diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -494,8 +494,11 @@ return TestFlags == maskFlags(Flags, TestFlags); }; + /// A boolean denoting whether all loops are finite by assumption. + bool FiniteLoops; + ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, - DominatorTree &DT, LoopInfo &LI); + DominatorTree &DT, LoopInfo &LI, bool FiniteLoops = false); ScalarEvolution(ScalarEvolution &&Arg); ~ScalarEvolution(); diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7017,7 +7017,7 @@ // A mustprogress loop without side effects must be finite. // TODO: The check used here is very conservative. It's only *specific* // side effects which are well defined in infinite loops. - return isMustProgress(L) && loopHasNoSideEffects(L); + return FiniteLoops || isMustProgress(L) && loopHasNoSideEffects(L); } const SCEV *ScalarEvolution::createSCEV(Value *V) { @@ -8513,6 +8513,46 @@ } } + // Moreover, if the loop is assumed to be finite, and controls the exit, then + // lhs <= rhs is equivalent to lhs < rhs + 1 and lhs >= rhs to lhs > rhs -1. + // In the case of lhs <= rhs, this is true since the only case these are not + // equivalent is when rhs == unsigned/signed intmax, which would have resulted + // in an infinite loop. In the case of lhs >= rhs, this is true since the only + // case these are not equivalent is when rhs == unsigned/signed intmin, which + // would again have resulted in an infinite loop. + if (ControlsExit && isLoopInvariant(RHS, L) && loopIsFiniteByAssumption(L)) { + switch (Pred) { + case ICmpInst::ICMP_SLE: { + Pred = ICmpInst::ICMP_SLT; + auto One = + getConstant(ConstantInt::get(cast(RHS->getType()), 1)); + RHS = getAddExpr(RHS, One, SCEV::FlagNSW); + break; + } + case ICmpInst::ICMP_ULE: { + Pred = ICmpInst::ICMP_ULT; + auto One = + getConstant(ConstantInt::get(cast(RHS->getType()), 1)); + RHS = getAddExpr(RHS, One, SCEV::FlagNUW); + break; + } + case ICmpInst::ICMP_SGE: { + Pred = ICmpInst::ICMP_SGT; + auto MinusOne = + getConstant(ConstantInt::get(cast(RHS->getType()), -1)); + RHS = getAddExpr(RHS, MinusOne, SCEV::FlagNSW); + break; + } + case ICmpInst::ICMP_UGE: { + Pred = ICmpInst::ICMP_UGT; + auto MinusOne = + getConstant(ConstantInt::get(cast(RHS->getType()), -1)); + RHS = getAddExpr(RHS, MinusOne, SCEV::FlagNUW); + break; + } + } + } + switch (Pred) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) @@ -12651,8 +12691,8 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, - LoopInfo &LI) - : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), + LoopInfo &LI, bool FiniteLoops) + : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), FiniteLoops(FiniteLoops), CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64) { // To use guards for proving predicates, we need to scan every instruction in @@ -12672,7 +12712,8 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), - LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), + LI(Arg.LI), FiniteLoops(Arg.FiniteLoops), + CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), PendingPhiRanges(std::move(Arg.PendingPhiRanges)), diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp --- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -42,19 +42,20 @@ ScalarEvolutionsTest() : M("", Context), TLII(), TLI(TLII) {} - ScalarEvolution buildSE(Function &F) { + ScalarEvolution buildSE(Function &F, bool FiniteLoops = false) { AC.reset(new AssumptionCache(F)); DT.reset(new DominatorTree(F)); LI.reset(new LoopInfo(*DT)); - return ScalarEvolution(F, TLI, *AC, *DT, *LI); + return ScalarEvolution(F, TLI, *AC, *DT, *LI, FiniteLoops); } void runWithSE( Module &M, StringRef FuncName, - function_ref Test) { + function_ref Test, + bool FiniteLoops = false) { auto *F = M.getFunction(FuncName); ASSERT_NE(F, nullptr) << "Could not find " << FuncName; - ScalarEvolution SE = buildSE(*F); + ScalarEvolution SE = buildSE(*F, FiniteLoops); Test(*F, *LI, SE); } @@ -1746,4 +1747,168 @@ }); } +TEST_F(ScalarEvolutionsTest, ComputeTripForFiniteSLE) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define void @foo(i32 %len) { " + "entry: " + " br label %for.body " + "for.body: " + " %iv = phi i32 [ %inc, %for.body ], [ 0, %entry ] " + " %inc = add i32 %iv, 1" + " %cmp = icmp sle i32 %iv, %len " + " br i1 %cmp, label %for.body, label %for.end " + "for.end: " + " ret void " + "} ", + Err, C); + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE( + *M, "foo", + [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *IV = getInstructionByName(F, "iv"); + auto *ScevIV = SE.getSCEV(IV); + const Loop *L = cast(ScevIV)->getLoop(); + const SCEV *ITC = SE.getExitCount(L, IV->getParent()); + // Assert exit count == smax(0, 1 + %len); + auto MaxExpr = dyn_cast(ITC); + EXPECT_TRUE(MaxExpr); + auto AddOne = dyn_cast(MaxExpr->getOperand(1)); + EXPECT_TRUE(AddOne); + auto One = dyn_cast(AddOne->getOperand(0)); + EXPECT_TRUE(One); + EXPECT_TRUE(One->getAPInt().isOne()); + auto Len = dyn_cast(AddOne->getOperand(1)); + EXPECT_TRUE(Len); + EXPECT_TRUE(Len->getValue()->getName() == "len"); + }, + /*FiniteLoops=*/true); +} + +TEST_F(ScalarEvolutionsTest, ComputeTripForFiniteULE) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define void @foo(i32 %len) { " + "entry: " + " br label %for.body " + "for.body: " + " %iv = phi i32 [ %inc, %for.body ], [ 0 , %entry ] " + " %inc = add i32 %iv, 1" + " %cmp = icmp ule i32 %iv, %len " + " br i1 %cmp, label %for.body, label %for.end " + "for.end: " + " ret void " + "} ", + Err, C); + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE( + *M, "foo", + [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *IV = getInstructionByName(F, "iv"); + auto *ScevIV = SE.getSCEV(IV); + const Loop *L = cast(ScevIV)->getLoop(); + const SCEV *ITC = SE.getExitCount(L, IV->getParent()); + // Assert exit count == umax(0, 1 + %len) == 1 + %len; + auto AddOne = dyn_cast(ITC); + EXPECT_TRUE(AddOne); + auto One = dyn_cast(AddOne->getOperand(0)); + EXPECT_TRUE(One); + EXPECT_TRUE(One->getAPInt().isOne()); + auto Len = dyn_cast(AddOne->getOperand(1)); + EXPECT_TRUE(Len); + EXPECT_TRUE(Len->getValue()->getName() == "len"); + }, + /*FiniteLoops=*/true); +} + +TEST_F(ScalarEvolutionsTest, ComputeTripForFiniteSGE) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32 %len) { " + "entry: " + " br label %for.body " + "for.body: " + " %iv = phi i32 [ %inc, %for.body ], [ %len, %entry ] " + " %inc = add i32 %iv, -1" + " %cmp = icmp sge i32 %iv, 1 " + " br i1 %cmp, label %for.body, label %for.end " + "for.end: " + " ret void " + "} ", + Err, C); + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE( + *M, "foo", + [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *IV = getInstructionByName(F, "iv"); + auto *ScevIV = SE.getSCEV(IV); + const Loop *L = cast(ScevIV)->getLoop(); + const SCEV *ITC = SE.getExitCount(L, IV->getParent()); + // Assert exit count == -1 * (0 smin %len) + %len == if %len <= 0 then 0 + // else len + auto AddExpr = dyn_cast(ITC); + EXPECT_TRUE(AddExpr); + auto MulMinusOne = dyn_cast(AddExpr->getOperand(0)); + EXPECT_TRUE(MulMinusOne); + auto MinusOne = dyn_cast(MulMinusOne->getOperand(0)); + EXPECT_TRUE(MinusOne); + EXPECT_TRUE(MinusOne->getAPInt().isAllOnes()); + auto SMin = dyn_cast(MulMinusOne->getOperand(1)); + EXPECT_TRUE(SMin); + auto Zero = dyn_cast(SMin->getOperand(0)); + EXPECT_TRUE(Zero); + EXPECT_TRUE(Zero->getAPInt().isZero()); + auto Len1 = dyn_cast(SMin->getOperand(1)); + EXPECT_TRUE(Len1); + EXPECT_TRUE(Len1->getValue()->getName() == "len"); + auto Len2 = dyn_cast(AddExpr->getOperand(1)); + EXPECT_TRUE(Len2); + EXPECT_TRUE(Len2->getValue()->getName() == "len"); + }, + /*FiniteLoops=*/true); +} + +TEST_F(ScalarEvolutionsTest, ComputeTripForFiniteUGE) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32 %len) { " + "entry: " + " br label %for.body " + "for.body: " + " %iv = phi i32 [ %inc, %for.body ], [ %len , %entry ] " + " %inc = add i32 %iv, -1" + " %cmp = icmp uge i32 %iv, 1 " + " br i1 %cmp, label %for.body, label %for.end " + "for.end: " + " ret void " + "} ", + Err, C); + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE( + *M, "foo", + [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *IV = getInstructionByName(F, "iv"); + auto *ScevIV = SE.getSCEV(IV); + const Loop *L = cast(ScevIV)->getLoop(); + const SCEV *ITC = SE.getExitCount(L, IV->getParent()); + // Assert exit count == len + auto Len = dyn_cast(ITC); + EXPECT_TRUE(Len); + EXPECT_TRUE(Len->getValue()->getName() == "len"); + }, + /*FiniteLoops=*/true); +} + } // end namespace llvm