Index: include/llvm/Analysis/ScalarEvolution.h =================================================================== --- include/llvm/Analysis/ScalarEvolution.h +++ include/llvm/Analysis/ScalarEvolution.h @@ -829,6 +829,27 @@ /// Test if the given expression is known to be non-zero. bool isKnownNonZero(const SCEV *S); + /// Splits SCEV expression \p S into two SCEVs. One of them is obtained from + /// \p S by substitution of all AddRec sub-expression related to loop \p L + /// with initial value of that SCEV. The second is obtained from \p S by + /// substitution of all AddRec sub-expressions related to loop \p L with post + /// increment of this AddRec in the loop \p L. In both cases all other AddRec + /// sub-expressions (not related to \p L) remain the same. + /// If the \p S contains non-invariant unknown SCEV the function returns + /// null in both values of std::pair. + /// For example, for SCEV S={0, +, 1) + {0, +, 1) and loop L=L1 + /// the function returns pair: + /// first = {0, +, 1) + /// second = {1, +, 1) + {0, +, 1) + /// We can see that for the first AddRec sub-expression it was replaced with + /// 0 (initial value) for the first element and to {1, +, 1) (post + /// increment value) for the second one. In both cases AddRec expression + /// related to L2 remains the same. + /// The primary goal of this function is an utility function for + /// isKnownPredicate. + std::pair SplitIntoInitAndPostInc(const Loop *L, + const SCEV *S); + /// Test if the given expression is known to satisfy the condition described /// by Pred, LHS, and RHS. bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -8727,38 +8727,78 @@ return isKnownNegative(S) || isKnownPositive(S); } +std::pair +ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) { + // Compute SCEV on entry of loop L. + const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this); + if (Start == getCouldNotCompute()) + return { nullptr, nullptr }; + // Compute post increment SCEV for loop L. + const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); + assert(PostInc != getCouldNotCompute() && "Unexpected could not compute"); + return { Start, PostInc }; +} + bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // Canonicalize the inputs first. (void)SimplifyICmpOperands(Pred, LHS, RHS); - // If LHS or RHS is an addrec, check to see if the condition is true in - // every iteration of the loop. - // If LHS and RHS are both addrec, both conditions must be true in - // every iteration of the loop. - const SCEVAddRecExpr *LAR = dyn_cast(LHS); - const SCEVAddRecExpr *RAR = dyn_cast(RHS); - bool LeftGuarded = false; - bool RightGuarded = false; - if (LAR) { - const Loop *L = LAR->getLoop(); - if (isAvailableAtLoopEntry(RHS, L) && - isKnownOnEveryIteration(Pred, LAR, RHS)) { - if (!RAR) return true; - LeftGuarded = true; - } - } - if (RAR) { - const Loop *L = RAR->getLoop(); - auto SwappedPred = ICmpInst::getSwappedPredicate(Pred); - if (isAvailableAtLoopEntry(LHS, L) && - isKnownOnEveryIteration(SwappedPred, RAR, LHS)) { - if (!LAR) return true; - RightGuarded = true; - } - } - if (LeftGuarded && RightGuarded) - return true; + // We'd like to check the predicate on every iteration of the most dominated + // loop between loops used in LHS and RHS. + // To do this we use the following list of steps: + // 1. Collect set S all loops on which either LHS or RHS depend. + // 2. If S is non-empty + // a. Let PD be the element of S which is dominated by all other elements of S + // b. Let E(LHS) be value of LHS on entry of PD. + // To get E(LHS), we should just take LHS and replace all AddRecs that are + // attached to PD on with their entry values. + // Define E(RHS) in the same way. + // c. Let B(LHS) be value of L on backedge of PD. + // To get B(LHS), we should just take LHS and replace all AddRecs that are + // attached to PD on with their backedge values. + // Define B(RHS) in the same way. + // d. Note that E(LHS) and E(RHS) are automatically available on entry of PD, + // so we can assert on that. + // e. Return true if isLoopEntryGuardedByCond(Pred, E(LHS), E(RHS)) && + // isLoopBackedgeGuardedByCond(Pred, B(LHS), B(RHS)) + + // First collect all loops. + SmallPtrSet LoopsUsed; + getUsedLoops(LHS, LoopsUsed); + getUsedLoops(RHS, LoopsUsed); + +// Domination relationship must be a linear order on collected loops. +#ifndef NDEBUG + for (auto *L1 : LoopsUsed) + for (auto *L2 : LoopsUsed) + assert((DT.dominates(L1->getHeader(), L2->getHeader()) || + DT.dominates(L2->getHeader(), L1->getHeader())) && + "Domination relationship is not a linear order"); +#endif + if (!LoopsUsed.empty()) { + const Loop *MDL = *std::max_element(LoopsUsed.begin(), LoopsUsed.end(), + [&](const Loop *L1, const Loop *L2) { + return DT.dominates(L1->getHeader(), L2->getHeader()); + }); + + // Get init and post increment value for LHS. + auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS); + if (SplitLHS.first) { + // if LHS does not contain unknown non-invariant SCEV then + // get init and post increment value for RHS. + auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS); + if (SplitRHS.first) { + // if RHS does not contain unknown non-invariant SCEV then + // check whether implication is possible. + if (isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, + SplitRHS.first) && + isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second, + SplitRHS.second)) + return true; + } + } + } if (isKnownPredicateViaSplitting(Pred, LHS, RHS)) return true;