Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1745,6 +1745,17 @@ /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. /// + /// This routine tries to weaken the known condition basing on fact that + /// FoundLHS is an AddRec. + bool isImpliedCondOperandsViaAddRecStart(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS); + + /// Test whether the condition described by Pred, LHS, and RHS is true + /// whenever the condition described by Pred, FoundLHS, and FoundRHS is + /// true. + /// /// This routine tries to figure out predicate for Phis which are SCEVUnknown /// if it is true for every possible incoming value from their respective /// basic blocks. Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -9886,6 +9886,44 @@ return None; } +bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, const SCEV *FoundRHS) { + // {X,+,1} < Y implies X < Y, and {X,+,-1} > Y implies X > Y. + // {X,+,C} != Y implies X != Y and {X,+,C} == Y implies X == Y if Y is a + // loop-invariant. + // Try to match (Pred, FoundLHS, FoundRHS) to this pattern. If it fits, use + // this new fact to infer the original pred. + auto *AR = dyn_cast(FoundLHS); + if (!AR || !AR->isAffine()) + return false; + if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop())) + return false; + switch (Pred) { + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: + if (AR->getOperand(1) != getNegativeSCEV(getOne(AR->getType()))) + return false; + break; + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLE: + case ICmpInst::ICMP_ULE: + if (AR->getOperand(1) != getOne(AR->getType())) + return false; + break; + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: + break; + default: + return false; + } + + return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS); +} + bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { @@ -10083,6 +10121,9 @@ if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; + if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS)) + return true; + return isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS) || // ~x < ~y --> x > y Index: llvm/unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1217,4 +1217,36 @@ }); } +TEST_F(ScalarEvolutionsTest, ImpliedViaAddRecStart) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32* %p) { " + "entry: " + " %x = load i32, i32* %p, !range !0 " + " br label %loop " + "loop: " + " %iv = phi i32 [ %x, %entry], [%iv.next, %backedge] " + " %ne.check = icmp ne i32 %iv, 0 " + " br i1 %ne.check, label %backedge, label %exit " + "backedge: " + " %iv.next = add i32 %iv, -1 " + " br label %loop " + "exit:" + " ret void " + "} " + "!0 = !{i32 0, i32 2147483647}", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *X = SE.getSCEV(getInstructionByName(F, "x")); + auto *Backedge = getInstructionByName(F, "iv.next")->getParent(); + EXPECT_TRUE(SE.isKnownPredicateAt(ICmpInst::ICMP_NE, X, + SE.getZero(X->getType()), Backedge)); + }); +} + } // end namespace llvm