diff --git a/llvm/include/llvm/Analysis/LoopInfo.h b/llvm/include/llvm/Analysis/LoopInfo.h --- a/llvm/include/llvm/Analysis/LoopInfo.h +++ b/llvm/include/llvm/Analysis/LoopInfo.h @@ -270,16 +270,12 @@ /// Return all unique successor blocks of this loop. /// These are the blocks _outside of the current loop_ which are branched to. - /// This assumes that loop exits are in canonical form, i.e. all exits are - /// dedicated exits. void getUniqueExitBlocks(SmallVectorImpl &ExitBlocks) const; /// Return all unique successor blocks of this loop except successors from /// Latch block are not considered. If the exit comes from Latch has also /// non Latch predecessor in a loop it will be added to ExitBlocks. /// These are the blocks _outside of the current loop_ which are branched to. - /// This assumes that loop exits are in canonical form, i.e. all exits are - /// dedicated exits. void getUniqueNonLatchExitBlocks(SmallVectorImpl &ExitBlocks) const; /// If getUniqueExitBlocks would return exactly one block, return that block. diff --git a/llvm/include/llvm/Analysis/LoopInfoImpl.h b/llvm/include/llvm/Analysis/LoopInfoImpl.h --- a/llvm/include/llvm/Analysis/LoopInfoImpl.h +++ b/llvm/include/llvm/Analysis/LoopInfoImpl.h @@ -101,47 +101,14 @@ void getUniqueExitBlocksHelper(const LoopT *L, SmallVectorImpl &ExitBlocks, PredicateT Pred) { - typedef GraphTraits BlockTraits; - typedef GraphTraits> InvBlockTraits; - - assert(L->hasDedicatedExits() && - "getUniqueExitBlocks assumes the loop has canonical form exits!"); - - SmallVector SwitchExitBlocks; + assert(!L->isInvalid() && "Loop not in a valid state!"); + SmallPtrSet Visited; auto Filtered = make_filter_range(L->blocks(), Pred); - for (BlockT *Block : Filtered) { - SwitchExitBlocks.clear(); - for (BlockT *Successor : children(Block)) { - // If block is inside the loop then it is not an exit block. - if (L->contains(Successor)) - continue; - - BlockT *FirstPred = *InvBlockTraits::child_begin(Successor); - - // If current basic block is this exit block's first predecessor then only - // insert exit block in to the output ExitBlocks vector. This ensures that - // same exit block is not inserted twice into ExitBlocks vector. - if (Block != FirstPred) - continue; - - // If a terminator has more then two successors, for example SwitchInst, - // then it is possible that there are multiple edges from current block to - // one exit block. - if (std::distance(BlockTraits::child_begin(Block), - BlockTraits::child_end(Block)) <= 2) { - ExitBlocks.push_back(Successor); - continue; - } - - // In case of multiple edges from current block to exit block, collect - // only one edge in ExitBlocks. Use switchExitBlocks to keep track of - // duplicate edges. - if (!is_contained(SwitchExitBlocks, Successor)) { - SwitchExitBlocks.push_back(Successor); - ExitBlocks.push_back(Successor); - } - } - } + for (BlockT *BB : Filtered) + for (BlockT *Successor : children(BB)) + if (!L->contains(Successor)) + if (Visited.insert(Successor).second) + ExitBlocks.push_back(Successor); } template diff --git a/llvm/unittests/Analysis/LoopInfoTest.cpp b/llvm/unittests/Analysis/LoopInfoTest.cpp --- a/llvm/unittests/Analysis/LoopInfoTest.cpp +++ b/llvm/unittests/Analysis/LoopInfoTest.cpp @@ -1156,3 +1156,46 @@ EXPECT_TRUE(Exits.size() == 1); }); } + +// Regression test for getUniqueNonLatchExitBlocks functions. +// It should detect the exit if it comes from both latch and non-latch blocks. +TEST(LoopInfoTest, LoopNonLatchUniqueExitBlocks) { + const char *ModuleStr = + "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + "define void @foo(i32 %n, i1 %cond) {\n" + "entry:\n" + " br label %for.cond\n" + "for.cond:\n" + " %i.0 = phi i32 [ 0, %entry ], [ %inc, %for.inc ]\n" + " %cmp = icmp slt i32 %i.0, %n\n" + " br i1 %cond, label %for.inc, label %for.end\n" + "for.inc:\n" + " %inc = add nsw i32 %i.0, 1\n" + " br i1 %cmp, label %for.cond, label %for.end, !llvm.loop !0\n" + "for.end:\n" + " ret void\n" + "}\n" + "!0 = distinct !{!0, !1}\n" + "!1 = !{!\"llvm.loop.distribute.enable\", i1 true}\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfo(*M, "foo", [&](Function &F, LoopInfo &LI) { + Function::iterator FI = F.begin(); + // First basic block is entry - skip it. + BasicBlock *Header = &*(++FI); + assert(Header->getName() == "for.cond"); + Loop *L = LI.getLoopFor(Header); + + SmallVector Exits; + // This loop has 1 unique exit. + L->getUniqueExitBlocks(Exits); + EXPECT_TRUE(Exits.size() == 1); + // And one unique non latch exit. + Exits.clear(); + L->getUniqueNonLatchExitBlocks(Exits); + EXPECT_TRUE(Exits.size() == 1); + }); +}