Index: lib/Transforms/Scalar/IndVarSimplify.cpp =================================================================== --- lib/Transforms/Scalar/IndVarSimplify.cpp +++ lib/Transforms/Scalar/IndVarSimplify.cpp @@ -146,7 +146,8 @@ bool rewriteFirstIterationLoopExitValues(Loop *L); bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) const; - bool linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, + bool linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, + const SCEV *BackedgeTakenCount, PHINode *IndVar, SCEVExpander &Rewriter); bool sinkUnusedInvariants(Loop *L); @@ -1979,33 +1980,6 @@ // linearFunctionTestReplace and its kin. Rewrite the loop exit condition. //===----------------------------------------------------------------------===// -/// Return true if this loop's backedge taken count expression can be safely and -/// cheaply expanded into an instruction sequence that can be used by -/// linearFunctionTestReplace. -static bool canExpandBackedgeTakenCount(Loop *L, ScalarEvolution *SE, - SCEVExpander &Rewriter) { - const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); - if (isa(BackedgeTakenCount)) - return false; - - // Better to break the backedge - if (BackedgeTakenCount->isZero()) - return false; - - // Loops with multiple exits are not currently suported by lftr - if (!L->getExitingBlock()) - return false; - - // Can't rewrite non-branch yet. - if (!isa(L->getExitingBlock()->getTerminator())) - return false; - - if (Rewriter.isHighCostExpansion(BackedgeTakenCount, L)) - return false; - - return true; -} - /// Given an Value which is hoped to be part of an add recurance in the given /// loop, return the associated Phi node if so. Otherwise, return null. Note /// that this is less general than SCEVs AddRec checking. @@ -2048,15 +2022,14 @@ /// Given a loop with one backedge and one exit, return the ICmpInst /// controlling the sole loop exit. There is no guarantee that the exiting /// block is also the latch. -static ICmpInst *getLoopTest(Loop *L) { - assert(L->getExitingBlock() && "expected loop exit"); +static ICmpInst *getLoopTest(Loop *L, BasicBlock *ExitingBB) { BasicBlock *LatchBlock = L->getLoopLatch(); // Don't bother with LFTR if the loop is not properly simplified. if (!LatchBlock) return nullptr; - BranchInst *BI = dyn_cast(L->getExitingBlock()->getTerminator()); + BranchInst *BI = dyn_cast(ExitingBB->getTerminator()); assert(BI && "expected exit branch"); return dyn_cast(BI->getCondition()); @@ -2064,9 +2037,9 @@ /// linearFunctionTestReplace policy. Return true unless we can show that the /// current exit test is already sufficiently canonical. -static bool needsLFTR(Loop *L) { +static bool needsLFTR(Loop *L, BasicBlock *ExitingBB) { // Do LFTR to simplify the exit condition to an ICMP. - ICmpInst *Cond = getLoopTest(L); + ICmpInst *Cond = getLoopTest(L, ExitingBB); if (!Cond) return true; @@ -2188,12 +2161,11 @@ /// BECount may be an i8* pointer type. The pointer difference is already /// valid count without scaling the address stride, so it remains a pointer /// expression as far as SCEV is concerned. -static PHINode *FindLoopCounter(Loop *L, const SCEV *BECount, - ScalarEvolution *SE) { +static PHINode *FindLoopCounter(Loop *L, BasicBlock *ExitingBB, + const SCEV *BECount, ScalarEvolution *SE) { uint64_t BCWidth = SE->getTypeSizeInBits(BECount->getType()); - Value *Cond = - cast(L->getExitingBlock()->getTerminator())->getCondition(); + Value *Cond = cast(ExitingBB->getTerminator())->getCondition(); // Loop over all of the PHI nodes, looking for a simple counter. PHINode *BestPhi = nullptr; @@ -2226,13 +2198,15 @@ // We explicitly allow unknown phis as long as they are already used by // the loop test. In this case we assume that performing LFTR could not // increase the number of undef users. - if (ICmpInst *Cond = getLoopTest(L)) { - if (Phi != getLoopPhiForCounter(Cond->getOperand(0), L) && - Phi != getLoopPhiForCounter(Cond->getOperand(1), L)) { - continue; - } - } + // TODO: Generlize this to allow *any* loop exit which is known to + // execute on each iteration + if (L->getExitingBlock()) + if (ICmpInst *Cond = getLoopTest(L, ExitingBB)) + if (Phi != getLoopPhiForCounter(Cond->getOperand(0), L) && + Phi != getLoopPhiForCounter(Cond->getOperand(1), L)) + continue; } + const SCEV *Init = AR->getStart(); if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) { @@ -2261,7 +2235,8 @@ /// Insert an IR expression which computes the value held by the IV IndVar /// (which must be an loop counter w/unit stride) after the backedge of loop L /// is taken IVCount times. -static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, +static Value *genLoopLimit(PHINode *IndVar, BasicBlock *ExitingBB, + const SCEV *IVCount, Loop *L, SCEVExpander &Rewriter, ScalarEvolution *SE) { assert(isLoopCounter(IndVar, L, SE)); const SCEVAddRecExpr *AR = cast(SE->getSCEV(IndVar)); @@ -2284,13 +2259,13 @@ // Expand the code for the iteration count. assert(SE->isLoopInvariant(IVOffset, L) && "Computed iteration count is not loop invariant!"); - BranchInst *BI = cast(L->getExitingBlock()->getTerminator()); + BranchInst *BI = cast(ExitingBB->getTerminator()); Value *GEPOffset = Rewriter.expandCodeFor(IVOffset, OfsTy, BI); Value *GEPBase = IndVar->getIncomingValueForBlock(L->getLoopPreheader()); assert(AR->getStart() == SE->getSCEV(GEPBase) && "bad loop counter"); // We could handle pointer IVs other than i8*, but we need to compensate for - // gep index scaling. See canExpandBackedgeTakenCount comments. + // gep index scaling. assert(SE->getSizeOfExpr(IntegerType::getInt64Ty(IndVar->getContext()), cast(GEPBase->getType()) ->getElementType())->isOne() && @@ -2328,7 +2303,7 @@ IVLimit = SE->getAddExpr(IVInit, IVCount); } // Expand the code for the iteration count. - BranchInst *BI = cast(L->getExitingBlock()->getTerminator()); + BranchInst *BI = cast(ExitingBB->getTerminator()); IRBuilder<> Builder(BI); assert(SE->isLoopInvariant(IVLimit, L) && "Computed iteration count is not loop invariant!"); @@ -2347,10 +2322,9 @@ /// determine a loop-invariant trip count of the loop, which is actually a much /// broader range than just linear tests. bool IndVarSimplify:: -linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, +linearFunctionTestReplace(Loop *L, BasicBlock *ExitingBB, + const SCEV *BackedgeTakenCount, PHINode *IndVar, SCEVExpander &Rewriter) { - assert(canExpandBackedgeTakenCount(L, SE, Rewriter) && "precondition"); - // Initialize CmpIndVar and IVCount to their preincremented values. Value *CmpIndVar = IndVar; const SCEV *IVCount = BackedgeTakenCount; @@ -2360,7 +2334,7 @@ // If the exiting block is the same as the backedge block, we prefer to // compare against the post-incremented value, otherwise we must compare // against the preincremented value. - if (L->getExitingBlock() == L->getLoopLatch()) { + if (ExitingBB == L->getLoopLatch()) { // Add one to the "backedge-taken" count to get the trip count. // This addition may overflow, which is valid as long as the comparison is // truncated to BackedgeTakenCount->getType(). @@ -2369,16 +2343,16 @@ // The BackedgeTaken expression contains the number of times that the // backedge branches to the loop header. This is one less than the // number of times the loop executes, so use the incremented indvar. - CmpIndVar = IndVar->getIncomingValueForBlock(L->getExitingBlock()); + CmpIndVar = IndVar->getIncomingValueForBlock(ExitingBB); } - Value *ExitCnt = genLoopLimit(IndVar, IVCount, L, Rewriter, SE); + Value *ExitCnt = genLoopLimit(IndVar, ExitingBB, IVCount, L, Rewriter, SE); assert(ExitCnt->getType()->isPointerTy() == IndVar->getType()->isPointerTy() && "genLoopLimit missed a cast"); // Insert a new icmp_ne or icmp_eq instruction before the branch. - BranchInst *BI = cast(L->getExitingBlock()->getTerminator()); + BranchInst *BI = cast(ExitingBB->getTerminator()); ICmpInst::Predicate P; if (L->contains(BI->getSuccessor(0))) P = ICmpInst::ICMP_NE; @@ -2567,6 +2541,13 @@ return MadeAnyChanges; } +/// If we have a unique latch block which is also an exiting block from the +/// loop, return it. Otherwise, return null. +static BasicBlock* getLoopExitingLatch(Loop *L) { + BasicBlock *Latch = L->getLoopLatch(); + return Latch && L->isLoopExiting(Latch) ? Latch : nullptr; +} + //===----------------------------------------------------------------------===// // IndVarSimplify driver. Manage several subpasses of IV simplification. //===----------------------------------------------------------------------===// @@ -2624,22 +2605,58 @@ NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts); // If we have a trip count expression, rewrite the loop's exit condition - // using it. We can currently only handle loops with a single exit. - if (!DisableLFTR && canExpandBackedgeTakenCount(L, SE, Rewriter) && - needsLFTR(L)) { - PHINode *IndVar = FindLoopCounter(L, BackedgeTakenCount, SE); - if (IndVar) { + // using it. + if (!DisableLFTR) { + // For the moment, we only allow when either a) we have a single exit + // loop, or b) we're considering the latch exit of a multiple exit loop. + // TODO: Expand this to all loop exit blocks. + SmallVector ExitingBlocks; + if (auto *ExitingBB = L->getExitingBlock()) + ExitingBlocks.push_back(ExitingBB); + // NOTE TO REVIEWERS: When actually landing, I'm going to commit without + // this next if block (to form an NFC patch), and then land this line and + // the helper function separately. I want any regressions caused to be + // clearly separated between the NFC refactoring and the increase in scope. + if (auto *ExitingBB = getLoopExitingLatch(L)) + ExitingBlocks.push_back(ExitingBB); + for (BasicBlock *ExitingBB : ExitingBlocks) { + // Can't rewrite non-branch yet. + if (!isa(ExitingBB->getTerminator())) + continue; + + if (!needsLFTR(L, ExitingBB)) + continue; + + const SCEV *BETakenCount = SE->getExitCount(L, ExitingBB); + if (isa(BETakenCount)) + continue; + + // Better to fold to true (TODO: do so!) + if (BETakenCount->isZero()) + continue; + + PHINode *IndVar = FindLoopCounter(L, ExitingBB, BETakenCount, SE); + if (!IndVar) + continue; + + // Avoid high cost expansions. Note: This heuristic is questionable in + // that our definition of "high cost" is not exactly principled. + if (Rewriter.isHighCostExpansion(BETakenCount, L)) + continue; + // Check preconditions for proper SCEVExpander operation. SCEV does not - // express SCEVExpander's dependencies, such as LoopSimplify. Instead any - // pass that uses the SCEVExpander must do it. This does not work well for - // loop passes because SCEVExpander makes assumptions about all loops, - // while LoopPassManager only forces the current loop to be simplified. + // express SCEVExpander's dependencies, such as LoopSimplify. Instead + // any pass that uses the SCEVExpander must do it. This does not work + // well for loop passes because SCEVExpander makes assumptions about + // all loops, while LoopPassManager only forces the current loop to be + // simplified. // // FIXME: SCEV expansion has no way to bail out, so the caller must // explicitly check any assumptions made by SCEV. Brittle. - const SCEVAddRecExpr *AR = dyn_cast(BackedgeTakenCount); + const SCEVAddRecExpr *AR = dyn_cast(BETakenCount); if (!AR || AR->getLoop()->getLoopPreheader()) - Changed |= linearFunctionTestReplace(L, BackedgeTakenCount, IndVar, + Changed |= linearFunctionTestReplace(L, ExitingBB, + BETakenCount, IndVar, Rewriter); } } Index: test/Transforms/IndVarSimplify/lftr-multi-exit.ll =================================================================== --- test/Transforms/IndVarSimplify/lftr-multi-exit.ll +++ test/Transforms/IndVarSimplify/lftr-multi-exit.ll @@ -20,8 +20,8 @@ ; CHECK: latch: ; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i32 [[IV]], 1 ; CHECK-NEXT: store i32 [[IV]], i32* @A -; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[IV_NEXT]], 1000 -; CHECK-NEXT: br i1 [[C]], label [[LOOP]], label [[EXIT]] +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i32 [[IV_NEXT]], 1000 +; CHECK-NEXT: br i1 [[EXITCOND]], label [[LOOP]], label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void ; @@ -55,8 +55,8 @@ ; CHECK: latch: ; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i32 [[IV]], 1 ; CHECK-NEXT: store i32 [[IV]], i32* @A -; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[IV_NEXT]], 1000 -; CHECK-NEXT: br i1 [[C]], label [[LOOP]], label [[EXIT]] +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i32 [[IV_NEXT]], 1000 +; CHECK-NEXT: br i1 [[EXITCOND]], label [[LOOP]], label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void ; @@ -95,8 +95,8 @@ ; CHECK: latch: ; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i32 [[IV]], 1 ; CHECK-NEXT: store volatile i32 [[IV]], i32* @A -; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[IV_NEXT]], 1000 -; CHECK-NEXT: br i1 [[C]], label [[LOOP]], label [[EXIT]] +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i32 [[IV_NEXT]], 1000 +; CHECK-NEXT: br i1 [[EXITCOND]], label [[LOOP]], label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void ; @@ -138,8 +138,8 @@ ; CHECK: latch: ; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i32 [[IV]], 1 ; CHECK-NEXT: store volatile i32 [[IV]], i32* @A -; CHECK-NEXT: [[C:%.*]] = icmp ult i32 [[IV_NEXT]], 1000 -; CHECK-NEXT: br i1 [[C]], label [[LOOP]], label [[EXIT]] +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i32 [[IV_NEXT]], 1000 +; CHECK-NEXT: br i1 [[EXITCOND]], label [[LOOP]], label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void ;