diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -107,40 +107,13 @@ // provide it currently. OptimizationRemarkEmitter *ORE; - /// Set of assumptions that should be excluded from further queries. - /// This is because of the potential for mutual recursion to cause - /// computeKnownBits to repeatedly visit the same assume intrinsic. The - /// classic case of this is assume(x = y), which will attempt to determine - /// bits in x from bits in y, which will attempt to determine bits in y from - /// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call - /// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo - /// (all of which can call computeKnownBits), and so on. - std::array Excluded; - /// If true, it is safe to use metadata during simplification. InstrInfoQuery IIQ; - unsigned NumExcluded = 0; - Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, bool UseInstrInfo, OptimizationRemarkEmitter *ORE = nullptr) : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {} - - Query(const Query &Q, const Value *NewExcl) - : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ), - NumExcluded(Q.NumExcluded) { - Excluded = Q.Excluded; - Excluded[NumExcluded++] = NewExcl; - assert(NumExcluded <= Excluded.size()); - } - - bool isExcluded(const Value *Value) const { - if (NumExcluded == 0) - return false; - auto End = Excluded.begin() + NumExcluded; - return std::find(Excluded.begin(), End, Value) != End; - } }; } // end anonymous namespace @@ -632,8 +605,6 @@ CallInst *I = cast(AssumeVH); assert(I->getFunction() == Q.CxtI->getFunction() && "Got assumption for the wrong function!"); - if (Q.isExcluded(I)) - continue; // Warning: This loop can end up being somewhat performance sensitive. // We're running this loop for once for each value queried resulting in a @@ -681,8 +652,6 @@ CallInst *I = cast(AssumeVH); assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && "Got assumption for the wrong function!"); - if (Q.isExcluded(I)) - continue; // Warning: This loop can end up being somewhat performance sensitive. // We're running this loop for once for each value queried resulting in a @@ -713,6 +682,15 @@ if (!Cmp) continue; + // We are attempting to compute known bits for the operands of an assume. + // Do not try to use other assumptions for those recursive calls because + // that can lead to mutual recursion and a compile-time explosion. + // An example of the mutual recursion: computeKnownBits can call + // isKnownNonZero which calls computeKnownBitsFromAssume (this function) + // and so on. + Query QueryNoAC = Q; + QueryNoAC.AC = nullptr; + // Note that ptrtoint may change the bitwidth. Value *A, *B; auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); @@ -727,7 +705,7 @@ if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); Known.Zero |= RHSKnown.Zero; Known.One |= RHSKnown.One; // assume(v & b = a) @@ -735,9 +713,9 @@ m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits MaskKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in the mask that are known to be one, we can propagate // known bits from the RHS to V. @@ -748,9 +726,9 @@ m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits MaskKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in the mask that are known to be one, we can propagate // inverted known bits from the RHS to V. @@ -761,9 +739,9 @@ m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits BKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. @@ -774,9 +752,9 @@ m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits BKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. @@ -787,9 +765,9 @@ m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits BKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. For those bits in B that are known to be one, @@ -803,9 +781,9 @@ m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); KnownBits BKnown = - computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(B, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. For those bits in B that are @@ -819,7 +797,7 @@ m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. @@ -832,7 +810,7 @@ m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. RHSKnown.One.lshrInPlace(C); @@ -844,7 +822,7 @@ m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. Known.Zero |= RHSKnown.Zero << C; @@ -854,7 +832,7 @@ m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. Known.Zero |= RHSKnown.One << C; @@ -866,7 +844,7 @@ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth); if (RHSKnown.isNonNegative()) { // We know that the sign bit is zero. @@ -879,7 +857,7 @@ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth); if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { // We know that the sign bit is zero. @@ -892,7 +870,7 @@ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth + 1, QueryNoAC).anyextOrTrunc(BitWidth); if (RHSKnown.isNegative()) { // We know that the sign bit is one. @@ -905,7 +883,7 @@ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); if (RHSKnown.isZero() || RHSKnown.isNegative()) { // We know that the sign bit is one. @@ -918,7 +896,7 @@ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // Whatever high bits in c are zero are known to be zero. Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); @@ -929,7 +907,7 @@ if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { KnownBits RHSKnown = - computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + computeKnownBits(A, Depth+1, QueryNoAC).anyextOrTrunc(BitWidth); // If the RHS is known zero, then this assumption must be wrong (nothing // is unsigned less than zero). Signal a conflict and get out of here. @@ -941,7 +919,7 @@ // Whatever high bits in c are zero are known to be zero (if c is a power // of 2, then one more). - if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) + if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, QueryNoAC)) Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros() + 1); else Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll --- a/llvm/test/Transforms/InstCombine/assume.ll +++ b/llvm/test/Transforms/InstCombine/assume.ll @@ -175,15 +175,20 @@ ret i32 %and1 } -define i32 @bar4(i32 %a, i32 %b) { -; CHECK-LABEL: @bar4( +; If we allow recursive known bits queries based on +; assumptions, we could do better here: +; a == b and a & 7 == 1, so b & 7 == 1, so b & 3 == 1, so return 1. + +define i32 @known_bits_recursion_via_assumes(i32 %a, i32 %b) { +; CHECK-LABEL: @known_bits_recursion_via_assumes( ; CHECK-NEXT: entry: +; CHECK-NEXT: [[AND1:%.*]] = and i32 [[B:%.*]], 3 ; CHECK-NEXT: [[AND:%.*]] = and i32 [[A:%.*]], 7 ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 1 ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP]]) -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[A]], [[B:%.*]] +; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[A]], [[B]] ; CHECK-NEXT: tail call void @llvm.assume(i1 [[CMP2]]) -; CHECK-NEXT: ret i32 1 +; CHECK-NEXT: ret i32 [[AND1]] ; entry: %and1 = and i32 %b, 3