diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -81,8 +81,8 @@ // TODO: Should these be here or in LoopUnroll? STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled"); STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)"); -STATISTIC(NumUnrolledWithHeader, "Number of loops unrolled without a " - "conditional latch (completely or otherwise)"); +STATISTIC(NumUnrolledNotLatch, "Number of loops unrolled without a conditional " + "latch (completely or otherwise)"); static cl::opt UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(false), cl::Hidden, @@ -304,48 +304,39 @@ return LoopUnrollResult::Unmodified; } - // The current loop unroll pass can unroll loops with a single latch or header - // that's a conditional branch exiting the loop. + // The current loop unroll pass can unroll loops with loops that + // (1) single latch + // (2a) latch is the exiting block; or + // (2b) latch is unconditional and there exists at least one exiting block. // FIXME: The implementation can be extended to work with more complicated // cases, e.g. loops with multiple latches. BasicBlock *Header = L->getHeader(); - BranchInst *HeaderBI = dyn_cast(Header->getTerminator()); - BranchInst *BI = dyn_cast(LatchBlock->getTerminator()); - // FIXME: Support loops without conditional latch and multiple exiting blocks. - if (!BI || - (BI->isUnconditional() && (!HeaderBI || HeaderBI->isUnconditional() || - L->getExitingBlock() != Header))) { - LLVM_DEBUG(dbgs() << " Can't unroll; loop not terminated by a conditional " - "branch in the latch or header.\n"); - return LoopUnrollResult::Unmodified; - } - - auto CheckLatchSuccessors = [&](unsigned S1, unsigned S2) { - return BI->isConditional() && BI->getSuccessor(S1) == Header && - !L->contains(BI->getSuccessor(S2)); - }; - - // If we have a conditional latch, it must exit the loop. - if (BI && BI->isConditional() && !CheckLatchSuccessors(0, 1) && - !CheckLatchSuccessors(1, 0)) { - LLVM_DEBUG( - dbgs() << "Can't unroll; a conditional latch must exit the loop"); - return LoopUnrollResult::Unmodified; + BranchInst *BI = nullptr; + if (L->isLoopExiting(LatchBlock)) + BI = dyn_cast_or_null(LatchBlock->getTerminator()); + else { + SmallVector ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + for (BasicBlock *ExitingBlock : ExitingBlocks) + if (BranchInst *ExitingBI = + dyn_cast_or_null(ExitingBlock->getTerminator())) { + BI = ExitingBI; + + // Give higher priority for exiting blocks that have small constant trip + // count equals to the desired unroll factor. + assert(ULO.Count > 0 && "Expecting valid unroll count"); + if (SE && ULO.Count == SE->getSmallConstantTripCount(L, ExitingBlock)) + break; + } } - - auto CheckHeaderSuccessors = [&](unsigned S1, unsigned S2) { - return HeaderBI && HeaderBI->isConditional() && - L->contains(HeaderBI->getSuccessor(S1)) && - !L->contains(HeaderBI->getSuccessor(S2)); - }; - - // If we do not have a conditional latch, the header must exit the loop. - if (BI && !BI->isConditional() && HeaderBI && HeaderBI->isConditional() && - !CheckHeaderSuccessors(0, 1) && !CheckHeaderSuccessors(1, 0)) { - LLVM_DEBUG(dbgs() << "Can't unroll; conditional header must exit the loop"); + if (!BI) { + LLVM_DEBUG(dbgs() << " Can't unroll; loop not terminated by a conditional " + "branch in latch or exiting blocks.\n"); return LoopUnrollResult::Unmodified; } + LLVM_DEBUG(dbgs() << " Exiting Block = " << BI->getParent()->getName() + << "\n"); if (Header->hasAddressTaken()) { // The loop-rotate pass can be helpful to avoid this in many cases. @@ -534,17 +525,11 @@ SE->forgetTopmostLoop(L); } - bool ContinueOnTrue; - bool LatchIsExiting = BI->isConditional(); - BasicBlock *LoopExit = nullptr; - if (LatchIsExiting) { - ContinueOnTrue = L->contains(BI->getSuccessor(0)); - LoopExit = BI->getSuccessor(ContinueOnTrue); - } else { - NumUnrolledWithHeader++; - ContinueOnTrue = L->contains(HeaderBI->getSuccessor(0)); - LoopExit = HeaderBI->getSuccessor(ContinueOnTrue); - } + bool LatchIsExiting = BI->getParent() == LatchBlock; + if (!LatchIsExiting) + ++NumUnrolledNotLatch; + bool ContinueOnTrue = L->contains(BI->getSuccessor(0)); + BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue); // For the first iteration of the loop, we should use the precloned values for // PHI nodes. Insert associations now. @@ -555,21 +540,13 @@ } std::vector Headers; - std::vector HeaderSucc; + std::vector ExitingBlocks; + std::vector ExitingSucc; std::vector Latches; Headers.push_back(Header); Latches.push_back(LatchBlock); - - if (!LatchIsExiting) { - auto *Term = cast(Header->getTerminator()); - if (Term->isUnconditional() || L->contains(Term->getSuccessor(0))) { - assert(L->contains(Term->getSuccessor(0))); - HeaderSucc.push_back(Term->getSuccessor(0)); - } else { - assert(L->contains(Term->getSuccessor(1))); - HeaderSucc.push_back(Term->getSuccessor(1)); - } - } + ExitingBlocks.push_back(BI->getParent()); + ExitingSucc.push_back(BI->getSuccessor(!ContinueOnTrue)); // The current on-the-fly SSA update requires blocks to be processed in // reverse postorder so that LastValueMap contains the correct value at each @@ -660,12 +637,12 @@ if (*BB == LatchBlock) Latches.push_back(New); - // Keep track of the successor of the new header in the current iteration. - for (auto *Pred : predecessors(*BB)) - if (Pred == Header) { - HeaderSucc.push_back(New); - break; - } + // Keep track of the exiting block and its successor block contained in + // the loop for the current iteration. + if (*BB == ExitingBlocks[0]) + ExitingBlocks.push_back(New); + if (*BB == ExitingSucc[0]) + ExitingSucc.push_back(New); NewBlocks.push_back(New); UnrolledLoopBlocks.push_back(New); @@ -745,22 +722,19 @@ // Now that all the basic blocks for the unrolled iterations are in place, // set up the branches to connect them. - if (LatchIsExiting) { - // Set up latches to branch to the new header in the unrolled iterations or - // the loop exit for the last latch in a fully unrolled loop. - for (unsigned i = 0, e = Latches.size(); i != e; ++i) { - // The branch destination. - unsigned j = (i + 1) % e; - BasicBlock *Dest = Headers[j]; - bool NeedConditional = true; - - if (RuntimeTripCount && j != 0) { - NeedConditional = false; - } - - // For a complete unroll, make the last iteration end with a branch - // to the exit block. - if (CompletelyUnroll) { + for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { + // The branch destination. + unsigned j = (i + 1) % e; + BasicBlock *Dest = ExitingSucc[(LatchIsExiting ? j : i)]; + bool NeedConditional = true; + + if (RuntimeTripCount && j != 0) + NeedConditional = false; + + if (CompletelyUnroll) + if (LatchIsExiting) { + // For a complete unroll, make the last iteration end with a branch + // to the exit block. if (j == 0) Dest = LoopExit; // If using trip count upper bound to completely unroll, we need to keep @@ -771,40 +745,21 @@ "unrolling and runtime unrolling"); NeedConditional = (ULO.PreserveCondBr && j && !(ULO.PreserveOnlyFirst && i != 0)); - } else if (j != BreakoutTrip && - (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) { - // If we know the trip count or a multiple of it, we can safely use an - // unconditional branch for some iterations. - NeedConditional = false; - } - - setDest(Latches[i], Dest, Headers[i], NeedConditional); - } - } else { - // Setup headers to branch to their new successors in the unrolled - // iterations. - for (unsigned i = 0, e = Headers.size(); i != e; ++i) { - // The branch destination. - unsigned j = (i + 1) % e; - BasicBlock *Dest = HeaderSucc[i]; - bool NeedConditional = true; - - if (RuntimeTripCount && j != 0) - NeedConditional = false; - - if (CompletelyUnroll) + } else { // We cannot drop the conditional branch for the last condition, as we // may have to execute the loop body depending on the condition. NeedConditional = j == 0 || ULO.PreserveCondBr; - else if (j != BreakoutTrip && - (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) - // If we know the trip count or a multiple of it, we can safely use an - // unconditional branch for some iterations. - NeedConditional = false; + } + else if (j != BreakoutTrip && + (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) + // If we know the trip count or a multiple of it, we can safely use an + // unconditional branch for some iterations. + NeedConditional = false; - setDest(Headers[i], Dest, HeaderSucc[i], NeedConditional); - } + setDest(ExitingBlocks[i], Dest, ExitingSucc[i], NeedConditional); + } + if (!LatchIsExiting) { // Set up latches to branch to the new header in the unrolled iterations or // the loop exit for the last latch in a fully unrolled loop. @@ -841,8 +796,8 @@ ChildrenToUpdate.push_back(ChildBB); } BasicBlock *NewIDom; - BasicBlock *&TermBlock = LatchIsExiting ? LatchBlock : Header; - auto &TermBlocks = LatchIsExiting ? Latches : Headers; + BasicBlock *&TermBlock = ExitingBlocks[0]; + auto &TermBlocks = ExitingBlocks; if (BB == TermBlock) { // The latch is special because we emit unconditional branches in // some cases where the original loop contained a conditional branch. @@ -857,7 +812,7 @@ for (BasicBlock *Iter : TermBlocks) { Instruction *Term = Iter->getTerminator(); if (isa(Term) && cast(Term)->isConditional()) { - NewIDom = Iter; + NewIDom = DT->findNearestCommonDominator(Iter, LatchBlock); break; } } diff --git a/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll @@ -0,0 +1,69 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -loop-unroll -S | FileCheck %s +; RUN: opt < %s -passes='require,unroll' -S | FileCheck %s + +define void @foo(i32* noalias %A) { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = load i32, i32* [[A:%.*]], align 4 +; CHECK-NEXT: call void @bar(i32 [[TMP0]]) +; CHECK-NEXT: br label [[FOR_HEADER:%.*]] +; CHECK: for.header: +; CHECK-NEXT: call void @bar(i32 [[TMP0]]) +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.body: +; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]] +; CHECK: for.body.for.body_crit_edge: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 1 +; CHECK-NEXT: [[DOTPRE:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT]], align 4 +; CHECK-NEXT: call void @bar(i32 [[DOTPRE]]) +; CHECK-NEXT: br label [[FOR_BODY_1:%.*]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; CHECK: for.body.1: +; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]] +; CHECK: for.body.for.body_crit_edge.1: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_1:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 2 +; CHECK-NEXT: [[DOTPRE_1:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_1]], align 4 +; CHECK-NEXT: call void @bar(i32 [[DOTPRE_1]]) +; CHECK-NEXT: br label [[FOR_BODY_2:%.*]] +; CHECK: for.body.2: +; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]] +; CHECK: for.body.for.body_crit_edge.2: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_2:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 3 +; CHECK-NEXT: [[DOTPRE_2:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_2]], align 4 +; CHECK-NEXT: call void @bar(i32 [[DOTPRE_2]]) +; CHECK-NEXT: br label [[FOR_BODY_3:%.*]] +; CHECK: for.body.3: +; CHECK-NEXT: br i1 false, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_3:%.*]], label [[FOR_END:%.*]] +; CHECK: for.body.for.body_crit_edge.3: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_3:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 4 +; CHECK-NEXT: unreachable +; +entry: + %0 = load i32, i32* %A, align 4 + call void @bar(i32 %0) + br label %for.header + +for.header: + %1 = phi i32 [ %0, %entry ], [ %.pre, %for.body.for.body_crit_edge ] + %i = phi i64 [ 0, %entry ], [ %inc, %for.body.for.body_crit_edge ] + %arrayidx = getelementptr inbounds i32, i32* %A, i64 %i + call void @bar(i32 %1) + br label %for.body + +for.body: + %inc = add nsw i64 %i, 1 + %cmp = icmp slt i64 %inc, 4 + br i1 %cmp, label %for.body.for.body_crit_edge, label %for.end + +for.body.for.body_crit_edge: + %arrayidx.phi.trans.insert = getelementptr inbounds i32, i32* %A, i64 %inc + %.pre = load i32, i32* %arrayidx.phi.trans.insert, align 4 + br label %for.header + +for.end: + ret void +} + +declare void @bar(i32)