diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -413,39 +413,8 @@ KnownOut = KnownBits::computeForAddSub(Add, NSW, Known2, KnownOut); } -static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, - const APInt &DemandedElts, KnownBits &Known, - KnownBits &Known2, unsigned Depth, - const Query &Q) { +static void computeKnownBitsMul(KnownBits &Known, KnownBits &Known2) { unsigned BitWidth = Known.getBitWidth(); - computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q); - computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); - - bool isKnownNegative = false; - bool isKnownNonNegative = false; - // If the multiplication is known not to overflow, compute the sign bit. - if (NSW) { - if (Op0 == Op1) { - // The product of a number with itself is non-negative. - isKnownNonNegative = true; - } else { - bool isKnownNonNegativeOp1 = Known.isNonNegative(); - bool isKnownNonNegativeOp0 = Known2.isNonNegative(); - bool isKnownNegativeOp1 = Known.isNegative(); - bool isKnownNegativeOp0 = Known2.isNegative(); - // The product of two numbers with the same sign is non-negative. - isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) || - (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); - // The product of a negative number and a non-negative number is either - // negative or zero. - if (!isKnownNonNegative) - isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && - isKnownNonZero(Op0, Depth, Q)) || - (isKnownNegativeOp0 && isKnownNonNegativeOp1 && - isKnownNonZero(Op1, Depth, Q)); - } - } - assert(!Known.hasConflict() && !Known2.hasConflict()); // Compute a conservative estimate for high known-0 bits. unsigned LeadZ = std::max(Known.countMinLeadingZeros() + @@ -518,6 +487,41 @@ Known.Zero.setHighBits(LeadZ); Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); Known.One |= BottomKnown.getLoBits(ResultBitsKnown); +} + +static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, + const APInt &DemandedElts, KnownBits &Known, + KnownBits &Known2, unsigned Depth, + const Query &Q) { + computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q); + computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); + + bool isKnownNegative = false; + bool isKnownNonNegative = false; + // If the multiplication is known not to overflow, compute the sign bit. + if (NSW) { + if (Op0 == Op1) { + // The product of a number with itself is non-negative. + isKnownNonNegative = true; + } else { + bool isKnownNonNegativeOp1 = Known.isNonNegative(); + bool isKnownNonNegativeOp0 = Known2.isNonNegative(); + bool isKnownNegativeOp1 = Known.isNegative(); + bool isKnownNegativeOp0 = Known2.isNegative(); + // The product of two numbers with the same sign is non-negative. + isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) || + (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); + // The product of a negative number and a non-negative number is either + // negative or zero. + if (!isKnownNonNegative) + isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && + isKnownNonZero(Op0, Depth, Q)) || + (isKnownNegativeOp0 && isKnownNonNegativeOp1 && + isKnownNonZero(Op1, Depth, Q)); + } + } + + computeKnownBitsMul(Known, Known2); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in @@ -1452,48 +1456,75 @@ // to determine if we can prove known low zero bits. KnownBits LocalKnown(BitWidth); computeKnownBits(I->getOperand(0), LocalKnown, Depth + 1, Q); + KnownBits AddrKnownBits(LocalKnown); + unsigned TrailZ = LocalKnown.countMinTrailingZeros(); gep_type_iterator GTI = gep_type_begin(I); + // If the inbounds keyword is not present, the offsets are added to the base + // address with silently-wrapping two’s complement arithmetic. + bool IsInBounds = cast(I)->isInBounds(); for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) { - // TrailZ can only become smaller, short-circuit if we hit zero. - if (TrailZ == 0) - break; Value *Index = I->getOperand(i); - if (StructType *STy = GTI.getStructTypeOrNull()) { - // Handle struct member offset arithmetic. + unsigned IndexBitWidth = Index->getType()->getScalarSizeInBits(); + KnownBits IndexBits(IndexBitWidth); + computeKnownBits(Index, IndexBits, Depth + 1, Q); + // If the offsets have a different width from the pointer, according to + // the language reference we need to sign-extend or truncat them to + // the width of the pointer. + if (IndexBitWidth < BitWidth) + IndexBits = IndexBits.sext(BitWidth); + else if (IndexBitWidth > BitWidth) + IndexBits = IndexBits.trunc(BitWidth); + // Multiply by current sizeof type. + // &A[i] == A + i * sizeof(*A[i]). + uint64_t TypeSizeInBytes = Q.DL.getTypeAllocSize(GTI.getIndexedType()); + KnownBits ScalingFactor(BitWidth); + ScalingFactor.Zero = ~TypeSizeInBytes; + ScalingFactor.One = TypeSizeInBytes; + computeKnownBitsMul(IndexBits, ScalingFactor); + + AddrKnownBits = KnownBits::computeForAddSub( + /*Add=*/true, + /*NSW=*/IsInBounds, AddrKnownBits, IndexBits); + // TrailZ can only become smaller, short-circuit if we hit zero. + if (TrailZ) { + if (StructType *STy = GTI.getStructTypeOrNull()) { + // Handle struct member offset arithmetic. - // Handle case when index is vector zeroinitializer - Constant *CIndex = cast(Index); - if (CIndex->isZeroValue()) - continue; + // Handle case when index is vector zeroinitializer + Constant *CIndex = cast(Index); + if (CIndex->isZeroValue()) + continue; - if (CIndex->getType()->isVectorTy()) - Index = CIndex->getSplatValue(); + if (CIndex->getType()->isVectorTy()) + Index = CIndex->getSplatValue(); - unsigned Idx = cast(Index)->getZExtValue(); - const StructLayout *SL = Q.DL.getStructLayout(STy); - uint64_t Offset = SL->getElementOffset(Idx); - TrailZ = std::min(TrailZ, - countTrailingZeros(Offset)); - } else { - // Handle array index arithmetic. - Type *IndexedTy = GTI.getIndexedType(); - if (!IndexedTy->isSized()) { - TrailZ = 0; - break; + unsigned Idx = cast(Index)->getZExtValue(); + const StructLayout *SL = Q.DL.getStructLayout(STy); + uint64_t Offset = SL->getElementOffset(Idx); + TrailZ = std::min(TrailZ, countTrailingZeros(Offset)); + } else { + // Handle array index arithmetic. + Type *IndexedTy = GTI.getIndexedType(); + if (!IndexedTy->isSized()) { + TrailZ = 0; + continue; + } + unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); + uint64_t TypeSize = + Q.DL.getTypeAllocSize(IndexedTy).getKnownMinSize(); + LocalKnown.Zero = LocalKnown.One = APInt(GEPOpiBits, 0); + computeKnownBits(Index, LocalKnown, Depth + 1, Q); + TrailZ = + std::min(TrailZ, unsigned(countTrailingZeros(TypeSize) + + LocalKnown.countMinTrailingZeros())); } - unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); - uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy).getKnownMinSize(); - LocalKnown.Zero = LocalKnown.One = APInt(GEPOpiBits, 0); - computeKnownBits(Index, LocalKnown, Depth + 1, Q); - TrailZ = std::min(TrailZ, - unsigned(countTrailingZeros(TypeSize) + - LocalKnown.countMinTrailingZeros())); } } + Known = AddrKnownBits; Known.Zero.setLowBits(TrailZ); break; } diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -1013,6 +1013,106 @@ EXPECT_EQ(Known.One.getZExtValue(), 0u); } +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsAddWithRange) { + parseAssembly("define void @test(i64* %p) {\n" + " %A = load i64, i64* %p, !range !{i64 64, i64 65536}\n" + " %APlus512 = add i64 %A, 512\n" + " %c = icmp ugt i64 %APlus512, 523\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits(A, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~(65536llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 0u); + Instruction &APlus512 = findInstructionByName(F, "APlus512"); + Known = computeKnownBits(&APlus512, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + // We know of one less zero because 512 may have produced a 1 that + // got carried all the way to the first trailing zero. + EXPECT_EQ(Known.Zero.getZExtValue(), (~(65536llu - 1)) << 1); + EXPECT_EQ(Known.One.getZExtValue(), 0u); +} + +// 512 + [32, 64) doesn't produce overlapping bits. +// Make sure we get all the individual bits properly. +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsAddWithRangeNoOverlap) { + parseAssembly("define void @test(i64* %p) {\n" + " %A = load i64, i64* %p, !range !{i64 32, i64 64}\n" + " %APlus512 = add i64 %A, 512\n" + " %c = icmp ugt i64 %APlus512, 523\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits(A, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~(64llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 32u); + Instruction &APlus512 = findInstructionByName(F, "APlus512"); + Known = computeKnownBits(&APlus512, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~512llu & ~(64llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 512u | 32u); +} + +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsGEPWithRange) { + parseAssembly( + "define void @test(i64* %p) {\n" + " %A = load i64, i64* %p, !range !{i64 64, i64 65536}\n" + " %APtr = inttoptr i64 %A to float*" + " %APtrPlus512 = getelementptr float, float* %APtr, i32 128\n" + " %c = icmp ugt float* %APtrPlus512, inttoptr (i32 523 to float*)\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits(A, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~(65536llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 0u); + Instruction &APtrPlus512 = findInstructionByName(F, "APtrPlus512"); + Known = computeKnownBits(&APtrPlus512, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + // We know of one less zero because 512 may have produced a 1 that + // got carried all the way to the first trailing zero. + EXPECT_EQ(Known.Zero.getZExtValue(), ~(65536llu - 1) << 1); + EXPECT_EQ(Known.One.getZExtValue(), 0u); +} + +// 4*128 + [32, 64) doesn't produce overlapping bits. +// Make sure we get all the individual bits properly. +// This test is useful to check that we account for the scaling factor +// in the gep. Indeed, gep float, [32,64), 128 is not 128 + [32,64). +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsGEPWithRangeNoOverlap) { + parseAssembly( + "define void @test(i64* %p) {\n" + " %A = load i64, i64* %p, !range !{i64 32, i64 64}\n" + " %APtr = inttoptr i64 %A to float*" + " %APtrPlus512 = getelementptr float, float* %APtr, i32 128\n" + " %c = icmp ugt float* %APtrPlus512, inttoptr (i32 523 to float*)\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits(A, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~(64llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 32u); + Instruction &APtrPlus512 = findInstructionByName(F, "APtrPlus512"); + Known = computeKnownBits(&APtrPlus512, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + // We know of one less zero because 512 may have produced a 1 that + // got carried all the way to the first trailing zero. + EXPECT_EQ(Known.Zero.getZExtValue(), ~512llu & ~(64llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 512u | 32u); +} + class IsBytewiseValueTest : public ValueTrackingTest, public ::testing::WithParamInterface< std::pair> {