diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -494,8 +494,11 @@ return TestFlags == maskFlags(Flags, TestFlags); }; + /// A boolean denoting whether all loops are finite by assumption. + bool FiniteLoops; + ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, - DominatorTree &DT, LoopInfo &LI); + DominatorTree &DT, LoopInfo &LI, bool FiniteLoops = false); ScalarEvolution(ScalarEvolution &&Arg); ~ScalarEvolution(); @@ -1111,9 +1114,11 @@ /// Simplify LHS and RHS in a comparison with predicate Pred. Return true /// iff any changes were made. If the operands are provably equal or /// unequal, LHS and RHS are set to the same value and Pred is set to either - /// ICMP_EQ or ICMP_NE. + /// ICMP_EQ or ICMP_NE. ControllingFiniteLoop is set if this comparison + /// controls the exit of a loop known to have a finite number of iterations. bool SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, - const SCEV *&RHS, unsigned Depth = 0); + const SCEV *&RHS, unsigned Depth = 0, + bool ControlingFiniteLoop = false); /// Return the "disposition" of the given SCEV with respect to the given /// loop. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7017,7 +7017,7 @@ // A mustprogress loop without side effects must be finite. // TODO: The check used here is very conservative. It's only *specific* // side effects which are well defined in infinite loops. - return isMustProgress(L) && loopHasNoSideEffects(L); + return FiniteLoops || isMustProgress(L) && loopHasNoSideEffects(L); } const SCEV *ScalarEvolution::createSCEV(Value *V) { @@ -8472,7 +8472,9 @@ } // Simplify the operands before analyzing them. - (void)SimplifyICmpOperands(Pred, LHS, RHS); + (void)SimplifyICmpOperands(Pred, LHS, RHS, /*Depth=*/0, + /*ControllingFiniteLoop=*/ControlsExit && + loopIsFiniteByAssumption(L)); // If we have a comparison of a chrec against a constant, try to use value // ranges to answer this query. @@ -9945,7 +9947,8 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, - unsigned Depth) { + unsigned Depth, + bool ControllingFiniteLoop) { bool Changed = false; // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or // '0 != 0'. @@ -10074,10 +10077,16 @@ } // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by - // adding or subtracting 1 from one of the operands. + // adding or subtracting 1 from one of the operands. This can be done for + // one of two reasons: + // 1) The range of the RHS does not include the (signed/unsigned) boundaries + // 2) The loop is finite, with this comparison controling the exit. Since the + // loop + // is finite, the bound cannot include the corresponding boundary + // (otherwise it would loop forever). switch (Pred) { case ICmpInst::ICMP_SLE: - if (!getSignedRangeMax(RHS).isMaxSignedValue()) { + if (!getSignedRangeMax(RHS).isMaxSignedValue() || ControllingFiniteLoop) { RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SLT; @@ -10090,7 +10099,7 @@ } break; case ICmpInst::ICMP_SGE: - if (!getSignedRangeMin(RHS).isMinSignedValue()) { + if (!getSignedRangeMin(RHS).isMinSignedValue() || ControllingFiniteLoop) { RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS, SCEV::FlagNSW); Pred = ICmpInst::ICMP_SGT; @@ -10103,7 +10112,7 @@ } break; case ICmpInst::ICMP_ULE: - if (!getUnsignedRangeMax(RHS).isMaxValue()) { + if (!getUnsignedRangeMax(RHS).isMaxValue() || ControllingFiniteLoop) { RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS, SCEV::FlagNUW); Pred = ICmpInst::ICMP_ULT; @@ -10115,7 +10124,7 @@ } break; case ICmpInst::ICMP_UGE: - if (!getUnsignedRangeMin(RHS).isMinValue()) { + if (!getUnsignedRangeMin(RHS).isMinValue() || ControllingFiniteLoop) { RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS); Pred = ICmpInst::ICMP_UGT; Changed = true; @@ -10135,7 +10144,8 @@ // Recursively simplify until we either hit a recursion limit or nothing // changes. if (Changed) - return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1); + return SimplifyICmpOperands(Pred, LHS, RHS, Depth + 1, + ControllingFiniteLoop); return Changed; } @@ -12651,8 +12661,8 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, - LoopInfo &LI) - : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), + LoopInfo &LI, bool FiniteLoops) + : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), FiniteLoops(FiniteLoops), CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64) { // To use guards for proving predicates, we need to scan every instruction in @@ -12672,7 +12682,8 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), - LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), + LI(Arg.LI), FiniteLoops(Arg.FiniteLoops), + CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), PendingPhiRanges(std::move(Arg.PendingPhiRanges)), diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -42,19 +42,20 @@ ScalarEvolutionsTest() : M("", Context), TLII(), TLI(TLII) {} - ScalarEvolution buildSE(Function &F) { + ScalarEvolution buildSE(Function &F, bool FiniteLoops = false) { AC.reset(new AssumptionCache(F)); DT.reset(new DominatorTree(F)); LI.reset(new LoopInfo(*DT)); - return ScalarEvolution(F, TLI, *AC, *DT, *LI); + return ScalarEvolution(F, TLI, *AC, *DT, *LI, FiniteLoops); } void runWithSE( Module &M, StringRef FuncName, - function_ref Test) { + function_ref Test, + bool FiniteLoops = false) { auto *F = M.getFunction(FuncName); ASSERT_NE(F, nullptr) << "Could not find " << FuncName; - ScalarEvolution SE = buildSE(*F); + ScalarEvolution SE = buildSE(*F, FiniteLoops); Test(*F, *LI, SE); } @@ -1746,4 +1747,168 @@ }); } +TEST_F(ScalarEvolutionsTest, ComputeTripForFiniteSLE) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define void @foo(i32 %len) { " + "entry: " + " br label %for.body " + "for.body: " + " %iv = phi i32 [ %inc, %for.body ], [ 0, %entry ] " + " %inc = add i32 %iv, 1" + " %cmp = icmp sle i32 %iv, %len " + " br i1 %cmp, label %for.body, label %for.end " + "for.end: " + " ret void " + "} ", + Err, C); + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE( + *M, "foo", + [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *IV = getInstructionByName(F, "iv"); + auto *ScevIV = SE.getSCEV(IV); + const Loop *L = cast(ScevIV)->getLoop(); + const SCEV *ITC = SE.getExitCount(L, IV->getParent()); + // Assert exit count == smax(0, 1 + %len); + auto MaxExpr = dyn_cast(ITC); + EXPECT_TRUE(MaxExpr); + auto AddOne = dyn_cast(MaxExpr->getOperand(1)); + EXPECT_TRUE(AddOne); + auto One = dyn_cast(AddOne->getOperand(0)); + EXPECT_TRUE(One); + EXPECT_TRUE(One->getAPInt().isOne()); + auto Len = dyn_cast(AddOne->getOperand(1)); + EXPECT_TRUE(Len); + EXPECT_TRUE(Len->getValue()->getName() == "len"); + }, + /*FiniteLoops=*/true); +} + +TEST_F(ScalarEvolutionsTest, ComputeTripForFiniteULE) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define void @foo(i32 %len) { " + "entry: " + " br label %for.body " + "for.body: " + " %iv = phi i32 [ %inc, %for.body ], [ 0 , %entry ] " + " %inc = add i32 %iv, 1" + " %cmp = icmp ule i32 %iv, %len " + " br i1 %cmp, label %for.body, label %for.end " + "for.end: " + " ret void " + "} ", + Err, C); + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE( + *M, "foo", + [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *IV = getInstructionByName(F, "iv"); + auto *ScevIV = SE.getSCEV(IV); + const Loop *L = cast(ScevIV)->getLoop(); + const SCEV *ITC = SE.getExitCount(L, IV->getParent()); + // Assert exit count == umax(0, 1 + %len) == 1 + %len; + auto AddOne = dyn_cast(ITC); + EXPECT_TRUE(AddOne); + auto One = dyn_cast(AddOne->getOperand(0)); + EXPECT_TRUE(One); + EXPECT_TRUE(One->getAPInt().isOne()); + auto Len = dyn_cast(AddOne->getOperand(1)); + EXPECT_TRUE(Len); + EXPECT_TRUE(Len->getValue()->getName() == "len"); + }, + /*FiniteLoops=*/true); +} + +TEST_F(ScalarEvolutionsTest, ComputeTripForFiniteSGE) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32 %len) { " + "entry: " + " br label %for.body " + "for.body: " + " %iv = phi i32 [ %inc, %for.body ], [ %len, %entry ] " + " %inc = add i32 %iv, -1" + " %cmp = icmp sge i32 %iv, 1 " + " br i1 %cmp, label %for.body, label %for.end " + "for.end: " + " ret void " + "} ", + Err, C); + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE( + *M, "foo", + [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *IV = getInstructionByName(F, "iv"); + auto *ScevIV = SE.getSCEV(IV); + const Loop *L = cast(ScevIV)->getLoop(); + const SCEV *ITC = SE.getExitCount(L, IV->getParent()); + // Assert exit count == -1 * (0 smin %len) + %len == if %len <= 0 then 0 + // else len + auto AddExpr = dyn_cast(ITC); + EXPECT_TRUE(AddExpr); + auto MulMinusOne = dyn_cast(AddExpr->getOperand(0)); + EXPECT_TRUE(MulMinusOne); + auto MinusOne = dyn_cast(MulMinusOne->getOperand(0)); + EXPECT_TRUE(MinusOne); + EXPECT_TRUE(MinusOne->getAPInt().isAllOnes()); + auto SMin = dyn_cast(MulMinusOne->getOperand(1)); + EXPECT_TRUE(SMin); + auto Zero = dyn_cast(SMin->getOperand(0)); + EXPECT_TRUE(Zero); + EXPECT_TRUE(Zero->getAPInt().isZero()); + auto Len1 = dyn_cast(SMin->getOperand(1)); + EXPECT_TRUE(Len1); + EXPECT_TRUE(Len1->getValue()->getName() == "len"); + auto Len2 = dyn_cast(AddExpr->getOperand(1)); + EXPECT_TRUE(Len2); + EXPECT_TRUE(Len2->getValue()->getName() == "len"); + }, + /*FiniteLoops=*/true); +} + +TEST_F(ScalarEvolutionsTest, ComputeTripForFiniteUGE) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32 %len) { " + "entry: " + " br label %for.body " + "for.body: " + " %iv = phi i32 [ %inc, %for.body ], [ %len , %entry ] " + " %inc = add i32 %iv, -1" + " %cmp = icmp uge i32 %iv, 1 " + " br i1 %cmp, label %for.body, label %for.end " + "for.end: " + " ret void " + "} ", + Err, C); + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE( + *M, "foo", + [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *IV = getInstructionByName(F, "iv"); + auto *ScevIV = SE.getSCEV(IV); + const Loop *L = cast(ScevIV)->getLoop(); + const SCEV *ITC = SE.getExitCount(L, IV->getParent()); + // Assert exit count == len + auto Len = dyn_cast(ITC); + EXPECT_TRUE(Len); + EXPECT_TRUE(Len->getValue()->getName() == "len"); + }, + /*FiniteLoops=*/true); +} + } // end namespace llvm