Index: lib/Analysis/LoopInfo.cpp =================================================================== --- lib/Analysis/LoopInfo.cpp +++ lib/Analysis/LoopInfo.cpp @@ -218,20 +218,13 @@ } else { assert(!getLoopLatch() && "The loop should have no single latch at this point"); - // Go through each predecessor of the loop header and check the + // Go through each loop latch and check the // terminator for the metadata. - BasicBlock *H = getHeader(); - for (BasicBlock *BB : this->blocks()) { + SmallVector LoopLatches; + this->getLoopLatches(LoopLatches); + for (BasicBlock *BB : LoopLatches) { TerminatorInst *TI = BB->getTerminator(); - MDNode *MD = nullptr; - - // Check if this terminator branches to the loop header. - for (BasicBlock *Successor : TI->successors()) { - if (Successor == H) { - MD = TI->getMetadata(LLVMContext::MD_loop); - break; - } - } + MDNode *MD = TI->getMetadata(LLVMContext::MD_loop); if (!MD) return nullptr; @@ -259,13 +252,10 @@ assert(!getLoopLatch() && "The loop should have no single latch at this point"); - BasicBlock *H = getHeader(); - for (BasicBlock *BB : this->blocks()) { - TerminatorInst *TI = BB->getTerminator(); - for (BasicBlock *Successor : TI->successors()) { - if (Successor == H) - TI->setMetadata(LLVMContext::MD_loop, LoopID); - } + SmallVector LoopLatches; + this->getLoopLatches(LoopLatches); + for (BasicBlock *BB : LoopLatches) { + BB->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID); } } Index: unittests/Analysis/LoopInfoTest.cpp =================================================================== --- unittests/Analysis/LoopInfoTest.cpp +++ unittests/Analysis/LoopInfoTest.cpp @@ -82,6 +82,58 @@ }); } +TEST(LoopInfoTest, LoopWithMultipleLatches) { + const char *ModuleStr = + "define void @foo(i1 %cond, i32 %n) {\n" + "entry:\n" + " br i1 undef, label %header, label %for.end\n" + "header:\n" + " %i.0 = phi i32 [ 0, %entry ], [ %inc, %latch1 ], [ %inc, %latch2 ]\n" + " %cmp = icmp slt i32 %i.0, %n\n" + " br i1 %cmp, label %for.inc, label %for.end\n" + "for.inc:\n" + " %inc = add nsw i32 %i.0, 1\n" + " br i1 %cond, label %latch1, label %latch2\n" + "latch1:\n" + " br label %header\n" + "latch2:\n" + " br label %header\n" + "for.end:\n" + " ret void\n" + "}\n"; + LLVMContext Context; + std::unique_ptr M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfo(*M, "foo", [&](Function &F, LoopInfo &LI) { + BasicBlock *header = nullptr, *latch1 = nullptr, *latch2 = nullptr; + // Iterate over the function and find our interersting blocks + for (BasicBlock &BB : F) { + if (BB.getName() == "header") + header = &BB; + else if (BB.getName() == "latch1") + latch1 = &BB; + else if (BB.getName() == "latch2") + latch2 = &BB; + } + assert(header && latch1 && latch2); + Loop *L = LI.getLoopFor(header); + EXPECT_TRUE(LI.isLoopHeader(header)); + // We start out with no loop ID + EXPECT_EQ(L->getLoopID(), nullptr); + // Now set a loop ID - This will add the metadata to each + // loop latch. + MDNode *LoopID = MDNode::get(Context, ArrayRef(NULL)); + LoopID->replaceOperandWith(0, LoopID); + L->setLoopID(LoopID); + // Make sure this took effect properly and can be retrieved with + EXPECT_EQ(latch1->getTerminator()->getMetadata(LLVMContext::MD_loop), + LoopID); + EXPECT_EQ(latch2->getTerminator()->getMetadata(LLVMContext::MD_loop), + LoopID); + EXPECT_EQ(L->getLoopID(), LoopID); + }); +} + TEST(LoopInfoTest, PreorderTraversals) { const char *ModuleStr = "define void @f() {\n" "entry:\n"