diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -785,6 +785,7 @@ if (!Cmp) continue; + // Note that ptrtoint may change the bitwidth. Value *A, *B; auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); @@ -797,18 +798,18 @@ // assume(v = a) if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); Known.Zero |= RHSKnown.Zero; Known.One |= RHSKnown.One; // assume(v & b = a) } else if (match(Cmp, m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits MaskKnown(BitWidth); - computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits MaskKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in the mask that are known to be one, we can propagate // known bits from the RHS to V. @@ -818,10 +819,10 @@ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits MaskKnown(BitWidth); - computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits MaskKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in the mask that are known to be one, we can propagate // inverted known bits from the RHS to V. @@ -831,10 +832,10 @@ } else if (match(Cmp, m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits BKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. @@ -844,10 +845,10 @@ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits BKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. @@ -857,10 +858,10 @@ } else if (match(Cmp, m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits BKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. For those bits in B that are known to be one, @@ -873,10 +874,10 @@ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits BKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. For those bits in B that are @@ -889,8 +890,9 @@ } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. RHSKnown.Zero.lshrInPlace(C); @@ -901,8 +903,8 @@ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. RHSKnown.One.lshrInPlace(C); @@ -913,8 +915,8 @@ } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. Known.Zero |= RHSKnown.Zero << C; @@ -923,8 +925,8 @@ } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. Known.Zero |= RHSKnown.One << C; @@ -935,8 +937,8 @@ // assume(v >=_s c) where c is non-negative if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); if (RHSKnown.isNonNegative()) { // We know that the sign bit is zero. @@ -948,8 +950,8 @@ // assume(v >_s c) where c is at least -1. if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { // We know that the sign bit is zero. @@ -961,8 +963,8 @@ // assume(v <=_s c) where c is negative if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); if (RHSKnown.isNegative()) { // We know that the sign bit is one. @@ -974,8 +976,8 @@ // assume(v <_s c) where c is non-positive if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); if (RHSKnown.isZero() || RHSKnown.isNegative()) { // We know that the sign bit is one. @@ -987,8 +989,8 @@ // assume(v <=_u c) if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // Whatever high bits in c are zero are known to be zero. Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); @@ -998,8 +1000,8 @@ // assume(v <_u c) if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // If the RHS is known zero, then this assumption must be wrong (nothing // is unsigned less than zero). Signal a conflict and get out of here. diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" @@ -40,7 +41,7 @@ M = parseModule(Assembly); ASSERT_TRUE(M); - Function *F = M->getFunction("test"); + F = M->getFunction("test"); ASSERT_TRUE(F) << "Test must have a function @test"; if (!F) return; @@ -57,6 +58,7 @@ LLVMContext Context; std::unique_ptr M; + Function *F = nullptr; Instruction *A = nullptr; }; @@ -954,6 +956,44 @@ expectKnownBits(/*zero*/ 2u, /*one*/ 0u); } +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntTrunc) { + // ptrtoint truncates the pointer type. + parseAssembly( + "define void @test(i8** %p) {\n" + " %A = load i8*, i8** %p\n" + " %i = ptrtoint i8* %A to i32\n" + " %m = and i32 %i, 31\n" + " %c = icmp eq i32 %m, 0\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits( + A, M->getDataLayout(), /* Depth */ 0, &AC, F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), 31u); + EXPECT_EQ(Known.One.getZExtValue(), 0u); +} + +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsPtrToIntZext) { + // ptrtoint zero extends the pointer type. + parseAssembly( + "define void @test(i8** %p) {\n" + " %A = load i8*, i8** %p\n" + " %i = ptrtoint i8* %A to i128\n" + " %m = and i128 %i, 31\n" + " %c = icmp eq i128 %m, 0\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits( + A, M->getDataLayout(), /* Depth */ 0, &AC, F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), 31u); + EXPECT_EQ(Known.One.getZExtValue(), 0u); +} + class IsBytewiseValueTest : public ValueTrackingTest, public ::testing::WithParamInterface< std::pair> {