diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -306,8 +306,8 @@ // The current loop unroll pass can unroll loops that have // (1) single latch; and - // (2a) latch is an exiting block; or - // (2b) latch is unconditional and there exists a single exiting block. + // (2a) latch is unconditional; or + // (2b) latch is conditional and is an exiting block // FIXME: The implementation can be extended to work with more complicated // cases, e.g. loops with multiple latches. BasicBlock *Header = L->getHeader(); @@ -321,18 +321,18 @@ ExitingBI = LatchBI; else if (BasicBlock *ExitingBlock = L->getExitingBlock()) ExitingBI = dyn_cast(ExitingBlock->getTerminator()); - if (!LatchBI || !ExitingBI) { - LLVM_DEBUG(dbgs() << " Can't unroll; loop not terminated by a conditional " - "branch in latch or a single exiting block.\n"); - return LoopUnrollResult::Unmodified; - } - if (LatchBI->isConditional() && LatchBI != ExitingBI) { + if (!LatchBI || (LatchBI->isConditional() && !LatchIsExiting)) { LLVM_DEBUG( dbgs() << "Can't unroll; a conditional latch must exit the loop"); return LoopUnrollResult::Unmodified; } - LLVM_DEBUG(dbgs() << " Exiting Block = " << ExitingBI->getParent()->getName() - << "\n"); + LLVM_DEBUG({ + if (ExitingBI) + dbgs() << " Exiting Block = " << ExitingBI->getParent()->getName() + << "\n"; + else + dbgs() << " No single exiting block\n"; + }); if (Header->hasAddressTaken()) { // The loop-rotate pass can be helpful to avoid this in many cases. @@ -523,8 +523,12 @@ if (!LatchIsExiting) ++NumUnrolledNotLatch; - bool ContinueOnTrue = L->contains(ExitingBI->getSuccessor(0)); - BasicBlock *LoopExit = ExitingBI->getSuccessor(ContinueOnTrue); + Optional ContinueOnTrue = None; + BasicBlock *LoopExit = nullptr; + if (ExitingBI) { + ContinueOnTrue = L->contains(ExitingBI->getSuccessor(0)); + LoopExit = ExitingBI->getSuccessor(*ContinueOnTrue); + } // For the first iteration of the loop, we should use the precloned values for // PHI nodes. Insert associations now. @@ -540,8 +544,10 @@ std::vector Latches; Headers.push_back(Header); Latches.push_back(LatchBlock); - ExitingBlocks.push_back(ExitingBI->getParent()); - ExitingSucc.push_back(ExitingBI->getSuccessor(!ContinueOnTrue)); + if (ExitingBI) { + ExitingBlocks.push_back(ExitingBI->getParent()); + ExitingSucc.push_back(ExitingBI->getSuccessor(!(*ContinueOnTrue))); + } // The current on-the-fly SSA update requires blocks to be processed in // reverse postorder so that LastValueMap contains the correct value at each @@ -634,10 +640,12 @@ // Keep track of the exiting block and its successor block contained in // the loop for the current iteration. - if (*BB == ExitingBlocks[0]) - ExitingBlocks.push_back(New); - if (*BB == ExitingSucc[0]) - ExitingSucc.push_back(New); + if (ExitingBI) { + if (*BB == ExitingBlocks[0]) + ExitingBlocks.push_back(New); + if (*BB == ExitingSucc[0]) + ExitingSucc.push_back(New); + } NewBlocks.push_back(New); UnrolledLoopBlocks.push_back(New); @@ -689,13 +697,15 @@ } auto setDest = [](BasicBlock *Src, BasicBlock *Dest, BasicBlock *BlockInLoop, - bool NeedConditional, bool ContinueOnTrue, + bool NeedConditional, Optional ContinueOnTrue, bool IsDestLoopExit) { auto *Term = cast(Src->getTerminator()); if (NeedConditional) { // Update the conditional branch's successor for the following // iteration. - Term->setSuccessor(!ContinueOnTrue, Dest); + assert(ContinueOnTrue.hasValue() && + "Expecting valid ContinueOnTrue when NeedConditional is true"); + Term->setSuccessor(!(*ContinueOnTrue), Dest); } else { // Remove phi operands at this loop exit if (!IsDestLoopExit) { @@ -780,7 +790,7 @@ if (NeedConditional) continue; setDest(ExitingBlocks[i], ExitingSucc[i], ExitingSucc[i], NeedConditional, - ContinueOnTrue, false); + None, false); } // When completely unrolling, the last latch becomes unreachable. @@ -805,9 +815,7 @@ ChildrenToUpdate.push_back(ChildBB); } BasicBlock *NewIDom; - BasicBlock *&TermBlock = ExitingBlocks[0]; - auto &TermBlocks = ExitingBlocks; - if (BB == TermBlock) { + if (ExitingBI && BB == ExitingBlocks[0]) { // The latch is special because we emit unconditional branches in // some cases where the original loop contained a conditional branch. // Since the latch is always at the bottom of the loop, if the latch @@ -816,12 +824,13 @@ // latch which ends in a conditional branch, or the last latch if // there is no such latch. // For loops exiting from non latch exiting block, we limit the - // supported loops to have a single exiting block. - NewIDom = TermBlocks.back(); - for (unsigned i = 0, e = TermBlocks.size(); i != e; ++i) { - Instruction *Term = TermBlocks[i]->getTerminator(); + // branch simplification to single exiting block loops. + NewIDom = ExitingBlocks.back(); + for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { + Instruction *Term = ExitingBlocks[i]->getTerminator(); if (isa(Term) && cast(Term)->isConditional()) { - NewIDom = DT->findNearestCommonDominator(TermBlocks[i], Latches[i]); + NewIDom = + DT->findNearestCommonDominator(ExitingBlocks[i], Latches[i]); break; } } diff --git a/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll --- a/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll +++ b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll @@ -155,5 +155,80 @@ ret void } +; Check that loop unroll pass correctly handle loops with +; (1) multiple exiting blocks; and +; (2) loop latch is not an exiting block. + +define void @test3(i32* noalias %A, i1 %cond) { +; CHECK-LABEL: @test3( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = load i32, i32* [[A:%.*]], align 4 +; CHECK-NEXT: call void @bar(i32 [[TMP0]]) +; CHECK-NEXT: br label [[FOR_HEADER:%.*]] +; CHECK: for.header: +; CHECK-NEXT: [[TMP1:%.*]] = phi i32 [ [[TMP0]], [[ENTRY:%.*]] ], [ [[DOTPRE_3:%.*]], [[FOR_BODY_FOR_BODY_CRIT_EDGE_3:%.*]] ] +; CHECK-NEXT: [[I:%.*]] = phi i64 [ 0, [[ENTRY]] ], [ [[INC_3:%.*]], [[FOR_BODY_FOR_BODY_CRIT_EDGE_3]] ] +; CHECK-NEXT: call void @bar(i32 [[TMP1]]) +; CHECK-NEXT: br i1 [[COND:%.*]], label [[FOR_BODY:%.*]], label [[FOR_END:%.*]] +; CHECK: for.body: +; CHECK-NEXT: [[INC:%.*]] = add nuw nsw i64 [[I]], 1 +; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]], label [[FOR_END]] +; CHECK: for.body.for.body_crit_edge: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 [[INC]] +; CHECK-NEXT: [[DOTPRE:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT]], align 4 +; CHECK-NEXT: call void @bar(i32 [[DOTPRE]]) +; CHECK-NEXT: br i1 [[COND]], label [[FOR_BODY_1:%.*]], label [[FOR_END]] +; CHECK: for.end: +; CHECK-NEXT: ret void +; CHECK: for.body.1: +; CHECK-NEXT: [[INC_1:%.*]] = add nuw nsw i64 [[INC]], 1 +; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]], label [[FOR_END]] +; CHECK: for.body.for.body_crit_edge.1: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_1:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 [[INC_1]] +; CHECK-NEXT: [[DOTPRE_1:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_1]], align 4 +; CHECK-NEXT: call void @bar(i32 [[DOTPRE_1]]) +; CHECK-NEXT: br i1 [[COND]], label [[FOR_BODY_2:%.*]], label [[FOR_END]] +; CHECK: for.body.2: +; CHECK-NEXT: [[INC_2:%.*]] = add nuw nsw i64 [[INC_1]], 1 +; CHECK-NEXT: br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]], label [[FOR_END]] +; CHECK: for.body.for.body_crit_edge.2: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_2:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 [[INC_2]] +; CHECK-NEXT: [[DOTPRE_2:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_2]], align 4 +; CHECK-NEXT: call void @bar(i32 [[DOTPRE_2]]) +; CHECK-NEXT: br i1 [[COND]], label [[FOR_BODY_3:%.*]], label [[FOR_END]] +; CHECK: for.body.3: +; CHECK-NEXT: [[INC_3]] = add nuw nsw i64 [[INC_2]], 1 +; CHECK-NEXT: br i1 false, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_3]], label [[FOR_END]] +; CHECK: for.body.for.body_crit_edge.3: +; CHECK-NEXT: [[ARRAYIDX_PHI_TRANS_INSERT_3:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 [[INC_3]] +; CHECK-NEXT: [[DOTPRE_3]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_3]], align 4 +; CHECK-NEXT: br label [[FOR_HEADER]], !llvm.loop !2 +; +entry: + %0 = load i32, i32* %A, align 4 + call void @bar(i32 %0) + br label %for.header + +for.header: + %1 = phi i32 [ %0, %entry ], [ %.pre, %for.body.for.body_crit_edge ] + %i = phi i64 [ 0, %entry ], [ %inc, %for.body.for.body_crit_edge ] + %arrayidx = getelementptr inbounds i32, i32* %A, i64 %i + call void @bar(i32 %1) + br i1 %cond, label %for.body, label %for.end + +for.body: + %inc = add nsw i64 %i, 1 + %cmp = icmp slt i64 %inc, 4 + br i1 %cmp, label %for.body.for.body_crit_edge, label %for.end + +for.body.for.body_crit_edge: + %arrayidx.phi.trans.insert = getelementptr inbounds i32, i32* %A, i64 %inc + %.pre = load i32, i32* %arrayidx.phi.trans.insert, align 4 + br label %for.header + +for.end: + ret void +} + declare void @bar(i32) declare i1 @foo(i64)