Index: lib/Transforms/InstCombine/InstCombineAndOrXor.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1565,127 +1565,129 @@ return Changed ? &I : nullptr; } +/// A potential constituent of a bitreverse or bswap expression. See +/// collectBitParts for a fuller explanation. +struct BitPart { + BitPart(Value *P, unsigned BW) : Provider(P) { + Provenance.resize(BW); + } + + /// The Value that this is a bitreverse/bswap of. + Value *Provider; + /// The "provenance" of each bit. Provenance[A] = B means that bit A + /// in Provider becomes bit B in the result of this expression. + SmallVector Provenance; // int8_t means max size is i128. + + enum { Unset = -1 }; +}; /// Analyze the specified subexpression and see if it is capable of providing /// pieces of a bswap or bitreverse. The subexpression provides a potential /// piece of a bswap or bitreverse if it can be proven that each non-zero bit in /// the output of the expression came from a corresponding bit in some other /// value. This function is recursive, and the end result is a mapping of -/// (value, bitnumber) to bitnumber. It is the caller's responsibility to -/// validate that all `value`s are identical and that the bitnumber to bitnumber -/// mapping is correct for a bswap or bitreverse. +/// bitnumber to bitnumber. It is the caller's responsibility to validate that +/// the bitnumber to bitnumber mapping is correct for a bswap or bitreverse. /// /// For example, if the current subexpression if "(shl i32 %X, 24)" then we know /// that the expression deposits the low byte of %X into the high byte of the -/// result and that all other bits are zero. This expression is accepted, -/// BitValues[24-31] are set to %X and BitProvenance[24-31] are set to [0-7]. -/// -/// This function returns true if the match was unsuccessful and false if so. -/// On entry to the function the "OverallLeftShift" is a signed integer value -/// indicating the number of bits that the subexpression is later shifted. For -/// example, if the expression is later right shifted by 16 bits, the -/// OverallLeftShift value would be -16 on entry. This is used to specify which -/// bits of BitValues are actually being set. +/// result and that all other bits are zero. This expression is accepted and a +/// BitPart is returned with Provider set to %X and Provenance[24-31] set to +/// [0-7]. /// -/// Similarly, BitMask is a bitmask where a bit is clear if its corresponding -/// bit is masked to zero by a user. For example, in (X & 255), X will be -/// processed with a bytemask of 255. BitMask is always in the local -/// (OverallLeftShift) coordinate space. +/// To avoid revisiting values, the BitPart results are memoized into the +/// provided map. /// -static bool CollectBitParts(Value *V, int OverallLeftShift, APInt BitMask, - SmallVectorImpl &BitValues, - SmallVectorImpl &BitProvenance) { +static const Optional & +collectBitParts(Value *V, DenseMap> &BPS) { + auto I = BPS.find(V); + if (I != BPS.end()) + return I->second; + + auto &Result = BPS[V] = None; + auto BitWidth = cast(V->getType())->getBitWidth(); + if (Instruction *I = dyn_cast(V)) { // If this is an or instruction, it may be an inner node of the bswap. - if (I->getOpcode() == Instruction::Or) - return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask, - BitValues, BitProvenance) || - CollectBitParts(I->getOperand(1), OverallLeftShift, BitMask, - BitValues, BitProvenance); - - // If this is a logical shift by a constant, recurse with OverallLeftShift - // and BitMask adjusted. + if (I->getOpcode() == Instruction::Or) { + auto &A = collectBitParts(I->getOperand(0), BPS); + auto &B = collectBitParts(I->getOperand(1), BPS); + if (!A || !B) + return Result; + + // Try and merge the two together. + if (!A->Provider || A->Provider != B->Provider) + return Result; + + Result = BitPart(A->Provider, BitWidth); + for (unsigned i = 0; i < A->Provenance.size(); ++i) { + if (A->Provenance[i] != BitPart::Unset && + B->Provenance[i] != BitPart::Unset && + A->Provenance[i] != B->Provenance[i]) + return Result = None; + + if (A->Provenance[i] == BitPart::Unset) + Result->Provenance[i] = B->Provenance[i]; + else + Result->Provenance[i] = A->Provenance[i]; + } + + return Result; + } + + // If this is a logical shift by a constant, recurse then shift the result. if (I->isLogicalShift() && isa(I->getOperand(1))) { - unsigned ShAmt = + unsigned BitShift = cast(I->getOperand(1))->getLimitedValue(~0U); // Ensure the shift amount is defined. - if (ShAmt > BitValues.size()) - return true; + if (BitShift > BitWidth) + return Result; - unsigned BitShift = ShAmt; + auto &Res = collectBitParts(I->getOperand(0), BPS); + if (!Res) + return Result; + Result = Res; + + // Perform the "shift" on BitProvenance. + auto &P = Result->Provenance; if (I->getOpcode() == Instruction::Shl) { - // X << C -> collect(X, +C) - OverallLeftShift += BitShift; - BitMask = BitMask.lshr(BitShift); + P.erase(std::prev(P.end(), BitShift), P.end()); + P.insert(P.begin(), BitShift, BitPart::Unset); } else { - // X >>u C -> collect(X, -C) - OverallLeftShift -= BitShift; - BitMask = BitMask.shl(BitShift); + P.erase(P.begin(), std::next(P.begin(), BitShift)); + P.insert(P.end(), BitShift, BitPart::Unset); } - if (OverallLeftShift >= (int)BitValues.size()) - return true; - if (OverallLeftShift <= -(int)BitValues.size()) - return true; - - return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask, - BitValues, BitProvenance); + return Result; } - // If this is a logical 'and' with a mask that clears bits, clear the - // corresponding bits in BitMask. + // If this is a logical 'and' with a mask that clears bits, recurse then + // unset the appropriate bits. if (I->getOpcode() == Instruction::And && isa(I->getOperand(1))) { - unsigned NumBits = BitValues.size(); APInt Bit(I->getType()->getPrimitiveSizeInBits(), 1); const APInt &AndMask = cast(I->getOperand(1))->getValue(); - for (unsigned i = 0; i != NumBits; ++i, Bit <<= 1) { - // If this bit is masked out by a later operation, we don't care what - // the and mask is. - if (BitMask[i] == 0) - continue; + auto &Res = collectBitParts(I->getOperand(0), BPS); + if (!Res) + return Result; + Result = Res; + for (unsigned i = 0; i < BitWidth; ++i, Bit <<= 1) // If the AndMask is zero for this bit, clear the bit. - APInt MaskB = AndMask & Bit; - if (MaskB == 0) { - BitMask.clearBit(i); - continue; - } + if ((AndMask & Bit) == 0) + Result->Provenance[i] = BitPart::Unset; - // Otherwise, this bit is kept. - } - - return CollectBitParts(I->getOperand(0), OverallLeftShift, BitMask, - BitValues, BitProvenance); + return Result; } } // Okay, we got to something that isn't a shift, 'or' or 'and'. This must be - // the input value to the bswap/bitreverse. To be part of a bswap or - // bitreverse we must be demanding a contiguous range of bits from it. - unsigned InputBitLen = BitMask.countPopulation(); - unsigned InputBitNo = BitMask.countTrailingZeros(); - if (BitMask.getBitWidth() - BitMask.countLeadingZeros() - InputBitNo != - InputBitLen) - // Not a contiguous set range of bits! - return true; - - // We know we're moving a contiguous range of bits from the input to the - // output. Record which bits in the output came from which bits in the input. - unsigned DestBitNo = InputBitNo + OverallLeftShift; - for (unsigned I = 0; I < InputBitLen; ++I) - BitProvenance[DestBitNo + I] = InputBitNo + I; - - // If the destination bit value is already defined, the values are or'd - // together, which isn't a bswap/bitreverse (unless it's an or of the same - // bits). - if (BitValues[DestBitNo] && BitValues[DestBitNo] != V) - return true; - for (unsigned I = 0; I < InputBitLen; ++I) - BitValues[DestBitNo + I] = V; - - return false; + // the input value to the bswap/bitreverse. + Result = BitPart(V, BitWidth); + for (unsigned i = 0; i < BitWidth; ++i) + Result->Provenance[i] = i; + return Result; } static bool bitTransformIsCorrectForBSwap(unsigned From, unsigned To, @@ -1708,32 +1710,21 @@ /// idiom. If so, insert the new intrinsic and return it. Instruction *InstCombiner::MatchBSwapOrBitReverse(BinaryOperator &I) { IntegerType *ITy = dyn_cast(I.getType()); - if (!ITy) - return nullptr; // Can't do vectors. + if (!ITy || ITy->getBitWidth() > 128) + return nullptr; // Can't do vectors or integers > 128 bits. unsigned BW = ITy->getBitWidth(); - /// We keep track of which bit (BitProvenance) inside which value (BitValues) - /// defines each bit in the result. - SmallVector BitValues(BW, nullptr); - SmallVector BitProvenance(BW, -1); - // Try to find all the pieces corresponding to the bswap. - APInt BitMask = APInt::getAllOnesValue(BitValues.size()); - if (CollectBitParts(&I, 0, BitMask, BitValues, BitProvenance)) + DenseMap> BPS; + auto Res = collectBitParts(&I, BPS); + if (!Res) return nullptr; - - // Check to see if all of the bits come from the same value. - Value *V = BitValues[0]; - if (!V) return nullptr; // Didn't find a bit? Must be zero. - - if (!std::all_of(BitValues.begin(), BitValues.end(), - [&](const Value *X) { return X == V; })) - return nullptr; - + auto &BitProvenance = Res->Provenance; + // Now, is the bit permutation correct for a bswap or a bitreverse? We can // only byteswap values with an even number of bytes. bool OKForBSwap = BW % 16 == 0, OKForBitReverse = true;; - for (unsigned i = 0, e = BitValues.size(); i != e; ++i) { + for (unsigned i = 0; i < BW; ++i) { OKForBSwap &= bitTransformIsCorrectForBSwap(BitProvenance[i], i, BW); OKForBitReverse &= bitTransformIsCorrectForBitReverse(BitProvenance[i], i, BW); @@ -1748,7 +1739,7 @@ return nullptr; Function *F = Intrinsic::getDeclaration(I.getModule(), Intrin, ITy); - return CallInst::Create(F, V); + return CallInst::Create(F, Res->Provider); } /// We have an expression of the form (A&C)|(B&D). Check if A is (cond?-1:0) Index: test/Transforms/InstCombine/bitreverse-hang.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/bitreverse-hang.ll @@ -0,0 +1,53 @@ +; RUN: opt < %s -loop-unroll -instcombine -S | FileCheck %s + +; This test is a worst-case scenario for bitreversal/byteswap detection. +; After loop unrolling (the unrolled loop is unreadably large so it has been kept +; rolled here), we have a binary tree of OR operands (as bitreversal detection +; looks straight through shifts): +; +; OR +; | \ +; | LSHR +; | / +; OR +; | \ +; | LSHR +; | / +; OR +; +; This results in exponential runtime. The loop here is 32 iterations which will +; totally hang if we don't deal with this case cleverly. + +@b = common global i32 0, align 4 + +; CHECK: define i32 @fn1 +define i32 @fn1() #0 { +entry: + %b.promoted = load i32, i32* @b, align 4, !tbaa !2 + br label %for.body + +for.body: ; preds = %for.body, %entry + %or4 = phi i32 [ %b.promoted, %entry ], [ %or, %for.body ] + %i.03 = phi i32 [ 0, %entry ], [ %inc, %for.body ] + %shr = lshr i32 %or4, 1 + %or = or i32 %shr, %or4 + %inc = add nuw nsw i32 %i.03, 1 + %exitcond = icmp eq i32 %inc, 32 + br i1 %exitcond, label %for.end, label %for.body + +for.end: ; preds = %for.body + store i32 %or, i32* @b, align 4, !tbaa !2 + ret i32 undef +} + +attributes #0 = { norecurse nounwind ssp uwtable "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="core2" "target-features"="+cx16,+fxsr,+mmx,+sse,+sse2,+sse3,+ssse3" "unsafe-fp-math"="false" "use-soft-float"="false" } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"PIC Level", i32 2} +!1 = !{!"clang version 3.8.0 (http://llvm.org/git/clang.git eb70f4e9cc9a4dc3dd57b032fb858d56b4b64a0e)"} +!2 = !{!3, !3, i64 0} +!3 = !{!"int", !4, i64 0} +!4 = !{!"omnipotent char", !5, i64 0} +!5 = !{!"Simple C/C++ TBAA"}