Index: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h +++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h @@ -829,6 +829,25 @@ /// Test if the given expression is known to be non-zero. bool isKnownNonZero(const SCEV *S); + /// Splits SCEV expression \p S into two SCEVs. One of them is obtained from + /// \p S by substitution of all AddRec sub-expression related to loop \p L + /// with initial value of that SCEV. The second is obtained from \p S by + /// substitution of all AddRec sub-expressions related to loop \p L with post + /// increment of this AddRec in the loop \p L. In both cases all other AddRec + /// sub-expressions (not related to \p L) remain the same. + /// If the \p S contains non-invariant unknown SCEV the function returns + /// CouldNotCompute SCEV in both values of std::pair. + /// For example, for SCEV S={0, +, 1} + {0, +, 1} and loop L=L1 + /// the function returns pair: + /// first = {0, +, 1} + /// second = {1, +, 1} + {0, +, 1} + /// We can see that for the first AddRec sub-expression it was replaced with + /// 0 (initial value) for the first element and to {1, +, 1} (post + /// increment value) for the second one. In both cases AddRec expression + /// related to L2 remains the same. + std::pair SplitIntoInitAndPostInc(const Loop *L, + const SCEV *S); + /// Test if the given expression is known to satisfy the condition described /// by Pred, LHS, and RHS. bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, Index: llvm/trunk/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/trunk/lib/Analysis/ScalarEvolution.cpp +++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp @@ -8727,38 +8727,78 @@ return isKnownNegative(S) || isKnownPositive(S); } +std::pair +ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) { + // Compute SCEV on entry of loop L. + const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this); + if (Start == getCouldNotCompute()) + return { Start, Start }; + // Compute post increment SCEV for loop L. + const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); + assert(PostInc != getCouldNotCompute() && "Unexpected could not compute"); + return { Start, PostInc }; +} + bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // Canonicalize the inputs first. (void)SimplifyICmpOperands(Pred, LHS, RHS); - // If LHS or RHS is an addrec, check to see if the condition is true in - // every iteration of the loop. - // If LHS and RHS are both addrec, both conditions must be true in - // every iteration of the loop. - const SCEVAddRecExpr *LAR = dyn_cast(LHS); - const SCEVAddRecExpr *RAR = dyn_cast(RHS); - bool LeftGuarded = false; - bool RightGuarded = false; - if (LAR) { - const Loop *L = LAR->getLoop(); - if (isAvailableAtLoopEntry(RHS, L) && - isKnownOnEveryIteration(Pred, LAR, RHS)) { - if (!RAR) return true; - LeftGuarded = true; - } - } - if (RAR) { - const Loop *L = RAR->getLoop(); - auto SwappedPred = ICmpInst::getSwappedPredicate(Pred); - if (isAvailableAtLoopEntry(LHS, L) && - isKnownOnEveryIteration(SwappedPred, RAR, LHS)) { - if (!LAR) return true; - RightGuarded = true; + // We'd like to check the predicate on every iteration of the most dominated + // loop between loops used in LHS and RHS. + // To do this we use the following list of steps: + // 1. Collect set S all loops on which either LHS or RHS depend. + // 2. If S is non-empty + // a. Let PD be the element of S which is dominated by all other elements of S + // b. Let E(LHS) be value of LHS on entry of PD. + // To get E(LHS), we should just take LHS and replace all AddRecs that are + // attached to PD on with their entry values. + // Define E(RHS) in the same way. + // c. Let B(LHS) be value of L on backedge of PD. + // To get B(LHS), we should just take LHS and replace all AddRecs that are + // attached to PD on with their backedge values. + // Define B(RHS) in the same way. + // d. Note that E(LHS) and E(RHS) are automatically available on entry of PD, + // so we can assert on that. + // e. Return true if isLoopEntryGuardedByCond(Pred, E(LHS), E(RHS)) && + // isLoopBackedgeGuardedByCond(Pred, B(LHS), B(RHS)) + + // First collect all loops. + SmallPtrSet LoopsUsed; + getUsedLoops(LHS, LoopsUsed); + getUsedLoops(RHS, LoopsUsed); + + // Domination relationship must be a linear order on collected loops. +#ifndef NDEBUG + for (auto *L1 : LoopsUsed) + for (auto *L2 : LoopsUsed) + assert((DT.dominates(L1->getHeader(), L2->getHeader()) || + DT.dominates(L2->getHeader(), L1->getHeader())) && + "Domination relationship is not a linear order"); +#endif + if (!LoopsUsed.empty()) { + const Loop *MDL = *std::max_element(LoopsUsed.begin(), LoopsUsed.end(), + [&](const Loop *L1, const Loop *L2) { + return DT.dominates(L1->getHeader(), L2->getHeader()); + }); + + // Get init and post increment value for LHS. + auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS); + if (SplitLHS.first != getCouldNotCompute()) { + // if LHS does not contain unknown non-invariant SCEV then + // get init and post increment value for RHS. + auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS); + if (SplitRHS.first != getCouldNotCompute()) { + // if RHS does not contain unknown non-invariant SCEV then + // check whether implication is possible. + if (isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, + SplitRHS.first) && + isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second, + SplitRHS.second)) + return true; + } } } - if (LeftGuarded && RightGuarded) - return true; if (isKnownPredicateViaSplitting(Pred, LHS, RHS)) return true; Index: llvm/trunk/test/Transforms/IndVarSimplify/inner-loop-by-latch-cond.ll =================================================================== --- llvm/trunk/test/Transforms/IndVarSimplify/inner-loop-by-latch-cond.ll +++ llvm/trunk/test/Transforms/IndVarSimplify/inner-loop-by-latch-cond.ll @@ -0,0 +1,33 @@ +; RUN: opt < %s -indvars -S | FileCheck %s + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128-ni:1" +target triple = "x86_64-unknown-linux-gnu" + +declare void @foo(i64) + +define void @test(i64 %a) { +entry: + br label %outer_header + +outer_header: + %i = phi i64 [20, %entry], [%i.next, %outer_latch] + %i.next = add nuw nsw i64 %i, 1 + br label %inner_header + +inner_header: + %j = phi i64 [1, %outer_header], [%j.next, %inner_header] + %cmp = icmp ult i64 %j, %i.next +; CHECK-NOT: select + %s = select i1 %cmp, i64 %j, i64 %i + call void @foo(i64 %s) + %j.next = add nuw nsw i64 %j, 1 + %cond = icmp ult i64 %j, %i + br i1 %cond, label %inner_header, label %outer_latch + +outer_latch: + %cond2 = icmp ne i64 %i.next, 40 + br i1 %cond2, label %outer_header, label %return + +return: + ret void +}