diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -81,8 +81,8 @@ // TODO: Should these be here or in LoopUnroll? STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled"); STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)"); -STATISTIC(NumUnrolledWithHeader, "Number of loops unrolled without a " - "conditional latch (completely or otherwise)"); +STATISTIC(NumUnrolledNotLatch, "Number of loops unrolled without a conditional " + "latch (completely or otherwise)"); static cl::opt UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(false), cl::Hidden, @@ -304,48 +304,26 @@ return LoopUnrollResult::Unmodified; } - // The current loop unroll pass can unroll loops with a single latch or header - // that's a conditional branch exiting the loop. + // The current loop unroll pass can unroll loops that have + // (1) single latch; and + // (2a) latch is an exiting block; or + // (2b) latch is unconditional and there exists a single exiting block. // FIXME: The implementation can be extended to work with more complicated // cases, e.g. loops with multiple latches. BasicBlock *Header = L->getHeader(); - BranchInst *HeaderBI = dyn_cast(Header->getTerminator()); - BranchInst *BI = dyn_cast(LatchBlock->getTerminator()); - // FIXME: Support loops without conditional latch and multiple exiting blocks. - if (!BI || - (BI->isUnconditional() && (!HeaderBI || HeaderBI->isUnconditional() || - L->getExitingBlock() != Header))) { + BranchInst *BI = nullptr; + if (L->isLoopExiting(LatchBlock)) + BI = dyn_cast_or_null(LatchBlock->getTerminator()); + else if (BasicBlock *ExitingBlock = L->getExitingBlock()) + BI = dyn_cast_or_null(ExitingBlock->getTerminator()); + if (!BI) { LLVM_DEBUG(dbgs() << " Can't unroll; loop not terminated by a conditional " - "branch in the latch or header.\n"); - return LoopUnrollResult::Unmodified; - } - - auto CheckLatchSuccessors = [&](unsigned S1, unsigned S2) { - return BI->isConditional() && BI->getSuccessor(S1) == Header && - !L->contains(BI->getSuccessor(S2)); - }; - - // If we have a conditional latch, it must exit the loop. - if (BI && BI->isConditional() && !CheckLatchSuccessors(0, 1) && - !CheckLatchSuccessors(1, 0)) { - LLVM_DEBUG( - dbgs() << "Can't unroll; a conditional latch must exit the loop"); - return LoopUnrollResult::Unmodified; - } - - auto CheckHeaderSuccessors = [&](unsigned S1, unsigned S2) { - return HeaderBI && HeaderBI->isConditional() && - L->contains(HeaderBI->getSuccessor(S1)) && - !L->contains(HeaderBI->getSuccessor(S2)); - }; - - // If we do not have a conditional latch, the header must exit the loop. - if (BI && !BI->isConditional() && HeaderBI && HeaderBI->isConditional() && - !CheckHeaderSuccessors(0, 1) && !CheckHeaderSuccessors(1, 0)) { - LLVM_DEBUG(dbgs() << "Can't unroll; conditional header must exit the loop"); + "branch in latch or a single exiting block.\n"); return LoopUnrollResult::Unmodified; } + LLVM_DEBUG(dbgs() << " Exiting Block = " << BI->getParent()->getName() + << "\n"); if (Header->hasAddressTaken()) { // The loop-rotate pass can be helpful to avoid this in many cases. @@ -534,17 +512,11 @@ SE->forgetTopmostLoop(L); } - bool ContinueOnTrue; - bool LatchIsExiting = BI->isConditional(); - BasicBlock *LoopExit = nullptr; - if (LatchIsExiting) { - ContinueOnTrue = L->contains(BI->getSuccessor(0)); - LoopExit = BI->getSuccessor(ContinueOnTrue); - } else { - NumUnrolledWithHeader++; - ContinueOnTrue = L->contains(HeaderBI->getSuccessor(0)); - LoopExit = HeaderBI->getSuccessor(ContinueOnTrue); - } + bool LatchIsExiting = BI->getParent() == LatchBlock; + if (!LatchIsExiting) + ++NumUnrolledNotLatch; + bool ContinueOnTrue = L->contains(BI->getSuccessor(0)); + BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue); // For the first iteration of the loop, we should use the precloned values for // PHI nodes. Insert associations now. @@ -555,21 +527,13 @@ } std::vector Headers; - std::vector HeaderSucc; + std::vector ExitingBlocks; + std::vector ExitingSucc; std::vector Latches; Headers.push_back(Header); Latches.push_back(LatchBlock); - - if (!LatchIsExiting) { - auto *Term = cast(Header->getTerminator()); - if (Term->isUnconditional() || L->contains(Term->getSuccessor(0))) { - assert(L->contains(Term->getSuccessor(0))); - HeaderSucc.push_back(Term->getSuccessor(0)); - } else { - assert(L->contains(Term->getSuccessor(1))); - HeaderSucc.push_back(Term->getSuccessor(1)); - } - } + ExitingBlocks.push_back(BI->getParent()); + ExitingSucc.push_back(BI->getSuccessor(!ContinueOnTrue)); // The current on-the-fly SSA update requires blocks to be processed in // reverse postorder so that LastValueMap contains the correct value at each @@ -660,12 +624,12 @@ if (*BB == LatchBlock) Latches.push_back(New); - // Keep track of the successor of the new header in the current iteration. - for (auto *Pred : predecessors(*BB)) - if (Pred == Header) { - HeaderSucc.push_back(New); - break; - } + // Keep track of the exiting block and its successor block contained in + // the loop for the current iteration. + if (*BB == ExitingBlocks[0]) + ExitingBlocks.push_back(New); + if (*BB == ExitingSucc[0]) + ExitingSucc.push_back(New); NewBlocks.push_back(New); UnrolledLoopBlocks.push_back(New); @@ -784,7 +748,7 @@ if (!LatchIsExiting) { // If the latch is not exiting, we may be able to simplify the conditional // branches in the unrolled exiting blocks. - for (unsigned i = 0, e = Headers.size(); i != e; ++i) { + for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { // The branch destination. unsigned j = (i + 1) % e; bool NeedConditional = true; @@ -807,7 +771,7 @@ // already correct. if (NeedConditional) continue; - setDest(Headers[i], HeaderSucc[i], HeaderSucc[i], NeedConditional, + setDest(ExitingBlocks[i], ExitingSucc[i], ExitingSucc[i], NeedConditional, ContinueOnTrue, false); } @@ -833,8 +797,8 @@ ChildrenToUpdate.push_back(ChildBB); } BasicBlock *NewIDom; - BasicBlock *&TermBlock = LatchIsExiting ? LatchBlock : Header; - auto &TermBlocks = LatchIsExiting ? Latches : Headers; + BasicBlock *&TermBlock = ExitingBlocks[0]; + auto &TermBlocks = ExitingBlocks; if (BB == TermBlock) { // The latch is special because we emit unconditional branches in // some cases where the original loop contained a conditional branch. @@ -849,7 +813,7 @@ for (BasicBlock *Iter : TermBlocks) { Instruction *Term = Iter->getTerminator(); if (isa(Term) && cast(Term)->isConditional()) { - NewIDom = Iter; + NewIDom = DT->findNearestCommonDominator(Iter, LatchBlock); break; } } diff --git a/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll @@ -0,0 +1,69 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -loop-unroll -S | FileCheck %s +; RUN: opt < %s -passes='require,unroll' -S | FileCheck %s + +define void @foo(i32* noalias %A) { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = load i32, i32* [[A:%.*]], align 4 +; CHECK-NEXT: call void @bar(i32 [[TMP0]]) +; CHECK-NEXT: br label [[FOR_HEADER:%.*]] +; CHECK: for.header: +; CHECK-NEXT: call void @bar(i32 [[TMP0]]) +; CHECK-NEXT: br label [[FOR_BODY:%.*]] +; CHECK: for.body: +; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]] +; CHECK: for.body.for.body_crit_edge: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 1 +; CHECK-NEXT: [[DOTPRE:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT]], align 4 +; CHECK-NEXT: call void @bar(i32 [[DOTPRE]]) +; CHECK-NEXT: br label [[FOR_BODY_1:%.*]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; CHECK: for.body.1: +; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]] +; CHECK: for.body.for.body_crit_edge.1: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_1:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 2 +; CHECK-NEXT: [[DOTPRE_1:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_1]], align 4 +; CHECK-NEXT: call void @bar(i32 [[DOTPRE_1]]) +; CHECK-NEXT: br label [[FOR_BODY_2:%.*]] +; CHECK: for.body.2: +; CHECK-NEXT: br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]] +; CHECK: for.body.for.body_crit_edge.2: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_2:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 3 +; CHECK-NEXT: [[DOTPRE_2:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_2]], align 4 +; CHECK-NEXT: call void @bar(i32 [[DOTPRE_2]]) +; CHECK-NEXT: br label [[FOR_BODY_3:%.*]] +; CHECK: for.body.3: +; CHECK-NEXT: br i1 false, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_3:%.*]], label [[FOR_END:%.*]] +; CHECK: for.body.for.body_crit_edge.3: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_3:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 4 +; CHECK-NEXT: unreachable +; +entry: + %0 = load i32, i32* %A, align 4 + call void @bar(i32 %0) + br label %for.header + +for.header: + %1 = phi i32 [ %0, %entry ], [ %.pre, %for.body.for.body_crit_edge ] + %i = phi i64 [ 0, %entry ], [ %inc, %for.body.for.body_crit_edge ] + %arrayidx = getelementptr inbounds i32, i32* %A, i64 %i + call void @bar(i32 %1) + br label %for.body + +for.body: + %inc = add nsw i64 %i, 1 + %cmp = icmp slt i64 %inc, 4 + br i1 %cmp, label %for.body.for.body_crit_edge, label %for.end + +for.body.for.body_crit_edge: + %arrayidx.phi.trans.insert = getelementptr inbounds i32, i32* %A, i64 %inc + %.pre = load i32, i32* %arrayidx.phi.trans.insert, align 4 + br label %for.header + +for.end: + ret void +} + +declare void @bar(i32)