Index: lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- lib/Target/ARM/ARMISelLowering.cpp +++ lib/Target/ARM/ARMISelLowering.cpp @@ -9018,46 +9018,82 @@ return SDValue(); } -static void WalkBFIChain(SDNode *N, SDValue &ChainFromVal, SDValue &ChainToVal, APInt &From, APInt &To, SmallVectorImpl &ExtraInstrs) { +// ParseBFI - given a BFI instruction in N, extract the "from" value (Rn) and return it, +// and fill in FromMask and ToMask with (consecutive) bits in "from" to be extracted and +// their position in "to" (Rd). +static SDValue ParseBFI(SDNode *N, APInt &ToMask, APInt &FromMask) { assert(N->getOpcode() == ARMISD::BFI); -// dbgs() << "WalkBFIChain: "; N->dump(); -// dbgs() << " From: " << utohexstr(From.getLimitedValue()) << ", To: " << utohexstr(To.getLimitedValue()) << "\n"; - APInt ToMask = ~cast(N->getOperand(2))->getAPIntValue(); - APInt FromMask = APInt::getLowBitsSet(ToMask.getBitWidth(), ToMask.countPopulation()); - // We have a BFI. Identify its base value (Rn). - SDValue FromVal = N->getOperand(1); - // If the BaseVal came from a SHR #C and the mask is simply one bit long, we can deduce that it is really testing bit #C in the base of the SHR; - if (/*FromMask.countPopulation() == 1 &&*/ FromVal->getOpcode() == ISD::SRL && isa(FromVal->getOperand(1))) { - APInt Shift = cast(FromVal->getOperand(1))->getAPIntValue(); + SDValue From = N->getOperand(1); + ToMask = ~cast(N->getOperand(2))->getAPIntValue(); + FromMask = APInt::getLowBitsSet(ToMask.getBitWidth(), ToMask.countPopulation()); + + // If the Base came from a SHR #C, we can deduce that it is really testing bit + // #C in the base of the SHR. + if (From->getOpcode() == ISD::SRL && + isa(From->getOperand(1))) { + APInt Shift = cast(From->getOperand(1))->getAPIntValue(); assert(Shift.getLimitedValue() < 32 && "Shift too large!"); FromMask <<= Shift.getLimitedValue(31); - FromVal = FromVal->getOperand(0); + From = From->getOperand(0); } -// dbgs() << " FromMask: " << utohexstr(FromMask.getLimitedValue()) << ", ToMask: " << utohexstr(ToMask.getLimitedValue()) << "\n"; - if ((!ChainFromVal || ChainFromVal == FromVal) && - // Have the "from" values already been queried? if so, bail and we can't go any further. - (FromMask & From) == 0 && (ToMask & To) == 0) { - ChainFromVal = FromVal; - From |= FromMask; - To |= ToMask; - } else { - ExtraInstrs.push_back(N); + return From; +} + +// If A and B contain one contiguous set of bits, does A | B == A . B? +static bool BitsProperlyConcatenate(const APInt &A, const APInt &B) { + unsigned LastActiveBitInA = A.getBitWidth() - A.countTrailingZeros() - 1; + unsigned FirstActiveBitInB = B.countLeadingZeros(); + return LastActiveBitInA + 1 == FirstActiveBitInB; +} + +static SDValue FindBFIToCombineWith(SDNode *N) { + // We have a BFI in N. Follow a possible chain of BFIs and find a BFI it can combine with, + // if one exists. + APInt ToMask, FromMask; + SDValue From = ParseBFI(N, ToMask, FromMask); + SDValue To = N->getOperand(0); + + // Now check for a compatible BFI to merge with. We can pass through BFIs that + // aren't compatible, but not if they set the same bit in their destination as + // we do (or that of any BFI we're going to combine with). + SDValue V = To; + APInt CombinedToMask = ToMask; + while (V.getOpcode() == ARMISD::BFI) { + APInt NewToMask, NewFromMask; + SDValue NewFrom = ParseBFI(V.getNode(), NewToMask, NewFromMask); + if (NewFrom != From) + // The BFIs have different bases so are not compatible! + return SDValue(); + + // Do the written bits conflict with any we've seen so far? + if ((NewToMask & CombinedToMask).getBoolValue()) + // Conflicting bits - bail out because going further is unsafe. + return SDValue(); + + // Are the new bits contiguous when combined with the old bits? + if (BitsProperlyConcatenate(ToMask, NewToMask) && + BitsProperlyConcatenate(FromMask, NewFromMask)) + return V; + if (BitsProperlyConcatenate(NewToMask, ToMask) && + BitsProperlyConcatenate(NewFromMask, FromMask)) + return V; + + // We've seen a write to some bits, so track it. + CombinedToMask |= NewToMask; + // Keep going... + V = V.getOperand(0); } - ChainToVal = N->getOperand(0); - - // Recurse down if the next instruction is also a BFI. - if (N->getOperand(0)->getOpcode() == ARMISD::BFI) - WalkBFIChain(N->getOperand(0).getNode(), ChainFromVal, ChainToVal, From, To, ExtraInstrs); + + return SDValue(); } static SDValue PerformBFICombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { SDValue N1 = N->getOperand(1); if (N1.getOpcode() == ISD::AND) { - // PerformBFICombine - (bfi A, (and B, Mask1), Mask2) -> (bfi A, B, Mask2) - // iff + // (bfi A, (and B, Mask1), Mask2) -> (bfi A, B, Mask2) iff // the bits being cleared by the AND are not demanded by the BFI. ConstantSDNode *N11C = dyn_cast(N1.getOperand(1)); if (!N11C) @@ -9078,39 +9114,33 @@ // We have a BFI of a BFI. Walk up the BFI chain to see how long it goes. // Keep track of any consecutive bits set that all come from the same base // value. We can combine these together into a single BFI. + SDValue CombineBFI = FindBFIToCombineWith(N); + if (CombineBFI == SDValue()) + return SDValue(); - SDValue FromVal, ToVal; - EVT VT = N1.getValueType(); - APInt From(VT.getSizeInBits(), 0); - APInt To(VT.getSizeInBits(), 0); - SmallVector Instrs; - WalkBFIChain(N, FromVal, ToVal, From, To, Instrs); - -// dbgs() << "JM FromVal: "; FromVal->dump(); -// dbgs() << "From: " << utohexstr(From.getLimitedValue()) << "\n"; -// dbgs() << "To: " << utohexstr(To.getLimitedValue()) << "\n"; + // We've found a BFI. + APInt ToMask1, FromMask1; + SDValue From1 = ParseBFI(N, ToMask1, FromMask1); - // Check contiguity. - unsigned Sz = VT.getSizeInBits(); - if (APInt::getBitsSet(Sz, From.countTrailingZeros(), Sz - From.countLeadingZeros()) != From) - return SDValue(); - if (APInt::getBitsSet(Sz, To.countTrailingZeros(), Sz - To.countLeadingZeros()) != To) - return SDValue(); - - if (From.countPopulation() > 1) { -// dbgs () << "JMJM\n"; + APInt ToMask2, FromMask2; + SDValue From2 = ParseBFI(CombineBFI.getNode(), ToMask2, FromMask2); + assert(From1 == From2); + + // First, unlink CombineBFI. + DCI.DAG.ReplaceAllUsesWith(CombineBFI, CombineBFI.getOperand(0)); + // Then create a new BFI, combining the two together. + APInt NewFromMask = FromMask1 | FromMask2; + APInt NewToMask = ToMask1 | ToMask2; - // Walk backwards over the extra instructions list, reconstructing as necessary. - SDValue V(ToVal); - SDLoc dl(ToVal); - for (auto *I : reverse(Instrs)) - V = DCI.DAG.getNode(ARMISD::BFI, dl, VT, V, I->getOperand(1), I->getOperand(2)); + EVT VT = N->getValueType(0); + SDLoc dl(N); - // OK, now we just need to construct a BFI from From to To, chaining to V. - if (From[0] == 0) - FromVal = DCI.DAG.getNode(ISD::SRL, dl, VT, FromVal, DCI.DAG.getConstant(From.countTrailingZeros(), dl, VT)); - return DCI.DAG.getNode(ARMISD::BFI, dl, VT, V, FromVal, DCI.DAG.getConstant(~To, dl, VT)); - } + if (NewFromMask[0] == 0) + From1 = DCI.DAG.getNode( + ISD::SRL, dl, VT, From1, + DCI.DAG.getConstant(NewFromMask.countTrailingZeros(), dl, VT)); + return DCI.DAG.getNode(ARMISD::BFI, dl, VT, N->getOperand(0), From1, + DCI.DAG.getConstant(~NewToMask, dl, VT)); } return SDValue(); } Index: test/CodeGen/ARM/bfi.ll =================================================================== --- test/CodeGen/ARM/bfi.ll +++ test/CodeGen/ARM/bfi.ll @@ -97,3 +97,42 @@ %sel = select i1 %cmp, i32 %or, i32 %y2 ret i32 %sel } + +define i32 @f9(i32 %x, i32 %y) { +; CHECK-LABEL: f9: +; CHECK: bfi r1, r0, #4, #2 + %y2 = and i32 %y, 4294967040 ; 0xFFFFFF00 + %and = and i32 %x, 4 + %or = or i32 %y2, 32 + %cmp = icmp ne i32 %and, 0 + %sel = select i1 %cmp, i32 %or, i32 %y2 + + %aand = and i32 %x, 2 + %aor = or i32 %sel, 16 + %acmp = icmp ne i32 %aand, 0 + %asel = select i1 %acmp, i32 %aor, i32 %sel + + ret i32 %asel +} + +define i32 @f10(i32 %x, i32 %y) { +; CHECK-LABEL: f10: +; CHECK: bfi r1, r0, #4, #3 + %y2 = and i32 %y, 4294967040 ; 0xFFFFFF00 + %and = and i32 %x, 4 + %or = or i32 %y2, 32 + %cmp = icmp ne i32 %and, 0 + %sel = select i1 %cmp, i32 %or, i32 %y2 + + %aand = and i32 %x, 2 + %aor = or i32 %sel, 16 + %acmp = icmp ne i32 %aand, 0 + %asel = select i1 %acmp, i32 %aor, i32 %sel + + %band = and i32 %x, 8 + %bor = or i32 %asel, 64 + %bcmp = icmp ne i32 %band, 0 + %bsel = select i1 %bcmp, i32 %bor, i32 %asel + + ret i32 %bsel +}