Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -10483,6 +10483,37 @@ Context)) return true; + // If FoundLHS is AddRec and FoundPred is EQ, we can say that the min value of + // FoundRHS is AddRec's start value if and only if "AddRec == FoundRHS" is + // true. It means we can use "FoundRHS >= AddRec's start value". + if (isa(FoundRHS) && !isa(FoundLHS)) { + std::swap(FoundLHS, FoundRHS); + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + if (const auto *AddRec = dyn_cast(FoundLHS)) { + if (FoundPred == ICmpInst::ICMP_EQ && Pred == ICmpInst::ICMP_SLT) { + const auto *StepCst = + dyn_cast(AddRec->getStepRecurrence(*this)); + if (StepCst) { + auto isImpliedCondWithNewPred = [&](ICmpInst::Predicate NewFoundPred) { + const auto *NewFoundLHS = cast(FoundLHS)->getStart(); + return isImpliedCondBalancedTypes(Pred, LHS, RHS, NewFoundPred, + NewFoundLHS, FoundRHS, Context); + }; + if ((AddRec->hasNoSignedWrap() && !StepCst->getValue()->isNegative())) { + if (isImpliedCondWithNewPred(ICmpInst::ICMP_SLE)) + return true; + } + if (AddRec->hasNoUnsignedWrap()) { + if (isImpliedCondWithNewPred(ICmpInst::ICMP_ULE)) + return true; + } + } + } + } + // Otherwise assume the worst. return false; } Index: llvm/test/Transforms/IndVarSimplify/lftr-pr20680.ll =================================================================== --- llvm/test/Transforms/IndVarSimplify/lftr-pr20680.ll +++ llvm/test/Transforms/IndVarSimplify/lftr-pr20680.ll @@ -49,10 +49,7 @@ ; CHECK-NEXT: store i32 1, i32* @b, align 4 ; CHECK-NEXT: br label [[FOR_COND2_LOOPEXIT_US_US]] ; CHECK: for.inc.us.us: -; CHECK-NEXT: [[TMP5:%.*]] = phi i32 [ [[TMP4]], [[FOR_INC_LR_PH_US_US]] ], [ [[INC_US_US:%.*]], [[FOR_INC_US_US]] ] -; CHECK-NEXT: [[INC_US_US]] = add nsw i32 [[TMP5]], 1 -; CHECK-NEXT: [[EXITCOND3:%.*]] = icmp ne i32 [[INC_US_US]], 1 -; CHECK-NEXT: br i1 [[EXITCOND3]], label [[FOR_INC_US_US]], label [[FOR_COND8_FOR_COND2_LOOPEXIT_CRIT_EDGE_US_US:%.*]] +; CHECK-NEXT: br i1 true, label [[FOR_INC_US_US]], label [[FOR_COND8_FOR_COND2_LOOPEXIT_CRIT_EDGE_US_US:%.*]] ; CHECK: for.cond2.for.inc13_crit_edge.us-lcssa.us.us-lcssa.us: ; CHECK-NEXT: br label [[FOR_COND2_FOR_INC13_CRIT_EDGE_US_LCSSA_US:%.*]] ; CHECK: for.body3.lr.ph.split.us.split: @@ -62,14 +59,11 @@ ; CHECK: cond.false.us: ; CHECK-NEXT: br label [[COND_END_US]] ; CHECK: cond.end.us: -; CHECK-NEXT: [[TMP6:%.*]] = load i32, i32* @b, align 4 -; CHECK-NEXT: [[CMP91_US:%.*]] = icmp slt i32 [[TMP6]], 1 +; CHECK-NEXT: [[TMP5:%.*]] = load i32, i32* @b, align 4 +; CHECK-NEXT: [[CMP91_US:%.*]] = icmp slt i32 [[TMP5]], 1 ; CHECK-NEXT: br i1 [[CMP91_US]], label [[FOR_INC_LR_PH_US:%.*]], label [[FOR_COND2_LOOPEXIT_US:%.*]] ; CHECK: for.inc.us: -; CHECK-NEXT: [[TMP7:%.*]] = phi i32 [ [[TMP6]], [[FOR_INC_LR_PH_US]] ], [ [[INC_US:%.*]], [[FOR_INC_US:%.*]] ] -; CHECK-NEXT: [[INC_US]] = add nsw i32 [[TMP7]], 1 -; CHECK-NEXT: [[EXITCOND2:%.*]] = icmp ne i32 [[INC_US]], 1 -; CHECK-NEXT: br i1 [[EXITCOND2]], label [[FOR_INC_US]], label [[FOR_COND8_FOR_COND2_LOOPEXIT_CRIT_EDGE_US:%.*]] +; CHECK-NEXT: br i1 true, label [[FOR_INC_US:%.*]], label [[FOR_COND8_FOR_COND2_LOOPEXIT_CRIT_EDGE_US:%.*]] ; CHECK: for.cond2.loopexit.us: ; CHECK-NEXT: br i1 false, label [[FOR_COND2_FOR_INC13_CRIT_EDGE_US_LCSSA_US_US_LCSSA:%.*]], label [[FOR_BODY3_US]] ; CHECK: for.inc.lr.ph.us: @@ -93,12 +87,12 @@ ; CHECK: cond.false.us4: ; CHECK-NEXT: br label [[COND_END_US5]] ; CHECK: cond.end.us5: -; CHECK-NEXT: [[TMP8:%.*]] = load i32, i32* @b, align 4 -; CHECK-NEXT: [[CMP91_US7:%.*]] = icmp slt i32 [[TMP8]], 1 +; CHECK-NEXT: [[TMP6:%.*]] = load i32, i32* @b, align 4 +; CHECK-NEXT: [[CMP91_US7:%.*]] = icmp slt i32 [[TMP6]], 1 ; CHECK-NEXT: br i1 [[CMP91_US7]], label [[FOR_INC_LR_PH_US12:%.*]], label [[FOR_COND2_LOOPEXIT_US11:%.*]] ; CHECK: for.inc.us8: -; CHECK-NEXT: [[TMP9:%.*]] = phi i32 [ [[TMP8]], [[FOR_INC_LR_PH_US12]] ], [ [[INC_US9:%.*]], [[FOR_INC_US8:%.*]] ] -; CHECK-NEXT: [[INC_US9]] = add nsw i32 [[TMP9]], 1 +; CHECK-NEXT: [[TMP7:%.*]] = phi i32 [ [[TMP6]], [[FOR_INC_LR_PH_US12]] ], [ [[INC_US9:%.*]], [[FOR_INC_US8:%.*]] ] +; CHECK-NEXT: [[INC_US9]] = add nsw i32 [[TMP7]], 1 ; CHECK-NEXT: [[EXITCOND1:%.*]] = icmp ne i32 [[INC_US9]], 1 ; CHECK-NEXT: br i1 [[EXITCOND1]], label [[FOR_INC_US8]], label [[FOR_COND8_FOR_COND2_LOOPEXIT_CRIT_EDGE_US13:%.*]] ; CHECK: for.cond2.loopexit.us11: @@ -122,14 +116,14 @@ ; CHECK: cond.false: ; CHECK-NEXT: br label [[COND_END]] ; CHECK: cond.end: -; CHECK-NEXT: [[TMP10:%.*]] = load i32, i32* @b, align 4 -; CHECK-NEXT: [[CMP91:%.*]] = icmp slt i32 [[TMP10]], 1 +; CHECK-NEXT: [[TMP8:%.*]] = load i32, i32* @b, align 4 +; CHECK-NEXT: [[CMP91:%.*]] = icmp slt i32 [[TMP8]], 1 ; CHECK-NEXT: br i1 [[CMP91]], label [[FOR_INC_LR_PH:%.*]], label [[FOR_COND2_LOOPEXIT]] ; CHECK: for.inc.lr.ph: ; CHECK-NEXT: br label [[FOR_INC:%.*]] ; CHECK: for.inc: -; CHECK-NEXT: [[TMP11:%.*]] = phi i32 [ [[TMP10]], [[FOR_INC_LR_PH]] ], [ [[INC:%.*]], [[FOR_INC]] ] -; CHECK-NEXT: [[INC]] = add nsw i32 [[TMP11]], 1 +; CHECK-NEXT: [[TMP9:%.*]] = phi i32 [ [[TMP8]], [[FOR_INC_LR_PH]] ], [ [[INC:%.*]], [[FOR_INC]] ] +; CHECK-NEXT: [[INC]] = add nsw i32 [[TMP9]], 1 ; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i32 [[INC]], 1 ; CHECK-NEXT: br i1 [[EXITCOND]], label [[FOR_INC]], label [[FOR_COND8_FOR_COND2_LOOPEXIT_CRIT_EDGE:%.*]] ; CHECK: for.cond2.for.inc13_crit_edge.us-lcssa.us-lcssa: @@ -143,8 +137,8 @@ ; CHECK-NEXT: br label [[FOR_INC13]] ; CHECK: for.inc13: ; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nsw i32 [[INDVARS_IV]], 1 -; CHECK-NEXT: [[EXITCOND4:%.*]] = icmp ne i32 [[INDVARS_IV_NEXT]], 0 -; CHECK-NEXT: br i1 [[EXITCOND4]], label [[FOR_COND2_PREHEADER]], label [[FOR_END15:%.*]] +; CHECK-NEXT: [[EXITCOND2:%.*]] = icmp ne i32 [[INDVARS_IV_NEXT]], 0 +; CHECK-NEXT: br i1 [[EXITCOND2]], label [[FOR_COND2_PREHEADER]], label [[FOR_END15:%.*]] ; CHECK: for.end15: ; CHECK-NEXT: ret void ; Index: llvm/unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1414,6 +1414,146 @@ }); } +TEST_F(ScalarEvolutionsTest, ImpliedCondWithAddRecNSWStepPositive) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define void @foo(i32 %len) { " + "entry: " + " br label %loop " + "loop: " + " %iv = phi i32 [ 1, %entry], [%iv.next, %loop] " + " %iv.next = add nsw i32 %iv, 1 " + " %cmp = icmp eq i32 %iv, %len " + " br i1 %cmp, label %loop, label %exit " + "exit:" + " ret void " + "}", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + ICmpInst *ICmp = cast(getInstructionByName(F, "cmp")); + const SCEV *FoundLHS = SE.getSCEV(ICmp->getOperand(0)); + const SCEV *FoundRHS = SE.getSCEV(ICmp->getOperand(1)); + ICmpInst::Predicate FoundPred = ICmp->getPredicate(); + + const SCEV *LHS = SE.getZero(ICmp->getOperand(0)->getType()); + const SCEV *RHS = SE.getSCEV(getArgByName(F, "len")); + ICmpInst::Predicate Pred = ICmpInst::ICMP_SLT; + + EXPECT_TRUE( + isImpliedCond(SE, Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS)); + }); +} + +TEST_F(ScalarEvolutionsTest, ImpliedCondWithAddRecNSWStepNegative) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define void @foo(i32 %len) { " + "entry: " + " br label %loop " + "loop: " + " %iv = phi i32 [ 100, %entry], [%iv.next, %loop] " + " %iv.next = add nsw i32 %iv, -1 " + " %cmp = icmp eq i32 %iv, %len " + " br i1 %cmp, label %loop, label %exit " + "exit:" + " ret void " + "}", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + ICmpInst *ICmp = cast(getInstructionByName(F, "cmp")); + const SCEV *FoundLHS = SE.getSCEV(ICmp->getOperand(0)); + const SCEV *FoundRHS = SE.getSCEV(ICmp->getOperand(1)); + ICmpInst::Predicate FoundPred = ICmp->getPredicate(); + + const SCEV *LHS = SE.getZero(ICmp->getOperand(0)->getType()); + const SCEV *RHS = SE.getSCEV(getArgByName(F, "len")); + ICmpInst::Predicate Pred = ICmpInst::ICMP_SLT; + + EXPECT_FALSE( + isImpliedCond(SE, Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS)); + }); +} + +TEST_F(ScalarEvolutionsTest, ImpliedCondWithAddRecStepNegative) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define void @foo(i32 %len) { " + "entry: " + " br label %loop " + "loop: " + " %iv = phi i32 [ 100, %entry], [%iv.next, %loop] " + " %iv.next = add i32 %iv, -1 " + " %cmp = icmp eq i32 %iv, %len " + " br i1 %cmp, label %loop, label %exit " + "exit:" + " ret void " + "}", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + ICmpInst *ICmp = cast(getInstructionByName(F, "cmp")); + const SCEV *FoundLHS = SE.getSCEV(ICmp->getOperand(0)); + const SCEV *FoundRHS = SE.getSCEV(ICmp->getOperand(1)); + ICmpInst::Predicate FoundPred = ICmp->getPredicate(); + + const SCEV *LHS = SE.getZero(ICmp->getOperand(0)->getType()); + const SCEV *RHS = SE.getSCEV(getArgByName(F, "len")); + ICmpInst::Predicate Pred = ICmpInst::ICMP_SLT; + + EXPECT_FALSE( + isImpliedCond(SE, Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS)); + }); +} + +TEST_F(ScalarEvolutionsTest, ImpliedCondWithAddRecNUW) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = + parseAssemblyString("define void @foo(i32 %len) { " + "entry: " + " br label %loop " + "loop: " + " %iv = phi i32 [ 1, %entry], [%iv.next, %loop] " + " %iv.next = add nuw i32 %iv, 1 " + " %cmp = icmp eq i32 %iv, %len " + " br i1 %cmp, label %loop, label %exit " + "exit:" + " ret void " + "}", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + ICmpInst *ICmp = cast(getInstructionByName(F, "cmp")); + const SCEV *FoundLHS = SE.getSCEV(ICmp->getOperand(0)); + const SCEV *FoundRHS = SE.getSCEV(ICmp->getOperand(1)); + ICmpInst::Predicate FoundPred = ICmp->getPredicate(); + + const SCEV *LHS = SE.getZero(ICmp->getOperand(0)->getType()); + const SCEV *RHS = SE.getSCEV(getArgByName(F, "len")); + ICmpInst::Predicate Pred = ICmpInst::ICMP_ULT; + + EXPECT_FALSE( + isImpliedCond(SE, Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS)); + }); +} + TEST_F(ScalarEvolutionsTest, MatchURem) { LLVMContext C; SMDiagnostic Err;