diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -531,7 +531,10 @@ /// Determine the possible constant range of an integer or vector of integer /// value. This is intended as a cheap, non-recursive check. - ConstantRange computeConstantRange(const Value *V, bool UseInstrInfo = true); + ConstantRange computeConstantRange(const Value *V, bool UseInstrInfo = true, + AssumptionCache *AC = nullptr, + const Instruction *CtxI = nullptr, + unsigned Depth = 0); /// Return true if this function can prove that the instruction I will /// always transfer execution to one of its successors (including the next diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -6367,9 +6367,15 @@ } } -ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo) { +ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo, + AssumptionCache *AC, + const Instruction *CtxI, + unsigned Depth) { assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction"); + if (Depth == MaxDepth) + return ConstantRange::getFull(V->getType()->getScalarSizeInBits()); + const APInt *C; if (match(V, m_APInt(C))) return ConstantRange(*C); @@ -6391,6 +6397,31 @@ if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range)) CR = CR.intersectWith(getConstantRangeFromMetadata(*Range)); + if (CtxI && AC) { + // Try to restrict the range based on information from assumptions. + for (auto &AssumeVH : AC->assumptionsFor(V)) { + if (!AssumeVH) + continue; + CallInst *I = cast(AssumeVH); + assert(I->getParent()->getParent() == CtxI->getParent()->getParent() && + "Got assumption for the wrong function!"); + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + if (!isValidAssumeForContext(I, CtxI, nullptr)) + continue; + Value *Arg = I->getArgOperand(0); + ICmpInst *Cmp = dyn_cast(Arg); + // Currently we just use information from comparisons. + if (!Cmp || Cmp->getOperand(0) != V) + continue; + ConstantRange RHS = computeConstantRange(Cmp->getOperand(1), UseInstrInfo, + AC, I, Depth + 1); + CR = CR.intersectWith( + ConstantRange::makeSatisfyingICmpRegion(Cmp->getPredicate(), RHS)); + } + } + return CR; } diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -9,6 +9,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/AsmParser/Parser.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" @@ -23,6 +24,14 @@ namespace { +static Instruction &findInstructionByName(Function *F, StringRef Name) { + for (Instruction &I : instructions(F)) + if (I.getName() == Name) + return I; + + llvm_unreachable("Expected value not found"); +} + class ValueTrackingTest : public testing::Test { protected: std::unique_ptr parseModule(StringRef Assembly) { @@ -46,13 +55,7 @@ if (!F) return; - A = nullptr; - for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) { - if (I->hasName()) { - if (I->getName() == "A") - A = &*I; - } - } + A = &findInstructionByName(F, "A"); ASSERT_TRUE(A) << "@test must have an instruction %A"; } @@ -1246,3 +1249,168 @@ S << *Actual; EXPECT_EQ(GetParam().first, S.str()); } + +TEST_F(ValueTrackingTest, ComputeConstantRange) { + { + // Assumptions: + // * stride >= 5 + // * stride < 10 + // + // stride = [5, 10) + auto M = parseModule(R"( + declare void @llvm.assume(i1) + + define i32 @test(i32 %stride) { + %gt = icmp uge i32 %stride, 5 + call void @llvm.assume(i1 %gt) + %lt = icmp ult i32 %stride, 10 + call void @llvm.assume(i1 %lt) + %stride.plus.one = add nsw nuw i32 %stride, 1 + ret i32 %stride.plus.one + })"); + Function *F = M->getFunction("test"); + + AssumptionCache AC(*F); + Value *Stride = &*F->arg_begin(); + ConstantRange CR1 = computeConstantRange(Stride, true, &AC, nullptr); + EXPECT_TRUE(CR1.isFullSet()); + + Instruction *I = &findInstructionByName(F, "stride.plus.one"); + ConstantRange CR2 = computeConstantRange(Stride, true, &AC, I); + EXPECT_EQ(5, CR2.getLower()); + EXPECT_EQ(10, CR2.getUpper()); + } + + { + // Assumptions: + // * stride >= 5 + // * stride < 200 + // * stride == 99 + // + // stride = [99, 100) + auto M = parseModule(R"( + declare void @llvm.assume(i1) + + define i32 @test(i32 %stride) { + %gt = icmp uge i32 %stride, 5 + call void @llvm.assume(i1 %gt) + %lt = icmp ult i32 %stride, 200 + call void @llvm.assume(i1 %lt) + %eq = icmp eq i32 %stride, 99 + call void @llvm.assume(i1 %eq) + %stride.plus.one = add nsw nuw i32 %stride, 1 + ret i32 %stride.plus.one + })"); + Function *F = M->getFunction("test"); + + AssumptionCache AC(*F); + Value *Stride = &*F->arg_begin(); + Instruction *I = &findInstructionByName(F, "stride.plus.one"); + ConstantRange CR = computeConstantRange(Stride, true, &AC, I); + EXPECT_EQ(99, *CR.getSingleElement()); + } + + { + // Assumptions: + // * stride >= 5 + // * stride >= 50 + // * stride < 100 + // * stride < 200 + // + // stride = [50, 100) + auto M = parseModule(R"( + declare void @llvm.assume(i1) + + define i32 @test(i32 %stride, i1 %cond) { + %gt = icmp uge i32 %stride, 5 + call void @llvm.assume(i1 %gt) + %gt.2 = icmp uge i32 %stride, 50 + call void @llvm.assume(i1 %gt.2) + br i1 %cond, label %bb1, label %bb2 + + bb1: + %lt = icmp ult i32 %stride, 200 + call void @llvm.assume(i1 %lt) + %lt.2 = icmp ult i32 %stride, 100 + call void @llvm.assume(i1 %lt.2) + %stride.plus.one = add nsw nuw i32 %stride, 1 + ret i32 %stride.plus.one + + bb2: + ret i32 0 + })"); + Function *F = M->getFunction("test"); + + AssumptionCache AC(*F); + Value *Stride = &*F->arg_begin(); + Instruction *GT2 = &findInstructionByName(F, "gt.2"); + ConstantRange CR = computeConstantRange(Stride, true, &AC, GT2); + EXPECT_EQ(5, CR.getLower()); + EXPECT_EQ(0, CR.getUpper()); + + Instruction *I = &findInstructionByName(F, "stride.plus.one"); + ConstantRange CR2 = computeConstantRange(Stride, true, &AC, I); + EXPECT_EQ(50, CR2.getLower()); + EXPECT_EQ(100, CR2.getUpper()); + } + + { + // Assumptions: + // * stride > 5 + // * stride < 5 + // + // stride = empty range, as the assumptions contradict each other. + auto M = parseModule(R"( + declare void @llvm.assume(i1) + + define i32 @test(i32 %stride, i1 %cond) { + %gt = icmp ugt i32 %stride, 5 + call void @llvm.assume(i1 %gt) + %lt = icmp ult i32 %stride, 5 + call void @llvm.assume(i1 %lt) + %stride.plus.one = add nsw nuw i32 %stride, 1 + ret i32 %stride.plus.one + })"); + Function *F = M->getFunction("test"); + + AssumptionCache AC(*F); + Value *Stride = &*F->arg_begin(); + + Instruction *I = &findInstructionByName(F, "stride.plus.one"); + ConstantRange CR = computeConstantRange(Stride, true, &AC, I); + EXPECT_TRUE(CR.isEmptySet()); + } + + { + // Assumptions: + // * x.1 >= 5 + // * x.2 < x.1 + // + // stride = [0, 5) + auto M = parseModule(R"( + declare void @llvm.assume(i1) + + define i32 @test(i32 %x.1, i32 %x.2) { + %gt = icmp uge i32 %x.1, 5 + call void @llvm.assume(i1 %gt) + %lt = icmp ult i32 %x.2, %x.1 + call void @llvm.assume(i1 %lt) + %stride.plus.one = add nsw nuw i32 %x.1, 1 + ret i32 %stride.plus.one + })"); + Function *F = M->getFunction("test"); + + AssumptionCache AC(*F); + Value *X2 = &*std::next(F->arg_begin()); + + Instruction *I = &findInstructionByName(F, "stride.plus.one"); + ConstantRange CR1 = computeConstantRange(X2, true, &AC, I); + EXPECT_EQ(0, CR1.getLower()); + EXPECT_EQ(5, CR1.getUpper()); + + // Check the depth cutoff results in a conservative result (full set) by + // passing Depth == MaxDepth == 6. + ConstantRange CR2 = computeConstantRange(X2, true, &AC, I, 6); + EXPECT_TRUE(CR2.isFullSet()); + } +}