Index: llvm/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/include/llvm/Analysis/ScalarEvolution.h +++ llvm/include/llvm/Analysis/ScalarEvolution.h @@ -916,6 +916,14 @@ bool isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS); + /// Test whether the condition described by Pred, LHS, and RHS is true + /// whenever the condition described by FoundPred, FoundLHS, FoundRHS is + /// true in given Context. If Context is nullptr, then the found predicate is + /// true everywhere. + bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, + const SCEV *FoundRHS, const BasicBlock *Context = nullptr); + /// Test if the given expression is known to satisfy the condition described /// by Pred, LHS, and RHS in the given Context. bool isKnownPredicateAt(ICmpInst::Predicate Pred, const SCEV *LHS, @@ -1685,14 +1693,6 @@ const BasicBlock *Context = nullptr); /// Test whether the condition described by Pred, LHS, and RHS is true - /// whenever the condition described by FoundPred, FoundLHS, FoundRHS is - /// true in given Context. If Context is nullptr, then the found predicate is - /// true everywhere. - bool isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, - const SCEV *FoundRHS, const BasicBlock *Context = nullptr); - - /// Test whether the condition described by Pred, LHS, and RHS is true /// whenever the condition described by Pred, FoundLHS, and FoundRHS is /// true in given Context. If Context is nullptr, then the found predicate is /// true everywhere. Index: llvm/lib/Transforms/Scalar/IndVarSimplify.cpp =================================================================== --- llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -2329,18 +2329,24 @@ return MadeAnyChanges; } +enum ExitCondAnalysisResult { + CanBeRemoved, + CanBeReplacedWithFirstIterCheck, + CannotOptimize +}; + // Returns true if the condition of \p BI being checked is invariant and can be // proved to be trivially true during at least first \p MaxIter iterations. -static bool isTrivialCond(const Loop *L, BranchInst *BI, ScalarEvolution *SE, - bool ProvingLoopExit, const SCEV *MaxIter, - bool SkipLastIter) { +static ExitCondAnalysisResult analyzeCond(const Loop *L, BranchInst *BI, + ScalarEvolution *SE, bool ProvingLoopExit, + const SCEV *MaxIter, bool SkipLastIter) { ICmpInst::Predicate Pred; Value *LHS, *RHS; using namespace PatternMatch; BasicBlock *TrueSucc, *FalseSucc; if (!match(BI, m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), m_BasicBlock(TrueSucc), m_BasicBlock(FalseSucc)))) - return false; + return CannotOptimize; assert((L->contains(TrueSucc) != L->contains(FalseSucc)) && "Not a loop exit!"); @@ -2357,37 +2363,40 @@ const SCEV *RHSS = SE->getSCEVAtScope(RHS, L); // Can we prove it to be trivially true? if (SE->isKnownPredicateAt(Pred, LHSS, RHSS, BI)) - return true; + return CanBeRemoved; if (ProvingLoopExit) - return false; + return CannotOptimize; // If we are proving that we stay in loop, try to prove the following set of // facts: // - The predicate is true on the 1st iteration; // - The predicate is still true on suggested last iteration; // - No overflow happens in between. + // Perhaps we won't be able to prove the fact on 1st iteration. But if this + // fact implies the other two, then we can replace the whole check with the + // first iteration check, replacing the recurrency with a loop invariant. auto *AR = dyn_cast(LHSS); // TODO: Lift affinity limitation in the future. if (!AR || !AR->isAffine()) - return false; + return CannotOptimize; // The predicate must be relational (i.e. <, <=, >=, >). if (!ICmpInst::isRelational(Pred)) - return false; + return CannotOptimize; // TODO: Support steps other than +/- 1. const SCEV *Step = AR->getOperand(1); auto *One = SE->getOne(Step->getType()); auto *MinusOne = SE->getNegativeSCEV(One); if (Step != One && Step != MinusOne) - return false; + return CannotOptimize; // Type mismatch here means that MaxIter is potentially larger than max // unsigned value in start type, which mean we cannot prove no wrap for the // indvar. if (AR->getType() != MaxIter->getType()) - return false; + return CannotOptimize; if (SkipLastIter) @@ -2395,14 +2404,17 @@ // First, check the predicate on the 1st iteration. const SCEV *Start = AR->getStart(); - if (!SE->isKnownPredicateAt(Pred, Start, RHSS, BI)) - return false; + bool StartCondProved = SE->isKnownPredicateAt(Pred, Start, RHSS, BI); + auto *BB = BI->getParent(); // Value of IV on suggested last iteration. const SCEV *Last = AR->evaluateAtIteration(MaxIter, *SE); // Does it still meet the requirement? - if (!SE->isKnownPredicateAt(Pred, Last, RHSS, BI)) - return false; + bool LastCondProved = SE->isKnownPredicateAt(Pred, Last, RHSS, BI); + bool LastCondImplied = + LastCondProved || + SE->isImpliedCond(Pred, Last, RHSS, Pred, Start, RHSS, BB); + // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does // not exceed max unsigned value of this type), this effectively proves // that there is no wrap during the iteration. To prove that there is no @@ -2412,11 +2424,15 @@ CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; if (Step == MinusOne) NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred); - if (!SE->isKnownPredicateAt(NoOverflowPred, Start, Last, BI)) - return false; - - // Everything is fine. - return true; + bool NoOverflowProved = + SE->isKnownPredicateAt(NoOverflowPred, Start, Last, BI); + bool NoOverflowImplied = + NoOverflowProved || + SE->isImpliedCond(NoOverflowPred, Start, Last, Pred, Start, RHSS, BB); + + if (LastCondImplied && NoOverflowImplied) + return StartCondProved ? CanBeRemoved : CanBeReplacedWithFirstIterCheck; + return CannotOptimize; } bool IndVarSimplify::optimizeLoopExits(Loop *L, SCEVExpander &Rewriter) { @@ -2497,11 +2513,29 @@ if (isa(ExitCount)) { auto *BI = cast(ExitingBB->getTerminator()); auto OptimizeCond = [&](bool Inverted, bool SkipLastIter) { - if (isTrivialCond(L, BI, SE, Inverted, MaxExitCount, SkipLastIter)) { + switch (analyzeCond(L, BI, SE, Inverted, MaxExitCount, SkipLastIter)) { + case CanBeRemoved: FoldExit(ExitingBB, Inverted); return true; + case CanBeReplacedWithFirstIterCheck: { + auto *Cond = cast(BI->getCondition()); + const SCEV *StartSCEV = cast( + SE->getSCEV(Cond->getOperand(0)))->getStart(); + Rewriter.setInsertPoint(BI); + auto *StartV = Rewriter.expandCodeFor(StartSCEV); + IRBuilder<> Builder(BI); + auto *NewCond = + Builder.CreateICmp(Cond->getPredicate(), StartV, + Cond->getOperand(1), Cond->getName()); + BI->setOperand(0, NewCond); + if (Cond->getNumUses() == 0) + Cond->eraseFromParent(); + return true; } - return false; + case CannotOptimize: + return false; + } + llvm_unreachable("Unknown analysis result!"); }; // Okay, we do not know the exit count here. Can we at least prove that it // will remain the same within iteration space? Index: llvm/test/Transforms/IndVarSimplify/predicated_ranges.ll =================================================================== --- llvm/test/Transforms/IndVarSimplify/predicated_ranges.ll +++ llvm/test/Transforms/IndVarSimplify/predicated_ranges.ll @@ -470,6 +470,7 @@ ; CHECK-LABEL: @test_can_predicate_simple_unsigned( ; CHECK-NEXT: preheader: ; CHECK-NEXT: [[LEN:%.*]] = load i32, i32* [[P:%.*]], align 4 +; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[LEN]], -1 ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[LEN]], [[PREHEADER:%.*]] ], [ [[IV_NEXT:%.*]], [[BACKEDGE:%.*]] ] @@ -477,8 +478,8 @@ ; CHECK-NEXT: br i1 [[ZERO_COND]], label [[EXIT:%.*]], label [[RANGE_CHECK_BLOCK:%.*]] ; CHECK: range_check_block: ; CHECK-NEXT: [[IV_NEXT]] = sub i32 [[IV]], 1 -; CHECK-NEXT: [[RANGE_CHECK:%.*]] = icmp ult i32 [[IV_NEXT]], [[LEN]] -; CHECK-NEXT: br i1 [[RANGE_CHECK]], label [[BACKEDGE]], label [[FAIL:%.*]] +; CHECK-NEXT: [[RANGE_CHECK1:%.*]] = icmp ult i32 [[TMP0]], [[LEN]] +; CHECK-NEXT: br i1 [[RANGE_CHECK1]], label [[BACKEDGE]], label [[FAIL:%.*]] ; CHECK: backedge: ; CHECK-NEXT: [[EL_PTR:%.*]] = getelementptr i32, i32* [[P]], i32 [[IV]] ; CHECK-NEXT: [[EL:%.*]] = load i32, i32* [[EL_PTR]], align 4