Index: llvm/lib/Transforms/Utils/LoopUnroll.cpp =================================================================== --- llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -328,6 +328,37 @@ if (MaxTripCount && ULO.Count > MaxTripCount) ULO.Count = MaxTripCount; + struct ExitInfo { + unsigned TripCount; + unsigned TripMultiple; + unsigned BreakoutTrip; + bool ExitOnTrue; + SmallVector ExitingBlocks; + }; + DenseMap ExitInfos; + SmallVector ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + for (auto *ExitingBlock : ExitingBlocks) { + // The folding code is not prepared to deal with non-branch instructions + // right now. + auto *BI = dyn_cast(ExitingBlock->getTerminator()); + if (!BI) + continue; + + ExitInfo &Info = ExitInfos.try_emplace(ExitingBlock).first->second; + Info.TripCount = SE->getSmallConstantTripCount(L, ExitingBlock); + Info.TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); + if (Info.TripCount != 0) { + Info.BreakoutTrip = Info.TripCount % ULO.Count; + Info.TripMultiple = 0; + } else { + Info.BreakoutTrip = Info.TripMultiple = + (unsigned)GreatestCommonDivisor64(ULO.Count, Info.TripMultiple); + } + Info.ExitOnTrue = !L->contains(BI->getSuccessor(0)); + Info.ExitingBlocks.push_back(ExitingBlock); + } + // Are we eliminating the loop control altogether? Note that we can know // we're eliminating the backedge without knowing exactly which iteration // of the unrolled body exits. @@ -362,31 +393,12 @@ // A conditional branch which exits the loop, which can be optimized to an // unconditional branch in the unrolled loop in some cases. - BranchInst *ExitingBI = nullptr; bool LatchIsExiting = L->isLoopExiting(LatchBlock); - if (LatchIsExiting) - ExitingBI = LatchBI; - else if (BasicBlock *ExitingBlock = L->getExitingBlock()) - ExitingBI = dyn_cast(ExitingBlock->getTerminator()); if (!LatchBI || (LatchBI->isConditional() && !LatchIsExiting)) { LLVM_DEBUG( dbgs() << "Can't unroll; a conditional latch must exit the loop"); return LoopUnrollResult::Unmodified; } - LLVM_DEBUG({ - if (ExitingBI) - dbgs() << " Exiting Block = " << ExitingBI->getParent()->getName() - << "\n"; - else - dbgs() << " No single exiting block\n"; - }); - - // Warning: ExactTripCount is the exact trip count for the block ending in - // ExitingBI, not neccessarily an exact exit count *for the loop*. The - // distinction comes when we have an exiting latch, but the loop exits - // through another exit first. - const unsigned ExactTripCount = ExitingBI ? - SE->getSmallConstantTripCount(L,ExitingBI->getParent()) : 0; // Loops containing convergent instructions must have a count that divides // their TripMultiple. @@ -421,6 +433,7 @@ } // If we know the trip count, we know the multiple... + // TODO: This is only used for the ORE code, remove it. unsigned BreakoutTrip = 0; if (ULO.TripCount != 0) { BreakoutTrip = ULO.TripCount % ULO.Count; @@ -504,12 +517,9 @@ } std::vector Headers; - std::vector ExitingBlocks; std::vector Latches; Headers.push_back(Header); Latches.push_back(LatchBlock); - if (ExitingBI) - ExitingBlocks.push_back(ExitingBI->getParent()); // The current on-the-fly SSA update requires blocks to be processed in // reverse postorder so that LastValueMap contains the correct value at each @@ -609,9 +619,9 @@ // Keep track of the exiting block and its successor block contained in // the loop for the current iteration. - if (ExitingBI) - if (*BB == ExitingBlocks[0]) - ExitingBlocks.push_back(New); + auto ExitInfoIt = ExitInfos.find(*BB); + if (ExitInfoIt != ExitInfos.end()) + ExitInfoIt->second.ExitingBlocks.push_back(New); NewBlocks.push_back(New); UnrolledLoopBlocks.push_back(New); @@ -701,71 +711,75 @@ DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - if (ExitingBI) { - auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) { - auto *Term = cast(Src->getTerminator()); - const unsigned Idx = ExitOnTrue ^ WillExit; - BasicBlock *Dest = Term->getSuccessor(Idx); - BasicBlock *DeadSucc = Term->getSuccessor(1-Idx); + auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) { + auto *Term = cast(Src->getTerminator()); + const unsigned Idx = ExitOnTrue ^ WillExit; + BasicBlock *Dest = Term->getSuccessor(Idx); + BasicBlock *DeadSucc = Term->getSuccessor(1-Idx); - // Remove predecessors from all non-Dest successors. - DeadSucc->removePredecessor(Src, /* KeepOneInputPHIs */ true); + // Remove predecessors from all non-Dest successors. + DeadSucc->removePredecessor(Src, /* KeepOneInputPHIs */ true); - // Replace the conditional branch with an unconditional one. - BranchInst::Create(Dest, Term); - Term->eraseFromParent(); + // Replace the conditional branch with an unconditional one. + BranchInst::Create(Dest, Term); + Term->eraseFromParent(); - DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}}); - }; + DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}}); + }; - auto WillExit = [&](unsigned i, unsigned j) -> Optional { - if (CompletelyUnroll) { - if (PreserveOnlyFirst) { - if (i == 0) - return None; - return j == 0; - } - // Complete (but possibly inexact) unrolling - if (j == 0) - return true; - // Warning: ExactTripCount is the trip count of the exiting - // block which ends in ExitingBI, not neccessarily the loop. - if (ExactTripCount && j != ExactTripCount) - return false; - return None; + auto WillExit = [&](const ExitInfo &Info, unsigned i, unsigned j, + bool IsLatch) -> Optional { + if (CompletelyUnroll) { + if (PreserveOnlyFirst) { + if (i == 0) + return None; + return j == 0; } - - if (RuntimeTripCount && j != 0) + // Complete (but possibly inexact) unrolling + if (j == 0) + return true; + if (Info.TripCount && j != Info.TripCount) return false; + return None; + } - if (j != BreakoutTrip && - (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) { - // If we know the trip count or a multiple of it, we can safely use an - // unconditional branch for some iterations. + if (RuntimeTripCount) { + // When runtime unrolling, information about non-latch exits may be + // stale. + if (IsLatch && j != 0) return false; - } return None; - }; + } + + if (j != Info.BreakoutTrip && + (Info.TripMultiple == 0 || j % Info.TripMultiple != 0)) { + // If we know the trip count or a multiple of it, we can safely use an + // unconditional branch for some iterations. + return false; + } + return None; + }; - // Fold branches for iterations where we know that they will exit or not - // exit. - bool ExitOnTrue = !L->contains(ExitingBI->getSuccessor(0)); - for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { + // Fold branches for iterations where we know that they will exit or not + // exit. + for (const auto &Pair : ExitInfos) { + const ExitInfo &Info = Pair.second; + for (unsigned i = 0, e = Info.ExitingBlocks.size(); i != e; ++i) { // The branch destination. unsigned j = (i + 1) % e; - Optional KnownWillExit = WillExit(i, j); + bool IsLatch = Pair.first == LatchBlock; + Optional KnownWillExit = WillExit(Info, i, j, IsLatch); if (!KnownWillExit) continue; // TODO: Also fold known-exiting branches for non-latch exits. - if (*KnownWillExit && !LatchIsExiting) + if (*KnownWillExit && !IsLatch) continue; - SetDest(ExitingBlocks[i], *KnownWillExit, ExitOnTrue); + SetDest(Info.ExitingBlocks[i], *KnownWillExit, Info.ExitOnTrue); } } - // When completely unrolling, the last latch becomes unreachable. if (!LatchIsExiting && CompletelyUnroll) changeToUnreachable(Latches.back()->getTerminator(), /* UseTrap */ false, Index: llvm/test/Transforms/LoopUnroll/multiple-exits.ll =================================================================== --- llvm/test/Transforms/LoopUnroll/multiple-exits.ll +++ llvm/test/Transforms/LoopUnroll/multiple-exits.ll @@ -9,49 +9,49 @@ ; CHECK-NEXT: br label [[LOOP:%.*]] ; CHECK: loop: ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH:%.*]], label [[EXIT:%.*]] +; CHECK-NEXT: br label [[LATCH:%.*]] ; CHECK: latch: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH_1:%.*]], label [[EXIT]] +; CHECK-NEXT: br label [[LATCH_1:%.*]] ; CHECK: exit: ; CHECK-NEXT: ret void ; CHECK: latch.1: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH_2:%.*]], label [[EXIT]] +; CHECK-NEXT: br label [[LATCH_2:%.*]] ; CHECK: latch.2: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH_3:%.*]], label [[EXIT]] +; CHECK-NEXT: br label [[LATCH_3:%.*]] ; CHECK: latch.3: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH_4:%.*]], label [[EXIT]] +; CHECK-NEXT: br label [[LATCH_4:%.*]] ; CHECK: latch.4: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH_5:%.*]], label [[EXIT]] +; CHECK-NEXT: br label [[LATCH_5:%.*]] ; CHECK: latch.5: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH_6:%.*]], label [[EXIT]] +; CHECK-NEXT: br label [[LATCH_6:%.*]] ; CHECK: latch.6: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH_7:%.*]], label [[EXIT]] +; CHECK-NEXT: br label [[LATCH_7:%.*]] ; CHECK: latch.7: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH_8:%.*]], label [[EXIT]] +; CHECK-NEXT: br label [[LATCH_8:%.*]] ; CHECK: latch.8: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 true, label [[LATCH_9:%.*]], label [[EXIT]] +; CHECK-NEXT: br label [[LATCH_9:%.*]] ; CHECK: latch.9: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: call void @bar() -; CHECK-NEXT: br i1 false, label [[LATCH_10:%.*]], label [[EXIT]] +; CHECK-NEXT: br i1 false, label [[LATCH_10:%.*]], label [[EXIT:%.*]] ; CHECK: latch.10: ; CHECK-NEXT: call void @bar() ; CHECK-NEXT: br label [[EXIT]] Index: llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll =================================================================== --- llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll +++ llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll @@ -168,7 +168,7 @@ ; CHECK-NEXT: call void @bar(i32 [[TMP0]]) ; CHECK-NEXT: br i1 [[COND:%.*]], label [[FOR_BODY:%.*]], label [[FOR_END:%.*]] ; CHECK: for.body: -; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]], label [[FOR_END]] +; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]] ; CHECK: for.body.for.body_crit_edge: ; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 1 ; CHECK-NEXT: [[DOTPRE:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT]], align 4 @@ -177,14 +177,14 @@ ; CHECK: for.end: ; CHECK-NEXT: ret void ; CHECK: for.body.1: -; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]], label [[FOR_END]] +; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]] ; CHECK: for.body.for.body_crit_edge.1: ; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_1:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 2 ; CHECK-NEXT: [[DOTPRE_1:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_1]], align 4 ; CHECK-NEXT: call void @bar(i32 [[DOTPRE_1]]) ; CHECK-NEXT: br i1 [[COND]], label [[FOR_BODY_2:%.*]], label [[FOR_END]] ; CHECK: for.body.2: -; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]], label [[FOR_END]] +; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]] ; CHECK: for.body.for.body_crit_edge.2: ; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_2:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 3 ; CHECK-NEXT: [[DOTPRE_2:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_2]], align 4