diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -396,7 +396,26 @@ BasicBlock *GuardOtherSucc = (GuardBI->getSuccessor(0) == Preheader) ? GuardBI->getSuccessor(1) : GuardBI->getSuccessor(0); - return (GuardOtherSucc == ExitFromLatchSucc) ? GuardBI : nullptr; + if (GuardOtherSucc != ExitFromLatchSucc) + return nullptr; + + auto IsSafeToSpeculativelyExecute = [](const BasicBlock &BB) { + return llvm::all_of(BB, [&BB](const Instruction &I) { + if (BB.getTerminator() == &I) + return true; + + if (const PHINode *PN = dyn_cast(&I)) + if (PN->hasConstantValue()) + return true; + + return isSafeToSpeculativelyExecute(&I); + }); + }; + + return (IsSafeToSpeculativelyExecute(*Preheader) && + IsSafeToSpeculativelyExecute(*ExitFromLatch)) + ? GuardBI + : nullptr; } bool Loop::isCanonical(ScalarEvolution &SE) const { diff --git a/llvm/unittests/Analysis/LoopInfoTest.cpp b/llvm/unittests/Analysis/LoopInfoTest.cpp --- a/llvm/unittests/Analysis/LoopInfoTest.cpp +++ b/llvm/unittests/Analysis/LoopInfoTest.cpp @@ -1373,6 +1373,96 @@ }); } +TEST(LoopInfoTest, LoopPreheaderNotSafe) { + const char *ModuleStr = + "define void @foo(i64 %N) {\n" + "entry:\n" + " %guard = icmp slt i64 0, %N\n" + " br i1 %guard, label %for.preheader, label %for.end\n" + "for.preheader:\n" + " call void @bar()\n" + " br label %for.body\n" + "for.body:\n" + " %i = phi i64 [ %inc, %for.body ], [ 0, %for.preheader ]\n" + " call void @bar()\n" + " %inc = add nsw i64 %i, 1\n" + " %cmp = icmp slt i64 %inc, %N\n" + " br i1 %cmp, label %for.body, label %for.exit\n" + "for.exit:\n" + " br label %for.end\n" + "for.end:\n" + " ret void\n" + "}\n" + "declare void @bar()\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfoPlus( + *M, "foo", + [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + Function::iterator FI = F.begin(); + // First two basic block are entry and for.preheader - skip them. + ++FI; + BasicBlock *Header = &*(++FI); + assert(Header && "No header"); + Loop *L = LI.getLoopFor(Header); + EXPECT_NE(L, nullptr); + EXPECT_TRUE(L->isLoopSimplifyForm()); + + // No loop guard becuase loop preheader contains instructions not safe + //to execute speculatively. + EXPECT_EQ(L->getLoopGuardBranch(), nullptr); + EXPECT_FALSE(L->isGuarded()); + }); +} + +TEST(LoopInfoTest, LoopExitNotSafe) { + const char *ModuleStr = + "define void @foo(i64 %N) {\n" + "entry:\n" + " %guard = icmp slt i64 0, %N\n" + " br i1 %guard, label %for.preheader, label %for.end\n" + "for.preheader:\n" + " br label %for.body\n" + "for.body:\n" + " %i = phi i64 [ %inc, %for.body ], [ 0, %for.preheader ]\n" + " call void @bar()\n" + " %inc = add nsw i64 %i, 1\n" + " %cmp = icmp slt i64 %inc, %N\n" + " br i1 %cmp, label %for.body, label %for.exit\n" + "for.exit:\n" + " call void @bar()\n" + " br label %for.end\n" + "for.end:\n" + " ret void\n" + "}\n" + "declare void @bar()\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfoPlus( + *M, "foo", + [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) { + Function::iterator FI = F.begin(); + // First two basic block are entry and for.preheader - skip them. + ++FI; + BasicBlock *Header = &*(++FI); + assert(Header && "No header"); + Loop *L = LI.getLoopFor(Header); + EXPECT_NE(L, nullptr); + EXPECT_TRUE(L->isLoopSimplifyForm()); + + // No loop guard becuase loop exit contains instructions not safe to + // execute speculatively. + EXPECT_EQ(L->getLoopGuardBranch(), nullptr); + EXPECT_FALSE(L->isGuarded()); + }); +} + // Examine getUniqueExitBlocks/getUniqueNonLatchExitBlocks functions. TEST(LoopInfoTest, LoopUniqueExitBlocks) { const char *ModuleStr =