diff --git a/llvm/include/llvm/Analysis/LoopNestAnalysis.h b/llvm/include/llvm/Analysis/LoopNestAnalysis.h --- a/llvm/include/llvm/Analysis/LoopNestAnalysis.h +++ b/llvm/include/llvm/Analysis/LoopNestAnalysis.h @@ -61,10 +61,12 @@ static unsigned getMaxPerfectDepth(const Loop &Root, ScalarEvolution &SE); /// Recursivelly traverse all empty 'single successor' basic blocks of \p From - /// (if there are any). Return the last basic block found or \p End if it was - /// reached during the search. + /// (if there are any). When \p CheckUniquePred is set to true, check if + /// each of the empty single successors has an unique predecessor. Return + /// the last basic block found or \p End if it was reached during the search. static const BasicBlock &skipEmptyBlockUntil(const BasicBlock *From, - const BasicBlock *End); + const BasicBlock *End, + bool CheckUniquePred = false); /// Return the outermost loop in the loop nest. Loop &getOutermostLoop() const { return *Loops.front(); } diff --git a/llvm/lib/Analysis/LoopInfo.cpp b/llvm/lib/Analysis/LoopInfo.cpp --- a/llvm/lib/Analysis/LoopInfo.cpp +++ b/llvm/lib/Analysis/LoopInfo.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/IVDescriptors.h" #include "llvm/Analysis/LoopInfoImpl.h" #include "llvm/Analysis/LoopIterator.h" +#include "llvm/Analysis/LoopNestAnalysis.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" @@ -380,10 +381,6 @@ if (!ExitFromLatch) return nullptr; - BasicBlock *ExitFromLatchSucc = ExitFromLatch->getUniqueSuccessor(); - if (!ExitFromLatchSucc) - return nullptr; - BasicBlock *GuardBB = Preheader->getUniquePredecessor(); if (!GuardBB) return nullptr; @@ -397,7 +394,17 @@ BasicBlock *GuardOtherSucc = (GuardBI->getSuccessor(0) == Preheader) ? GuardBI->getSuccessor(1) : GuardBI->getSuccessor(0); - return (GuardOtherSucc == ExitFromLatchSucc) ? GuardBI : nullptr; + + // Check if ExitFromLatch (or any BasicBlock which is an empty unique + // successor of ExitFromLatch) is equal to GuardOtherSucc. If + // skipEmptyBlockUntil returns GuardOtherSucc, then the guard branch for the + // loop is GuardBI (return GuardBI), otherwise return nullptr. + if (&LoopNest::skipEmptyBlockUntil(ExitFromLatch, GuardOtherSucc, + /*CheckUniquePred=*/true) == + GuardOtherSucc) + return GuardBI; + else + return nullptr; } bool Loop::isCanonical(ScalarEvolution &SE) const { diff --git a/llvm/lib/Analysis/LoopNestAnalysis.cpp b/llvm/lib/Analysis/LoopNestAnalysis.cpp --- a/llvm/lib/Analysis/LoopNestAnalysis.cpp +++ b/llvm/lib/Analysis/LoopNestAnalysis.cpp @@ -206,7 +206,8 @@ } const BasicBlock &LoopNest::skipEmptyBlockUntil(const BasicBlock *From, - const BasicBlock *End) { + const BasicBlock *End, + bool CheckUniquePred) { assert(From && "Expecting valid From"); assert(End && "Expecting valid End"); @@ -220,8 +221,9 @@ // Visited is used to avoid running into an infinite loop. SmallPtrSet<const BasicBlock *, 4> Visited; const BasicBlock *BB = From->getUniqueSuccessor(); - const BasicBlock *PredBB = BB; - while (BB && BB != End && IsEmpty(BB) && !Visited.count(BB)) { + const BasicBlock *PredBB = From; + while (BB && BB != End && IsEmpty(BB) && !Visited.count(BB) && + (!CheckUniquePred || BB->getUniquePredecessor())) { Visited.insert(BB); PredBB = BB; BB = BB->getUniqueSuccessor(); diff --git a/llvm/test/Analysis/LoopNestAnalysis/imperfectnest.ll b/llvm/test/Analysis/LoopNestAnalysis/imperfectnest.ll --- a/llvm/test/Analysis/LoopNestAnalysis/imperfectnest.ll +++ b/llvm/test/Analysis/LoopNestAnalysis/imperfectnest.ll @@ -424,70 +424,3 @@ for.end13: ret void } - -; Test an imperfect loop nest of the form: -; for (int i = 0; i < nx; ++i) -; if (i > 5) { // user branch -; for (int j = 1; j <= 5; j+=2) -; y[j][i] = x[i][j] + j; -; } - -define void @imperf_nest_6(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) { -; CHECK-LABEL: IsPerfect=false, Depth=2, OutermostLoop: imperf_nest_6_loop_i, Loops: ( imperf_nest_6_loop_i imperf_nest_6_loop_j ) -entry: - %cmp2 = icmp slt i32 0, %nx - br i1 %cmp2, label %imperf_nest_6_loop_i.lr.ph, label %for.end13 - -imperf_nest_6_loop_i.lr.ph: - br label %imperf_nest_6_loop_i - -imperf_nest_6_loop_i: - %i.0 = phi i32 [ 0, %imperf_nest_6_loop_i.lr.ph ], [ %inc12, %for.inc11 ] - %cmp1 = icmp sgt i32 %i.0, 5 - br i1 %cmp1, label %imperf_nest_6_loop_j.lr.ph, label %if.end - -imperf_nest_6_loop_j.lr.ph: - br label %imperf_nest_6_loop_j - -imperf_nest_6_loop_j: - %j.0 = phi i32 [ 1, %imperf_nest_6_loop_j.lr.ph ], [ %inc, %for.inc ] - %idxprom = sext i32 %i.0 to i64 - %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom - %0 = load i32*, i32** %arrayidx, align 8 - %idxprom5 = sext i32 %j.0 to i64 - %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5 - %1 = load i32, i32* %arrayidx6, align 4 - %add = add nsw i32 %1, %j.0 - %idxprom7 = sext i32 %j.0 to i64 - %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7 - %2 = load i32*, i32** %arrayidx8, align 8 - %idxprom9 = sext i32 %i.0 to i64 - %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9 - store i32 %add, i32* %arrayidx10, align 4 - br label %for.inc - -for.inc: - %inc = add nsw i32 %j.0, 2 - %cmp3 = icmp sle i32 %inc, 5 - br i1 %cmp3, label %imperf_nest_6_loop_j, label %for.cond2.for.end_crit_edge - -for.cond2.for.end_crit_edge: - br label %for.end - -for.end: - br label %if.end - -if.end: - br label %for.inc11 - -for.inc11: - %inc12 = add nsw i32 %i.0, 1 - %cmp = icmp slt i32 %inc12, %nx - br i1 %cmp, label %imperf_nest_6_loop_i, label %for.cond.for.end13_crit_edge - -for.cond.for.end13_crit_edge: - br label %for.end13 - -for.end13: - ret void -} diff --git a/llvm/test/Analysis/LoopNestAnalysis/perfectnest.ll b/llvm/test/Analysis/LoopNestAnalysis/perfectnest.ll --- a/llvm/test/Analysis/LoopNestAnalysis/perfectnest.ll +++ b/llvm/test/Analysis/LoopNestAnalysis/perfectnest.ll @@ -322,3 +322,148 @@ %x.addr.0.lcssa = phi i32 [ %split7, %for.cond.for.end7_crit_edge ], [ %x, %entry ] ret i32 %x.addr.0.lcssa } + +; Test a perfect loop nest of the form: +; for (int i = 0; i < nx; ++i) +; if (i < ny) { // guard branch for the j-loop +; for (int j=i; j < ny; j+=1) +; y[j][i] = x[i][j] + j; +; } +define double @perf_nest_guard_branch(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) { +; CHECK-LABEL: IsPerfect=true, Depth=1, OutermostLoop: test6Loop2, Loops: ( test6Loop2 ) +; CHECK-LABEL: IsPerfect=true, Depth=2, OutermostLoop: test6Loop1, Loops: ( test6Loop1 test6Loop2 ) +entry: + %cmp2 = icmp slt i32 0, %nx + br i1 %cmp2, label %test6Loop1.lr.ph, label %for.end13 + +test6Loop1.lr.ph: ; preds = %entry + br label %test6Loop1 + +test6Loop1: ; preds = %test6Loop1.lr.ph, %for.inc11 + %i.0 = phi i32 [ 0, %test6Loop1.lr.ph ], [ %inc12, %for.inc11 ] + %cmp1 = icmp slt i32 %i.0, %ny + br i1 %cmp1, label %test6Loop2.lr.ph, label %if.end + +test6Loop2.lr.ph: ; preds = %if.then + br label %test6Loop2 + +test6Loop2: ; preds = %test6Loop2.lr.ph, %for.inc + %j.0 = phi i32 [ %i.0, %test6Loop2.lr.ph ], [ %inc, %for.inc ] + %idxprom = sext i32 %i.0 to i64 + %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom + %0 = load i32*, i32** %arrayidx, align 8 + %idxprom5 = sext i32 %j.0 to i64 + %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5 + %1 = load i32, i32* %arrayidx6, align 4 + %add = add nsw i32 %1, %j.0 + %idxprom7 = sext i32 %j.0 to i64 + %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7 + %2 = load i32*, i32** %arrayidx8, align 8 + %idxprom9 = sext i32 %i.0 to i64 + %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9 + store i32 %add, i32* %arrayidx10, align 4 + br label %for.inc + +for.inc: ; preds = %test6Loop2 + %inc = add nsw i32 %j.0, 1 + %cmp3 = icmp slt i32 %inc, %ny + br i1 %cmp3, label %test6Loop2, label %for.cond2.for.end_crit_edge + +for.cond2.for.end_crit_edge: ; preds = %for.inc + br label %for.end + +for.end: ; preds = %for.cond2.for.end_crit_edge, %if.then + br label %if.end + +if.end: ; preds = %for.end, %test6Loop1 + br label %for.inc11 + +for.inc11: ; preds = %if.end + %inc12 = add nsw i32 %i.0, 1 + %cmp = icmp slt i32 %inc12, %nx + br i1 %cmp, label %test6Loop1, label %for.cond.for.end13_crit_edge + +for.cond.for.end13_crit_edge: ; preds = %for.inc11 + br label %for.end13 + +for.end13: ; preds = %for.cond.for.end13_crit_edge, %entry + %arrayidx14 = getelementptr inbounds i32*, i32** %y, i64 0 + %3 = load i32*, i32** %arrayidx14, align 8 + %arrayidx15 = getelementptr inbounds i32, i32* %3, i64 0 + %4 = load i32, i32* %arrayidx15, align 4 + %conv = sitofp i32 %4 to double + ret double %conv +} + +; Test a perfect loop nest of the form: +; for (int i = 0; i < nx; ++i) +; if (i < ny) { // guard branch for the j-loop +; for (int j=i; j < ny; j+=1) +; y[j][i] = x[i][j] + j; +; } + +define double @test6(i32** %y, i32** %x, i32 signext %nx, i32 signext %ny) { +; CHECK-LABEL: IsPerfect=true, Depth=1, OutermostLoop: test6Loop2, Loops: ( test6Loop2 ) +; CHECK-LABEL: IsPerfect=true, Depth=2, OutermostLoop: test6Loop1, Loops: ( test6Loop1 test6Loop2 ) +entry: + %cmp2 = icmp slt i32 0, %nx + br i1 %cmp2, label %test6Loop1.lr.ph, label %for.end13 + +test6Loop1.lr.ph: ; preds = %entry + br label %test6Loop1 + +test6Loop1: ; preds = %test6Loop1.lr.ph, %for.inc11 + %i.0 = phi i32 [ 0, %test6Loop1.lr.ph ], [ %inc12, %for.inc11 ] + %cmp1 = icmp slt i32 %i.0, %ny + br i1 %cmp1, label %test6Loop2.lr.ph, label %if.end + +test6Loop2.lr.ph: ; preds = %if.then + br label %test6Loop2 + +test6Loop2: ; preds = %test6Loop2.lr.ph, %for.inc + %j.0 = phi i32 [ %i.0, %test6Loop2.lr.ph ], [ %inc, %for.inc ] + %idxprom = sext i32 %i.0 to i64 + %arrayidx = getelementptr inbounds i32*, i32** %x, i64 %idxprom + %0 = load i32*, i32** %arrayidx, align 8 + %idxprom5 = sext i32 %j.0 to i64 + %arrayidx6 = getelementptr inbounds i32, i32* %0, i64 %idxprom5 + %1 = load i32, i32* %arrayidx6, align 4 + %add = add nsw i32 %1, %j.0 + %idxprom7 = sext i32 %j.0 to i64 + %arrayidx8 = getelementptr inbounds i32*, i32** %y, i64 %idxprom7 + %2 = load i32*, i32** %arrayidx8, align 8 + %idxprom9 = sext i32 %i.0 to i64 + %arrayidx10 = getelementptr inbounds i32, i32* %2, i64 %idxprom9 + store i32 %add, i32* %arrayidx10, align 4 + br label %for.inc + +for.inc: ; preds = %test6Loop2 + %inc = add nsw i32 %j.0, 1 + %cmp3 = icmp slt i32 %inc, %ny + br i1 %cmp3, label %test6Loop2, label %for.cond2.for.end_crit_edge + +for.cond2.for.end_crit_edge: ; preds = %for.inc + br label %for.end + +for.end: ; preds = %for.cond2.for.end_crit_edge, %if.then + br label %if.end + +if.end: ; preds = %for.end, %test6Loop1 + br label %for.inc11 + +for.inc11: ; preds = %if.end + %inc12 = add nsw i32 %i.0, 1 + %cmp = icmp slt i32 %inc12, %nx + br i1 %cmp, label %test6Loop1, label %for.cond.for.end13_crit_edge + +for.cond.for.end13_crit_edge: ; preds = %for.inc11 + br label %for.end13 + +for.end13: ; preds = %for.cond.for.end13_crit_edge, %entry + %arrayidx14 = getelementptr inbounds i32*, i32** %y, i64 0 + %3 = load i32*, i32** %arrayidx14, align 8 + %arrayidx15 = getelementptr inbounds i32, i32* %3, i64 0 + %4 = load i32, i32* %arrayidx15, align 4 + %conv = sitofp i32 %4 to double + ret double %conv +} diff --git a/llvm/unittests/Analysis/LoopInfoTest.cpp b/llvm/unittests/Analysis/LoopInfoTest.cpp --- a/llvm/unittests/Analysis/LoopInfoTest.cpp +++ b/llvm/unittests/Analysis/LoopInfoTest.cpp @@ -1500,3 +1500,51 @@ EXPECT_FALSE(L->isRotatedForm()); }); } + +TEST(LoopInfoTest, LoopUserBranch) { + const char *ModuleStr = + "target datalayout = \"e-m:o-i64:64-f80:128-n8:16:32:64-S128\"\n" + "define void @foo(i32* %B, i64 signext %nx, i1 %cond) {\n" + "entry:\n" + " br i1 %cond, label %bb, label %guard\n" + "guard:\n" + " %cmp.guard = icmp slt i64 0, %nx\n" + " br i1 %cmp.guard, label %for.i.preheader, label %for.end\n" + "for.i.preheader:\n" + " br label %for.i\n" + "for.i:\n" + " %i = phi i64 [ 0, %for.i.preheader ], [ %inc13, %for.i ]\n" + " %Bi = getelementptr inbounds i32, i32* %B, i64 %i\n" + " store i32 0, i32* %Bi, align 4\n" + " %inc13 = add nsw i64 %i, 1\n" + " %cmp = icmp slt i64 %inc13, %nx\n" + " br i1 %cmp, label %for.i, label %for.i.exit\n" + "for.i.exit:\n" + " br label %bb\n" + "bb:\n" + " br label %for.end\n" + "for.end:\n" + " ret void\n" + "}\n"; + + // Parse the module. + LLVMContext Context; + std::unique_ptr<Module> M = makeLLVMModule(Context, ModuleStr); + + runWithLoopInfo(*M, "foo", [&](Function &F, LoopInfo &LI) { + Function::iterator FI = F.begin(); + FI = ++FI; + BasicBlock *Guard = &*FI; + assert(Guard->getName() == "guard"); + + FI = ++FI; + BasicBlock *Header = &*(++FI); + assert(Header->getName() == "for.i"); + + Loop *L = LI.getLoopFor(Header); + EXPECT_NE(L, nullptr); + + // L should not have a guard branch + EXPECT_EQ(L->getLoopGuardBranch(), nullptr); + }); +}