diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -413,39 +413,8 @@ KnownOut = KnownBits::computeForAddSub(Add, NSW, Known2, KnownOut); } -static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, - const APInt &DemandedElts, KnownBits &Known, - KnownBits &Known2, unsigned Depth, - const Query &Q) { +static void computeKnownBitsMul(KnownBits &Known, KnownBits &Known2) { unsigned BitWidth = Known.getBitWidth(); - computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q); - computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); - - bool isKnownNegative = false; - bool isKnownNonNegative = false; - // If the multiplication is known not to overflow, compute the sign bit. - if (NSW) { - if (Op0 == Op1) { - // The product of a number with itself is non-negative. - isKnownNonNegative = true; - } else { - bool isKnownNonNegativeOp1 = Known.isNonNegative(); - bool isKnownNonNegativeOp0 = Known2.isNonNegative(); - bool isKnownNegativeOp1 = Known.isNegative(); - bool isKnownNegativeOp0 = Known2.isNegative(); - // The product of two numbers with the same sign is non-negative. - isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) || - (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); - // The product of a negative number and a non-negative number is either - // negative or zero. - if (!isKnownNonNegative) - isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && - isKnownNonZero(Op0, Depth, Q)) || - (isKnownNegativeOp0 && isKnownNonNegativeOp1 && - isKnownNonZero(Op1, Depth, Q)); - } - } - assert(!Known.hasConflict() && !Known2.hasConflict()); // Compute a conservative estimate for high known-0 bits. unsigned LeadZ = std::max(Known.countMinLeadingZeros() + @@ -518,6 +487,41 @@ Known.Zero.setHighBits(LeadZ); Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); Known.One |= BottomKnown.getLoBits(ResultBitsKnown); +} + +static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, + const APInt &DemandedElts, KnownBits &Known, + KnownBits &Known2, unsigned Depth, + const Query &Q) { + computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q); + computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); + + bool isKnownNegative = false; + bool isKnownNonNegative = false; + // If the multiplication is known not to overflow, compute the sign bit. + if (NSW) { + if (Op0 == Op1) { + // The product of a number with itself is non-negative. + isKnownNonNegative = true; + } else { + bool isKnownNonNegativeOp1 = Known.isNonNegative(); + bool isKnownNonNegativeOp0 = Known2.isNonNegative(); + bool isKnownNegativeOp1 = Known.isNegative(); + bool isKnownNegativeOp0 = Known2.isNegative(); + // The product of two numbers with the same sign is non-negative. + isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) || + (isKnownNonNegativeOp1 && isKnownNonNegativeOp0); + // The product of a negative number and a non-negative number is either + // negative or zero. + if (!isKnownNonNegative) + isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && + isKnownNonZero(Op0, Depth, Q)) || + (isKnownNegativeOp0 && isKnownNonNegativeOp1 && + isKnownNonZero(Op1, Depth, Q)); + } + } + + computeKnownBitsMul(Known, Known2); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in @@ -1452,15 +1456,23 @@ // to determine if we can prove known low zero bits. KnownBits LocalKnown(BitWidth); computeKnownBits(I->getOperand(0), LocalKnown, Depth + 1, Q); + KnownBits AddrKnownBits(LocalKnown); + // True if we track all the bits of the address. + bool TrackAddr = true; + unsigned TrailZ = LocalKnown.countMinTrailingZeros(); gep_type_iterator GTI = gep_type_begin(I); + // If the inbounds keyword is not present, the offsets are added to the base + // address with silently-wrapping two’s complement arithmetic. + bool IsInBounds = cast(I)->isInBounds(); for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) { - // TrailZ can only become smaller, short-circuit if we hit zero. - if (TrailZ == 0) + if (TrailZ == 0 && !TrackAddr) break; - Value *Index = I->getOperand(i); + + unsigned IndexBitWidth = Index->getType()->getScalarSizeInBits(); + KnownBits IndexBits(IndexBitWidth); if (StructType *STy = GTI.getStructTypeOrNull()) { // Handle struct member offset arithmetic. @@ -1475,6 +1487,10 @@ unsigned Idx = cast(Index)->getZExtValue(); const StructLayout *SL = Q.DL.getStructLayout(STy); uint64_t Offset = SL->getElementOffset(Idx); + if (TrackAddr) { + IndexBits.Zero = ~Offset; + IndexBits.One = Offset; + } TrailZ = std::min(TrailZ, countTrailingZeros(Offset)); } else { @@ -1482,19 +1498,43 @@ Type *IndexedTy = GTI.getIndexedType(); if (!IndexedTy->isSized()) { TrailZ = 0; + TrackAddr = false; break; } - unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); - uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy).getKnownMinSize(); - LocalKnown.Zero = LocalKnown.One = APInt(GEPOpiBits, 0); - computeKnownBits(Index, LocalKnown, Depth + 1, Q); - TrailZ = std::min(TrailZ, - unsigned(countTrailingZeros(TypeSize) + - LocalKnown.countMinTrailingZeros())); + computeKnownBits(Index, IndexBits, Depth + 1, Q); + TypeSize IndexTypeSize = Q.DL.getTypeAllocSize(IndexedTy); + uint64_t TypeSizeInBytes = IndexTypeSize.getKnownMinSize(); + TrackAddr &= !IndexTypeSize.isScalable(); + if (TrackAddr) { + // Multiply by current sizeof type. + // &A[i] == A + i * sizeof(*A[i]). + KnownBits ScalingFactor(IndexBitWidth); + ScalingFactor.Zero = ~TypeSizeInBytes; + ScalingFactor.One = TypeSizeInBytes; + computeKnownBitsMul(IndexBits, ScalingFactor); + } + TrailZ = std::min(TrailZ, unsigned(countTrailingZeros(TypeSizeInBytes) + + IndexBits.countMinTrailingZeros())); } - } + if (!TrackAddr) + continue; - Known.Zero.setLowBits(TrailZ); + // If the offsets have a different width from the pointer, according + // to the language reference we need to sign-extend or truncate them + // to the width of the pointer. + if (IndexBitWidth < BitWidth) + IndexBits = IndexBits.sext(BitWidth); + else if (IndexBitWidth > BitWidth) + IndexBits = IndexBits.trunc(BitWidth); + + AddrKnownBits = KnownBits::computeForAddSub( + /*Add=*/true, + /*NSW=*/IsInBounds, AddrKnownBits, IndexBits); + } + if (TrackAddr) + Known = AddrKnownBits; + else + Known.Zero.setLowBits(TrailZ); break; } case Instruction::PHI: { diff --git a/llvm/test/Transforms/InstCombine/constant-fold-address-space-pointer.ll b/llvm/test/Transforms/InstCombine/constant-fold-address-space-pointer.ll --- a/llvm/test/Transforms/InstCombine/constant-fold-address-space-pointer.ll +++ b/llvm/test/Transforms/InstCombine/constant-fold-address-space-pointer.ll @@ -197,7 +197,7 @@ define i32 @test_constant_cast_gep_struct_indices_as() { ; CHECK-LABEL: @test_constant_cast_gep_struct_indices_as( -; CHECK: load i32, i32 addrspace(3)* getelementptr inbounds (%struct.foo, %struct.foo addrspace(3)* @constant_fold_global_ptr, i16 0, i32 2, i16 2), align 8 +; CHECK: load i32, i32 addrspace(3)* getelementptr inbounds (%struct.foo, %struct.foo addrspace(3)* @constant_fold_global_ptr, i16 0, i32 2, i16 2), align 16 %x = getelementptr %struct.foo, %struct.foo addrspace(3)* @constant_fold_global_ptr, i18 0, i32 2, i12 2 %y = load i32, i32 addrspace(3)* %x, align 4 ret i32 %y diff --git a/llvm/test/Transforms/InstCombine/constant-fold-gep.ll b/llvm/test/Transforms/InstCombine/constant-fold-gep.ll --- a/llvm/test/Transforms/InstCombine/constant-fold-gep.ll +++ b/llvm/test/Transforms/InstCombine/constant-fold-gep.ll @@ -17,7 +17,7 @@ store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 2), align 4 ; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 1, i64 0), align 4 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 3), align 4 -; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 1, i64 1), align 4 +; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 1, i64 1), align 16 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 4), align 4 ; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 1, i64 2), align 4 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 5), align 4 @@ -25,11 +25,11 @@ store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 6), align 4 ; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 1, i32 0, i64 1), align 4 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 7), align 4 -; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 1, i32 0, i64 2), align 8 +; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 1, i32 0, i64 2), align 16 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 8), align 4 ; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 1, i32 1, i64 0), align 4 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 9), align 4 -; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 1, i32 1, i64 1), align 4 +; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 1, i32 1, i64 1), align 8 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 10), align 4 ; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 1, i32 1, i64 2), align 4 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 11), align 4 @@ -39,12 +39,12 @@ store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 13), align 4 ; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 2, i32 0, i64 2), align 8 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 14), align 8 -; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 2, i32 1, i64 0), align 8 - store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 15), align 8 -; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 2, i32 1, i64 1), align 8 +; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 2, i32 1, i64 0), align 4 + store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 15), align 4 +; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 2, i32 1, i64 1), align 16 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 16), align 8 -; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 2, i32 1, i64 2), align 8 - store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 17), align 8 +; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 2, i32 1, i64 2), align 4 + store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 17), align 4 ; CHECK: store i32 1, i32* getelementptr inbounds ([3 x %struct.X], [3 x %struct.X]* @Y, i64 1, i64 0, i32 0, i64 0), align 8 store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 0, i64 0, i32 0, i64 18), align 8 ; CHECK: store i32 1, i32* getelementptr ([3 x %struct.X], [3 x %struct.X]* @Y, i64 2, i64 0, i32 0, i64 0), align 16 @@ -90,3 +90,19 @@ ret i16 %E } + +; Check that we improve the alignment information. +; The base pointer is 16-byte aligned and we access the field at +; an offset of 8-byte. +; Every element in the @CallerInfos array is 16-byte aligned so +; any access from the following gep is 8-byte aligned. +%struct.CallerInfo = type { i8*, i32 } +@CallerInfos = global [128 x %struct.CallerInfo] zeroinitializer, align 16 + +; CHECK-LABEL: @test_gep_in_struct( +; CHECK; load i32, i32* %NS7, align 8 +define i32 @test_gep_in_struct(i64 %idx) { + %NS7 = getelementptr inbounds [128 x %struct.CallerInfo], [128 x %struct.CallerInfo]* @CallerInfos, i64 0, i64 %idx, i32 1 + %res = load i32, i32* %NS7, align 1 + ret i32 %res +} diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -1013,6 +1013,106 @@ EXPECT_EQ(Known.One.getZExtValue(), 0u); } +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsAddWithRange) { + parseAssembly("define void @test(i64* %p) {\n" + " %A = load i64, i64* %p, !range !{i64 64, i64 65536}\n" + " %APlus512 = add i64 %A, 512\n" + " %c = icmp ugt i64 %APlus512, 523\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits(A, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~(65536llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 0u); + Instruction &APlus512 = findInstructionByName(F, "APlus512"); + Known = computeKnownBits(&APlus512, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + // We know of one less zero because 512 may have produced a 1 that + // got carried all the way to the first trailing zero. + EXPECT_EQ(Known.Zero.getZExtValue(), (~(65536llu - 1)) << 1); + EXPECT_EQ(Known.One.getZExtValue(), 0u); +} + +// 512 + [32, 64) doesn't produce overlapping bits. +// Make sure we get all the individual bits properly. +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsAddWithRangeNoOverlap) { + parseAssembly("define void @test(i64* %p) {\n" + " %A = load i64, i64* %p, !range !{i64 32, i64 64}\n" + " %APlus512 = add i64 %A, 512\n" + " %c = icmp ugt i64 %APlus512, 523\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits(A, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~(64llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 32u); + Instruction &APlus512 = findInstructionByName(F, "APlus512"); + Known = computeKnownBits(&APlus512, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~512llu & ~(64llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 512u | 32u); +} + +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsGEPWithRange) { + parseAssembly( + "define void @test(i64* %p) {\n" + " %A = load i64, i64* %p, !range !{i64 64, i64 65536}\n" + " %APtr = inttoptr i64 %A to float*" + " %APtrPlus512 = getelementptr float, float* %APtr, i32 128\n" + " %c = icmp ugt float* %APtrPlus512, inttoptr (i32 523 to float*)\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits(A, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~(65536llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 0u); + Instruction &APtrPlus512 = findInstructionByName(F, "APtrPlus512"); + Known = computeKnownBits(&APtrPlus512, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + // We know of one less zero because 512 may have produced a 1 that + // got carried all the way to the first trailing zero. + EXPECT_EQ(Known.Zero.getZExtValue(), ~(65536llu - 1) << 1); + EXPECT_EQ(Known.One.getZExtValue(), 0u); +} + +// 4*128 + [32, 64) doesn't produce overlapping bits. +// Make sure we get all the individual bits properly. +// This test is useful to check that we account for the scaling factor +// in the gep. Indeed, gep float, [32,64), 128 is not 128 + [32,64). +TEST_F(ComputeKnownBitsTest, ComputeKnownBitsGEPWithRangeNoOverlap) { + parseAssembly( + "define void @test(i64* %p) {\n" + " %A = load i64, i64* %p, !range !{i64 32, i64 64}\n" + " %APtr = inttoptr i64 %A to float*" + " %APtrPlus512 = getelementptr float, float* %APtr, i32 128\n" + " %c = icmp ugt float* %APtrPlus512, inttoptr (i32 523 to float*)\n" + " call void @llvm.assume(i1 %c)\n" + " ret void\n" + "}\n" + "declare void @llvm.assume(i1)\n"); + AssumptionCache AC(*F); + KnownBits Known = computeKnownBits(A, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + EXPECT_EQ(Known.Zero.getZExtValue(), ~(64llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 32u); + Instruction &APtrPlus512 = findInstructionByName(F, "APtrPlus512"); + Known = computeKnownBits(&APtrPlus512, M->getDataLayout(), /* Depth */ 0, &AC, + F->front().getTerminator()); + // We know of one less zero because 512 may have produced a 1 that + // got carried all the way to the first trailing zero. + EXPECT_EQ(Known.Zero.getZExtValue(), ~512llu & ~(64llu - 1)); + EXPECT_EQ(Known.One.getZExtValue(), 512u | 32u); +} + class IsBytewiseValueTest : public ValueTrackingTest, public ::testing::WithParamInterface< std::pair> {