Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1678,23 +1678,29 @@ getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) const; /// Test whether the condition described by Pred, LHS, and RHS is true - /// whenever the given FoundCondValue value evaluates to true. + /// whenever the given FoundCondValue value evaluates to true in given + /// Context. If Context is nullptr, then the found predicate is true + /// everywhere. bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - const Value *FoundCondValue, bool Inverse); + const Value *FoundCondValue, bool Inverse, + const BasicBlock *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by FoundPred, FoundLHS, FoundRHS is - /// true. + /// true in given Context. If Context is nullptr, then the found predicate is + /// true everywhere. bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, - const SCEV *FoundRHS); + const SCEV *FoundRHS, const BasicBlock *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is - /// true. + /// true in given Context. If Context is nullptr, then the found predicate is + /// true everywhere. bool isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS); + const SCEV *FoundRHS, + const BasicBlock *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is @@ -1745,6 +1751,18 @@ /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true. /// + /// This routine tries to weaken the known condition basing on fact that + /// FoundLHS is an AddRec. + bool isImpliedCondOperandsViaAddRecStart(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS, + const BasicBlock *Context); + + /// Test whether the condition described by Pred, LHS, and RHS is true + /// whenever the condition described by Pred, FoundLHS, and FoundRHS is + /// true. + /// /// This routine tries to figure out predicate for Phis which are SCEVUnknown /// if it is true for every possible incoming value from their respective /// basic blocks. Index: llvm/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/lib/Analysis/ScalarEvolution.cpp +++ llvm/lib/Analysis/ScalarEvolution.cpp @@ -9544,16 +9544,17 @@ }; // Try to prove (Pred, LHS, RHS) using isImpliedCond. - auto ProveViaCond = [&](const Value *Condition, bool Inverse) { - if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse)) + auto ProveViaCond = [&](const Value *Condition, bool Inverse, + const BasicBlock *Context = nullptr) { + if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, Context)) return true; if (ProvingStrictComparison) { if (!ProvedNonStrictComparison) - ProvedNonStrictComparison = - isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse); + ProvedNonStrictComparison = isImpliedCond(NonStrictPredicate, LHS, RHS, + Condition, Inverse, Context); if (!ProvedNonEquality) - ProvedNonEquality = - isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse); + ProvedNonEquality = isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, + Condition, Inverse, Context); if (ProvedNonStrictComparison && ProvedNonEquality) return true; } @@ -9581,7 +9582,8 @@ continue; if (ProveViaCond(LoopEntryPredicate->getCondition(), - LoopEntryPredicate->getSuccessor(0) != Pair.second)) + LoopEntryPredicate->getSuccessor(0) != Pair.second, + Pair.second)) return true; } @@ -9619,7 +9621,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - const Value *FoundCondValue, bool Inverse) { + const Value *FoundCondValue, bool Inverse, + const BasicBlock *Context) { if (!PendingLoopPredicates.insert(FoundCondValue).second) return false; @@ -9630,12 +9633,16 @@ if (const BinaryOperator *BO = dyn_cast(FoundCondValue)) { if (BO->getOpcode() == Instruction::And) { if (!Inverse) - return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || - isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse, + Context) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse, + Context); } else if (BO->getOpcode() == Instruction::Or) { if (Inverse) - return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) || - isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse); + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse, + Context) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse, + Context); } } @@ -9653,14 +9660,14 @@ const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); - return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS); + return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context); } bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, - const SCEV *FoundRHS) { + const SCEV *FoundLHS, const SCEV *FoundRHS, + const BasicBlock *Context) { // Balance the types. if (getTypeSizeInBits(LHS->getType()) < getTypeSizeInBits(FoundLHS->getType())) { @@ -9704,16 +9711,16 @@ // Check whether the found predicate is the same as the desired predicate. if (FoundPred == Pred) - return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context); // Check whether swapping the found predicate makes it the same as the // desired predicate. if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { if (isa(RHS)) - return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS); + return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context); else - return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), - RHS, LHS, FoundLHS, FoundRHS); + return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS, + LHS, FoundLHS, FoundRHS, Context); } // Unsigned comparison is the same as signed comparison when both the operands @@ -9721,7 +9728,7 @@ if (CmpInst::isUnsigned(FoundPred) && CmpInst::getSignedPredicate(FoundPred) == Pred && isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) - return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS); + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context); // Check if we can make progress by sharpening ranges. if (FoundPred == ICmpInst::ICMP_NE && @@ -9758,8 +9765,8 @@ case ICmpInst::ICMP_UGE: // We know V `Pred` SharperMin. If this implies LHS `Pred` // RHS, we're done. - if (isImpliedCondOperands(Pred, LHS, RHS, V, - getConstant(SharperMin))) + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin), + Context)) return true; LLVM_FALLTHROUGH; @@ -9774,7 +9781,8 @@ // // If V `Pred` Min implies LHS `Pred` RHS, we're done. - if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), + Context)) return true; break; @@ -9782,14 +9790,14 @@ case ICmpInst::ICMP_SLE: case ICmpInst::ICMP_ULE: if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS, - LHS, V, getConstant(SharperMin))) + LHS, V, getConstant(SharperMin), Context)) return true; LLVM_FALLTHROUGH; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS, - LHS, V, getConstant(Min))) + LHS, V, getConstant(Min), Context)) return true; break; @@ -9803,11 +9811,12 @@ // Check whether the actual condition is beyond sufficient. if (FoundPred == ICmpInst::ICMP_EQ) if (ICmpInst::isTrueWhenEqual(Pred)) - if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS)) + if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context)) return true; if (Pred == ICmpInst::ICMP_NE) if (!ICmpInst::isTrueWhenEqual(FoundPred)) - if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS)) + if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, + Context)) return true; // Otherwise assume the worst. @@ -9886,6 +9895,44 @@ return None; } +bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, const SCEV *FoundRHS, const BasicBlock *Context) { + // Try to recognize the following pattern: + // + // FoundRHS = ... + // ... + // loop: + // FoundLHS = {Start,+,W} + // context: // Basic block from the same loop + // known(Pred, FoundLHS, FoundRHS) + // + // If some predicate is known in the context of a loop, it is also known on + // each iteration of this loop, including the first iteration. Therefore, in + // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to + // prove the original pred using this fact. + if (!Context) + return false; + // Make sure AR varies in the context block. + if (auto *AR = dyn_cast(FoundLHS)) { + if (!AR->getLoop()->contains(Context)) + return false; + if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop())) + return false; + return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS); + } + + if (auto *AR = dyn_cast(FoundRHS)) { + if (!AR->getLoop()->contains(Context)) + return false; + if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop())) + return false; + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart()); + } + + return false; +} + bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { @@ -10076,13 +10123,18 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS) { + const SCEV *FoundRHS, + const BasicBlock *Context) { if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; + if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS, + Context)) + return true; + return isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS) || // ~x < ~y --> x > y Index: llvm/unittests/Analysis/ScalarEvolutionTest.cpp =================================================================== --- llvm/unittests/Analysis/ScalarEvolutionTest.cpp +++ llvm/unittests/Analysis/ScalarEvolutionTest.cpp @@ -1217,4 +1217,36 @@ }); } +TEST_F(ScalarEvolutionsTest, ImpliedViaAddRecStart) { + LLVMContext C; + SMDiagnostic Err; + std::unique_ptr M = parseAssemblyString( + "define void @foo(i32* %p) { " + "entry: " + " %x = load i32, i32* %p, !range !0 " + " br label %loop " + "loop: " + " %iv = phi i32 [ %x, %entry], [%iv.next, %backedge] " + " %ne.check = icmp ne i32 %iv, 0 " + " br i1 %ne.check, label %backedge, label %exit " + "backedge: " + " %iv.next = add i32 %iv, -1 " + " br label %loop " + "exit:" + " ret void " + "} " + "!0 = !{i32 0, i32 2147483647}", + Err, C); + + ASSERT_TRUE(M && "Could not parse module?"); + ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!"); + + runWithSE(*M, "foo", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + auto *X = SE.getSCEV(getInstructionByName(F, "x")); + auto *Backedge = getInstructionByName(F, "iv.next")->getParent(); + EXPECT_TRUE(SE.isKnownPredicateAt(ICmpInst::ICMP_NE, X, + SE.getZero(X->getType()), Backedge)); + }); +} + } // end namespace llvm