Index: lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- lib/Target/ARM/ARMISelLowering.cpp +++ lib/Target/ARM/ARMISelLowering.cpp @@ -9018,12 +9018,83 @@ return SDValue(); } -/// PerformBFICombine - (bfi A, (and B, Mask1), Mask2) -> (bfi A, B, Mask2) iff -/// the bits being cleared by the AND are not demanded by the BFI. +// ParseBFI - given a BFI instruction in N, extract the "from" value (Rn) and return it, +// and fill in FromMask and ToMask with (consecutive) bits in "from" to be extracted and +// their position in "to" (Rd). +static SDValue ParseBFI(SDNode *N, APInt &ToMask, APInt &FromMask) { + assert(N->getOpcode() == ARMISD::BFI); + + SDValue From = N->getOperand(1); + ToMask = ~cast(N->getOperand(2))->getAPIntValue(); + FromMask = APInt::getLowBitsSet(ToMask.getBitWidth(), ToMask.countPopulation()); + + // If the Base came from a SHR #C, we can deduce that it is really testing bit + // #C in the base of the SHR. + if (From->getOpcode() == ISD::SRL && + isa(From->getOperand(1))) { + APInt Shift = cast(From->getOperand(1))->getAPIntValue(); + assert(Shift.getLimitedValue() < 32 && "Shift too large!"); + FromMask <<= Shift.getLimitedValue(31); + From = From->getOperand(0); + } + + return From; +} + +// If A and B contain one contiguous set of bits, does A | B == A . B? +static bool BitsProperlyConcatenate(const APInt &A, const APInt &B) { + unsigned LastActiveBitInA = A.getBitWidth() - A.countTrailingZeros() - 1; + unsigned FirstActiveBitInB = B.countLeadingZeros(); + return LastActiveBitInA + 1 == FirstActiveBitInB; +} + +static SDValue FindBFIToCombineWith(SDNode *N) { + // We have a BFI in N. Follow a possible chain of BFIs and find a BFI it can combine with, + // if one exists. + APInt ToMask, FromMask; + SDValue From = ParseBFI(N, ToMask, FromMask); + SDValue To = N->getOperand(0); + + // Now check for a compatible BFI to merge with. We can pass through BFIs that + // aren't compatible, but not if they set the same bit in their destination as + // we do (or that of any BFI we're going to combine with). + SDValue V = To; + APInt CombinedToMask = ToMask; + while (V.getOpcode() == ARMISD::BFI) { + APInt NewToMask, NewFromMask; + SDValue NewFrom = ParseBFI(V.getNode(), NewToMask, NewFromMask); + if (NewFrom != From) + // The BFIs have different bases so are not compatible! + return SDValue(); + + // Do the written bits conflict with any we've seen so far? + if ((NewToMask & CombinedToMask).getBoolValue()) + // Conflicting bits - bail out because going further is unsafe. + return SDValue(); + + // Are the new bits contiguous when combined with the old bits? + if (BitsProperlyConcatenate(ToMask, NewToMask) && + BitsProperlyConcatenate(FromMask, NewFromMask)) + return V; + if (BitsProperlyConcatenate(NewToMask, ToMask) && + BitsProperlyConcatenate(NewFromMask, FromMask)) + return V; + + // We've seen a write to some bits, so track it. + CombinedToMask |= NewToMask; + // Keep going... + V = V.getOperand(0); + } + + return SDValue(); +} + static SDValue PerformBFICombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SDValue N1 = N->getOperand(1); if (N1.getOpcode() == ISD::AND) { + // (bfi A, (and B, Mask1), Mask2) -> (bfi A, B, Mask2) iff + // the bits being cleared by the AND are not demanded by the BFI. ConstantSDNode *N11C = dyn_cast(N1.getOperand(1)); if (!N11C) return SDValue(); @@ -9039,6 +9110,37 @@ return DCI.DAG.getNode(ARMISD::BFI, SDLoc(N), N->getValueType(0), N->getOperand(0), N1.getOperand(0), N->getOperand(2)); + } else if (N->getOperand(0).getOpcode() == ARMISD::BFI) { + // We have a BFI of a BFI. Walk up the BFI chain to see how long it goes. + // Keep track of any consecutive bits set that all come from the same base + // value. We can combine these together into a single BFI. + SDValue CombineBFI = FindBFIToCombineWith(N); + if (CombineBFI == SDValue()) + return SDValue(); + + // We've found a BFI. + APInt ToMask1, FromMask1; + SDValue From1 = ParseBFI(N, ToMask1, FromMask1); + + APInt ToMask2, FromMask2; + SDValue From2 = ParseBFI(CombineBFI.getNode(), ToMask2, FromMask2); + assert(From1 == From2); + + // First, unlink CombineBFI. + DCI.DAG.ReplaceAllUsesWith(CombineBFI, CombineBFI.getOperand(0)); + // Then create a new BFI, combining the two together. + APInt NewFromMask = FromMask1 | FromMask2; + APInt NewToMask = ToMask1 | ToMask2; + + EVT VT = N->getValueType(0); + SDLoc dl(N); + + if (NewFromMask[0] == 0) + From1 = DCI.DAG.getNode( + ISD::SRL, dl, VT, From1, + DCI.DAG.getConstant(NewFromMask.countTrailingZeros(), dl, VT)); + return DCI.DAG.getNode(ARMISD::BFI, dl, VT, N->getOperand(0), From1, + DCI.DAG.getConstant(~NewToMask, dl, VT)); } return SDValue(); } @@ -10243,6 +10345,16 @@ KnownOne &= Mask; return; } + if (Op.getOpcode() == ARMISD::CMOV) { + APInt KZ2(KnownZero.getBitWidth(), 0); + APInt KO2(KnownOne.getBitWidth(), 0); + computeKnownBits(DAG, Op.getOperand(1), KnownZero, KnownOne); + computeKnownBits(DAG, Op.getOperand(2), KZ2, KO2); + + KnownZero &= KZ2; + KnownOne &= KO2; + return; + } return DAG.computeKnownBits(Op, KnownZero, KnownOne); } Index: test/CodeGen/ARM/bfi.ll =================================================================== --- test/CodeGen/ARM/bfi.ll +++ test/CodeGen/ARM/bfi.ll @@ -97,3 +97,42 @@ %sel = select i1 %cmp, i32 %or, i32 %y2 ret i32 %sel } + +define i32 @f9(i32 %x, i32 %y) { +; CHECK-LABEL: f9: +; CHECK: bfi r1, r0, #4, #2 + %y2 = and i32 %y, 4294967040 ; 0xFFFFFF00 + %and = and i32 %x, 4 + %or = or i32 %y2, 32 + %cmp = icmp ne i32 %and, 0 + %sel = select i1 %cmp, i32 %or, i32 %y2 + + %aand = and i32 %x, 2 + %aor = or i32 %sel, 16 + %acmp = icmp ne i32 %aand, 0 + %asel = select i1 %acmp, i32 %aor, i32 %sel + + ret i32 %asel +} + +define i32 @f10(i32 %x, i32 %y) { +; CHECK-LABEL: f10: +; CHECK: bfi r1, r0, #4, #3 + %y2 = and i32 %y, 4294967040 ; 0xFFFFFF00 + %and = and i32 %x, 4 + %or = or i32 %y2, 32 + %cmp = icmp ne i32 %and, 0 + %sel = select i1 %cmp, i32 %or, i32 %y2 + + %aand = and i32 %x, 2 + %aor = or i32 %sel, 16 + %acmp = icmp ne i32 %aand, 0 + %asel = select i1 %acmp, i32 %aor, i32 %sel + + %band = and i32 %x, 8 + %bor = or i32 %asel, 64 + %bcmp = icmp ne i32 %band, 0 + %bsel = select i1 %bcmp, i32 %bor, i32 %asel + + ret i32 %bsel +}