Index: llvm/include/llvm/Analysis/LoopInfo.h =================================================================== --- llvm/include/llvm/Analysis/LoopInfo.h +++ llvm/include/llvm/Analysis/LoopInfo.h @@ -274,6 +274,14 @@ /// dedicated exits. void getUniqueExitBlocks(SmallVectorImpl &ExitBlocks) const; + /// Return all unique successor blocks of this loop except successors from + /// Latch block are not considered. If the exit comes from Latch has also + /// non Latch predecessor in a loop it will be added to ExitBlocks. + /// These are the blocks _outside of the current loop_ which are branched to. + /// This assumes that loop exits are in canonical form, i.e. all exits are + /// dedicated exits. + void getUniqueNonLatchExitBlocks(SmallVectorImpl &ExitBlocks) const; + /// If getUniqueExitBlocks would return exactly one block, return that block. /// Otherwise return null. BlockT *getUniqueExitBlock() const; Index: llvm/include/llvm/Analysis/LoopInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/LoopInfoImpl.h +++ llvm/include/llvm/Analysis/LoopInfoImpl.h @@ -95,21 +95,25 @@ return true; } -template -void LoopBase::getUniqueExitBlocks( - SmallVectorImpl &ExitBlocks) const { +// Helper function to get unique loop exits. Pred is a predicate pointing to +// BasicBlocks in a loop which should be considered to find loop exits. +template +void getUniqueExitBlocksHelper(const LoopT *L, + SmallVectorImpl &ExitBlocks, + PredicateT Pred) { typedef GraphTraits BlockTraits; typedef GraphTraits> InvBlockTraits; - assert(hasDedicatedExits() && + assert(L->hasDedicatedExits() && "getUniqueExitBlocks assumes the loop has canonical form exits!"); SmallVector SwitchExitBlocks; - for (BlockT *Block : this->blocks()) { + auto Filtered = make_filter_range(L->blocks(), Pred); + for (BlockT *Block : Filtered) { SwitchExitBlocks.clear(); for (BlockT *Successor : children(Block)) { // If block is inside the loop then it is not an exit block. - if (contains(Successor)) + if (L->contains(Successor)) continue; BlockT *FirstPred = *InvBlockTraits::child_begin(Successor); @@ -141,6 +145,22 @@ } template +void LoopBase::getUniqueExitBlocks( + SmallVectorImpl &ExitBlocks) const { + getUniqueExitBlocksHelper(this, ExitBlocks, + [](const BlockT *BB) { return true; }); +} + +template +void LoopBase::getUniqueNonLatchExitBlocks( + SmallVectorImpl &ExitBlocks) const { + const BlockT *Latch = getLoopLatch(); + assert(Latch && "Latch block must exists"); + getUniqueExitBlocksHelper(this, ExitBlocks, + [Latch](const BlockT *BB) { return BB != Latch; }); +} + +template BlockT *LoopBase::getUniqueExitBlock() const { SmallVector UniqueExitBlocks; getUniqueExitBlocks(UniqueExitBlocks); Index: llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp =================================================================== --- llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -424,10 +424,9 @@ /// Returns true if we can safely unroll a multi-exit/exiting loop. OtherExits /// is populated with all the loop exit blocks other than the LatchExit block. -static bool -canSafelyUnrollMultiExitLoop(Loop *L, SmallVectorImpl &OtherExits, - BasicBlock *LatchExit, bool PreserveLCSSA, - bool UseEpilogRemainder) { +static bool canSafelyUnrollMultiExitLoop(Loop *L, BasicBlock *LatchExit, + bool PreserveLCSSA, + bool UseEpilogRemainder) { // We currently have some correctness constrains in unrolling a multi-exit // loop. Check for these below. @@ -435,11 +434,6 @@ // We rely on LCSSA form being preserved when the exit blocks are transformed. if (!PreserveLCSSA) return false; - SmallVector Exits; - L->getUniqueExitBlocks(Exits); - for (auto *BB : Exits) - if (BB != LatchExit) - OtherExits.push_back(BB); // TODO: Support multiple exiting blocks jumping to the `LatchExit` when // UnrollRuntimeMultiExit is true. This will need updating the logic in @@ -469,9 +463,8 @@ bool PreserveLCSSA, bool UseEpilogRemainder) { #if !defined(NDEBUG) - SmallVector OtherExitsDummyCheck; - assert(canSafelyUnrollMultiExitLoop(L, OtherExitsDummyCheck, LatchExit, - PreserveLCSSA, UseEpilogRemainder) && + assert(canSafelyUnrollMultiExitLoop(L, LatchExit, PreserveLCSSA, + UseEpilogRemainder) && "Should be safe to unroll before checking profitability!"); #endif @@ -595,8 +588,9 @@ // These are exit blocks other than the target of the latch exiting block. SmallVector OtherExits; + L->getUniqueNonLatchExitBlocks(OtherExits); bool isMultiExitUnrollingEnabled = - canSafelyUnrollMultiExitLoop(L, OtherExits, LatchExit, PreserveLCSSA, + canSafelyUnrollMultiExitLoop(L, LatchExit, PreserveLCSSA, UseEpilogRemainder) && canProfitablyUnrollMultiExitLoop(L, OtherExits, LatchExit, PreserveLCSSA, UseEpilogRemainder); Index: llvm/unittests/Analysis/LoopInfoTest.cpp =================================================================== --- llvm/unittests/Analysis/LoopInfoTest.cpp +++ llvm/unittests/Analysis/LoopInfoTest.cpp @@ -1110,3 +1110,49 @@ L->isAuxiliaryInductionVariable(Instruction_mulopcode, SE)); }); } + +// Examine getUniqueExitBlocks/getUniqueNonLatchExitBlocks functions. +TEST(LoopInfoTest, LoopUniqueExitBlocks) { + const char *ModuleStr = + "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + "define void @foo(i32 %n, i1 %cond) {\n" + "entry:\n" + " br label %for.cond\n" + "for.cond:\n" + " %i.0 = phi i32 [ 0, %entry ], [ %inc, %for.inc ]\n" + " %cmp = icmp slt i32 %i.0, %n\n" + " br i1 %cond, label %for.inc, label %for.end1\n" + "for.inc:\n" + " %inc = add nsw i32 %i.0, 1\n" + " br i1 %cmp, label %for.cond, label %for.end2, !llvm.loop !0\n" + "for.end1:\n" + " br label %for.end\n" + "for.end2:\n" + " br label %for.end\n" + "for.end:\n" + " ret void\n" + "}\n" + "!0 = distinct !{!0, !1}\n" + "!1 = !{!\"llvm.loop.distribute.enable\", i1 true}\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfo(*M, "foo", [&](Function &F, LoopInfo &LI) { + Function::iterator FI = F.begin(); + // First basic block is entry - skip it. + BasicBlock *Header = &*(++FI); + assert(Header->getName() == "for.cond"); + Loop *L = LI.getLoopFor(Header); + + SmallVector Exits; + // This loop has 2 unique exits. + L->getUniqueExitBlocks(Exits); + EXPECT_TRUE(Exits.size() == 2); + // And one unique non latch exit. + Exits.clear(); + L->getUniqueNonLatchExitBlocks(Exits); + EXPECT_TRUE(Exits.size() == 1); + }); +}