Index: llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -55,7 +55,7 @@ /// Update the dominator tree after removing one exiting predecessor of a loop /// exit block. static void updateLoopExitIDom(BasicBlock *LoopExitBB, Loop &L, - DominatorTree &DT) { + DominatorTree &DT) { assert(pred_begin(LoopExitBB) != pred_end(LoopExitBB) && "Cannot have empty predecessors of the loop exit block if we split " "off a block to unswitch!"); @@ -137,6 +137,98 @@ } } +/// Check that all the LCSSA PHI nodes in the loop exit block have trivial +/// incoming values along this edge. +static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, + BasicBlock &ExitBB) { + for (Instruction &I : ExitBB) { + auto *PN = dyn_cast(&I); + if (!PN) + // No more PHIs to check. + return true; + + // If the incoming value for this edge isn't loop invariant the unswitch + // won't be trivial. + if (!L.isLoopInvariant(PN->getIncomingValueForBlock(&ExitingBB))) + return false; + } + llvm_unreachable("Basic blocks should never be empty!"); +} + +/// Rewrite the PHI nodes in an unswitched loop exit basic block. +/// +/// Requires that the loop exit and unswitched basic block are the same, and +/// that the exiting block was a unique predecessor of that block. Rewrites the +/// PHI nodes in that block such that what were LCSSA PHI nodes become trivial +/// PHI nodes from the old preheader that now contains the unswitched +/// terminator. +static void rewritePHINodesForUnswitchedExitBlock(BasicBlock &UnswitchedBB, + BasicBlock &OldExitingBB, + BasicBlock &OldPH) { + for (Instruction &I : UnswitchedBB) { + auto *PN = dyn_cast(&I); + if (!PN) + // No more PHIs to check. + break; + + // When the loop exit is directly unswitched we just need to update the + // incoming basic block. We loop to handle weird cases with repeated + // incoming blocks, but expect to typically only have one operand here. + for (auto i : llvm::seq(0, PN->getNumOperands())) { + assert(PN->getIncomingBlock(i) == &OldExitingBB && + "Found incoming block different from unique predecessor!"); + PN->setIncomingBlock(i, &OldPH); + } + } +} + +/// Rewrite the PHI nodes in the loop exit basic block and the split off +/// unswitched block. +/// +/// Because the exit block remains an exit from the loop, this rewrites the +/// LCSSA PHI nodes in it to remove the unswitched edge and introduces PHI +/// nodes into the unswitched basic block to select between the value in the +/// old preheader and the loop exit. +static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, + BasicBlock &UnswitchedBB, + BasicBlock &OldExitingBB, + BasicBlock &OldPH) { + assert(&ExitBB != &UnswitchedBB && + "Must have different loop exit and unswitched blocks!"); + Instruction *InsertPt = &*UnswitchedBB.begin(); + for (Instruction &I : ExitBB) { + auto *PN = dyn_cast(&I); + if (!PN) + // No more PHIs to check. + break; + + auto *NewPN = PHINode::Create(PN->getType(), /*NumReservedValues*/ 2, + PN->getName() + ".split", InsertPt); + + // Walk backwards over the old PHI node's inputs to minimize the cost of + // removing each one. We have to do this weird loop manually so that we + // create the same number of new incoming edges in the new PHI as we expect + // each case-based edge to be included in the unswitched switch in some + // cases. + // FIXME: This is really, really gross. It would be much cleaner if LLVM + // allowed us to create a single entry for a predecessor block without + // having separate entries for each "edge" even though these edges are + // required to produce identical results. + for (int i = PN->getNumIncomingValues() - 1; i >= 0; --i) { + if (PN->getIncomingBlock(i) != &OldExitingBB) + continue; + + Value *Incoming = PN->removeIncomingValue(i); + NewPN->addIncoming(Incoming, &OldPH); + } + + // Now replace the old PHI with the new one and wire the old one in as an + // input to the new one. + PN->replaceAllUsesWith(NewPN); + NewPN->addIncoming(PN, &ExitBB); + } +} + /// Unswitch a trivial branch if the condition is loop invariant. /// /// This routine should only be called when loop code leading to the branch has @@ -187,10 +279,8 @@ assert(L.contains(ContinueBB) && "Cannot have both successors exit and still be in the loop!"); - // If the loop exit block contains phi nodes, this isn't trivial. - // FIXME: We should examine the PHI to determine whether or not we can handle - // it trivially. - if (isa(LoopExitBB->begin())) + auto *ParentBB = BI.getParent(); + if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, *LoopExitBB)) return false; DEBUG(dbgs() << " unswitching trivial branch when: " << CondVal @@ -209,14 +299,13 @@ BasicBlock *UnswitchedBB; if (BasicBlock *PredBB = LoopExitBB->getUniquePredecessor()) { (void)PredBB; - assert(PredBB == BI.getParent() && "A branch's parent is't a predecessor!"); + assert(PredBB == BI.getParent() && + "A branch's parent isn't a predecessor!"); UnswitchedBB = LoopExitBB; } else { UnswitchedBB = SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI); } - BasicBlock *ParentBB = BI.getParent(); - // Now splice the branch to gate reaching the new preheader and re-point its // successors. OldPH->getInstList().splice(std::prev(OldPH->end()), @@ -229,6 +318,13 @@ // terminator. BranchInst::Create(ContinueBB, ParentBB); + // Rewrite the relevant PHI nodes. + if (UnswitchedBB == LoopExitBB) + rewritePHINodesForUnswitchedExitBlock(*UnswitchedBB, *ParentBB, *OldPH); + else + rewritePHINodesForExitAndUnswitchedBlocks(*LoopExitBB, *UnswitchedBB, + *ParentBB, *OldPH); + // Now we need to update the dominator tree. updateDTAfterUnswitch(UnswitchedBB, OldPH, DT); // But if we split something off of the loop exit block then we also removed @@ -278,6 +374,8 @@ if (!L.isLoopInvariant(LoopCond)) return false; + auto *ParentBB = SI.getParent(); + // FIXME: We should compute this once at the start and update it! SmallVector ExitBlocks; L.getExitBlocks(ExitBlocks); @@ -287,12 +385,13 @@ SmallVector ExitCaseIndices; for (auto Case : SI.cases()) { auto *SuccBB = Case.getCaseSuccessor(); - if (ExitBlockSet.count(SuccBB) && !isa(SuccBB->begin())) + if (ExitBlockSet.count(SuccBB) && + areLoopExitPHIsLoopInvariant(L, *ParentBB, *SuccBB)) ExitCaseIndices.push_back(Case.getCaseIndex()); } BasicBlock *DefaultExitBB = nullptr; if (ExitBlockSet.count(SI.getDefaultDest()) && - !isa(SI.getDefaultDest()->begin()) && + areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && !isa(SI.getDefaultDest()->getTerminator())) DefaultExitBB = SI.getDefaultDest(); else if (ExitCaseIndices.empty()) @@ -330,7 +429,6 @@ if (CommonSuccBB) { SI.setDefaultDest(CommonSuccBB); } else { - BasicBlock *ParentBB = SI.getParent(); BasicBlock *UnreachableBB = BasicBlock::Create( ParentBB->getContext(), Twine(ParentBB->getName()) + ".unreachable_default", @@ -358,30 +456,44 @@ // Now add the unswitched switch. auto *NewSI = SwitchInst::Create(LoopCond, NewPH, ExitCases.size(), OldPH); - // Split any exit blocks with remaining in-loop predecessors. We walk in - // reverse so that we split in the same order as the cases appeared. This is - // purely for convenience of reading the resulting IR, but it doesn't cost - // anything really. + // Rewrite the IR for the unswitched basic blocks. This requires two steps. + // First, we split any exit blocks with remaining in-loop predecessors. Then + // we update the PHIs in one of two ways depending on if there was a split. + // We walk in reverse so that we split in the same order as the cases + // appeared. This is purely for convenience of reading the resulting IR, but + // it doesn't cost anything really. + SmallPtrSet UnswitchedExitBBs; SmallDenseMap SplitExitBBMap; // Handle the default exit if necessary. // FIXME: It'd be great if we could merge this with the loop below but LLVM's // ranges aren't quite powerful enough yet. - if (DefaultExitBB && !pred_empty(DefaultExitBB)) { - auto *SplitBB = - SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI); - updateLoopExitIDom(DefaultExitBB, L, DT); - DefaultExitBB = SplitExitBBMap[DefaultExitBB] = SplitBB; + if (DefaultExitBB) { + if (pred_empty(DefaultExitBB)) { + UnswitchedExitBBs.insert(DefaultExitBB); + rewritePHINodesForUnswitchedExitBlock(*DefaultExitBB, *ParentBB, *OldPH); + } else { + auto *SplitBB = + SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI); + rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB, + *ParentBB, *OldPH); + updateLoopExitIDom(DefaultExitBB, L, DT); + DefaultExitBB = SplitExitBBMap[DefaultExitBB] = SplitBB; + } } // Note that we must use a reference in the for loop so that we update the // container. for (auto &CasePair : reverse(ExitCases)) { // Grab a reference to the exit block in the pair so that we can update it. - BasicBlock *&ExitBB = CasePair.second; + BasicBlock *ExitBB = CasePair.second; // If this case is the last edge into the exit block, we can simply reuse it // as it will no longer be a loop exit. No mapping necessary. - if (pred_empty(ExitBB)) + if (pred_empty(ExitBB)) { + // Only rewrite once. + if (UnswitchedExitBBs.insert(ExitBB).second) + rewritePHINodesForUnswitchedExitBlock(*ExitBB, *ParentBB, *OldPH); continue; + } // Otherwise we need to split the exit block so that we retain an exit // block from the loop and a target for the unswitched condition. @@ -389,9 +501,12 @@ if (!SplitExitBB) { // If this is the first time we see this, do the split and remember it. SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); + rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB, + *ParentBB, *OldPH); updateLoopExitIDom(ExitBB, L, DT); } - ExitBB = SplitExitBB; + // Update the case pair to point to the split block. + CasePair.second = SplitExitBB; } // Now add the unswitched cases. We do this in reverse order as we built them Index: llvm/trunk/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll =================================================================== --- llvm/trunk/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll +++ llvm/trunk/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll @@ -183,3 +183,202 @@ ; CHECK: [[UNREACHABLE]]: ; CHECK-NEXT: unreachable } + +; This test contains a trivially unswitchable branch with an LCSSA phi node in +; a loop exit block. +define i32 @test5(i1 %cond1, i32 %x, i32 %y) { +; CHECK-LABEL: @test5( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 %{{.*}}, label %entry.split, label %loop_exit +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + br i1 %cond1, label %latch, label %loop_exit +; CHECK: loop_begin: +; CHECK-NEXT: br label %latch + +latch: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: latch: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + %result1 = phi i32 [ %x, %loop_begin ] + %result2 = phi i32 [ %y, %loop_begin ] + %result = add i32 %result1, %result2 + ret i32 %result +; CHECK: loop_exit: +; CHECK-NEXT: %[[R1:.*]] = phi i32 [ %x, %entry ] +; CHECK-NEXT: %[[R2:.*]] = phi i32 [ %y, %entry ] +; CHECK-NEXT: %[[R:.*]] = add i32 %[[R1]], %[[R2]] +; CHECK-NEXT: ret i32 %[[R]] +} + +; This test contains a trivially unswitchable branch with a real phi node in LCSSA +; position in a shared exit block where a different path through the loop +; produces a non-invariant input to the PHI node. +define i32 @test6(i32* %var, i1 %cond1, i1 %cond2, i32 %x, i32 %y) { +; CHECK-LABEL: @test6( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: br i1 %{{.*}}, label %entry.split, label %loop_exit.split +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + br i1 %cond1, label %continue, label %loop_exit +; CHECK: loop_begin: +; CHECK-NEXT: br label %continue + +continue: + %var_val = load i32, i32* %var + br i1 %cond2, label %latch, label %loop_exit +; CHECK: continue: +; CHECK-NEXT: load +; CHECK-NEXT: br i1 %cond2, label %latch, label %loop_exit + +latch: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: latch: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + %result1 = phi i32 [ %x, %loop_begin ], [ %var_val, %continue ] + %result2 = phi i32 [ %var_val, %continue ], [ %y, %loop_begin ] + %result = add i32 %result1, %result2 + ret i32 %result +; CHECK: loop_exit: +; CHECK-NEXT: %[[R1:.*]] = phi i32 [ %var_val, %continue ] +; CHECK-NEXT: %[[R2:.*]] = phi i32 [ %var_val, %continue ] +; CHECK-NEXT: br label %loop_exit.split +; +; CHECK: loop_exit.split: +; CHECK-NEXT: %[[R1S:.*]] = phi i32 [ %x, %entry ], [ %[[R1]], %loop_exit ] +; CHECK-NEXT: %[[R2S:.*]] = phi i32 [ %y, %entry ], [ %[[R2]], %loop_exit ] +; CHECK-NEXT: %[[R:.*]] = add i32 %[[R1S]], %[[R2S]] +; CHECK-NEXT: ret i32 %[[R]] +} + +; This test contains a trivially unswitchable switch with an LCSSA phi node in +; a loop exit block. +define i32 @test7(i32 %cond1, i32 %x, i32 %y) { +; CHECK-LABEL: @test7( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 %cond1, label %entry.split [ +; CHECK-NEXT: i32 0, label %loop_exit +; CHECK-NEXT: i32 1, label %loop_exit +; CHECK-NEXT: ] +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + switch i32 %cond1, label %latch [ + i32 0, label %loop_exit + i32 1, label %loop_exit + ] +; CHECK: loop_begin: +; CHECK-NEXT: br label %latch + +latch: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: latch: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + %result1 = phi i32 [ %x, %loop_begin ], [ %x, %loop_begin ] + %result2 = phi i32 [ %y, %loop_begin ], [ %y, %loop_begin ] + %result = add i32 %result1, %result2 + ret i32 %result +; CHECK: loop_exit: +; CHECK-NEXT: %[[R1:.*]] = phi i32 [ %x, %entry ], [ %x, %entry ] +; CHECK-NEXT: %[[R2:.*]] = phi i32 [ %y, %entry ], [ %y, %entry ] +; CHECK-NEXT: %[[R:.*]] = add i32 %[[R1]], %[[R2]] +; CHECK-NEXT: ret i32 %[[R]] +} + +; This test contains a trivially unswitchable switch with a real phi node in +; LCSSA position in a shared exit block where a different path through the loop +; produces a non-invariant input to the PHI node. +define i32 @test8(i32* %var, i32 %cond1, i32 %cond2, i32 %x, i32 %y) { +; CHECK-LABEL: @test8( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 %cond1, label %entry.split [ +; CHECK-NEXT: i32 0, label %loop_exit.split +; CHECK-NEXT: i32 1, label %loop_exit2 +; CHECK-NEXT: i32 2, label %loop_exit.split +; CHECK-NEXT: ] +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + switch i32 %cond1, label %continue [ + i32 0, label %loop_exit + i32 1, label %loop_exit2 + i32 2, label %loop_exit + ] +; CHECK: loop_begin: +; CHECK-NEXT: br label %continue + +continue: + %var_val = load i32, i32* %var + switch i32 %cond2, label %latch [ + i32 0, label %loop_exit + ] +; CHECK: continue: +; CHECK-NEXT: load +; CHECK-NEXT: switch i32 %cond2, label %latch [ +; CHECK-NEXT: i32 0, label %loop_exit +; CHECK-NEXT: ] + +latch: + call void @some_func() noreturn nounwind + br label %loop_begin +; CHECK: latch: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_begin + +loop_exit: + %result1.1 = phi i32 [ %x, %loop_begin ], [ %x, %loop_begin ], [ %var_val, %continue ] + %result1.2 = phi i32 [ %var_val, %continue ], [ %y, %loop_begin ], [ %y, %loop_begin ] + %result1 = add i32 %result1.1, %result1.2 + ret i32 %result1 +; CHECK: loop_exit: +; CHECK-NEXT: %[[R1:.*]] = phi i32 [ %var_val, %continue ] +; CHECK-NEXT: %[[R2:.*]] = phi i32 [ %var_val, %continue ] +; CHECK-NEXT: br label %loop_exit.split +; +; CHECK: loop_exit.split: +; CHECK-NEXT: %[[R1S:.*]] = phi i32 [ %x, %entry ], [ %x, %entry ], [ %[[R1]], %loop_exit ] +; CHECK-NEXT: %[[R2S:.*]] = phi i32 [ %y, %entry ], [ %y, %entry ], [ %[[R2]], %loop_exit ] +; CHECK-NEXT: %[[R:.*]] = add i32 %[[R1S]], %[[R2S]] +; CHECK-NEXT: ret i32 %[[R]] + +loop_exit2: + %result2.1 = phi i32 [ %x, %loop_begin ] + %result2.2 = phi i32 [ %y, %loop_begin ] + %result2 = add i32 %result2.1, %result2.2 + ret i32 %result2 +; CHECK: loop_exit2: +; CHECK-NEXT: %[[R1:.*]] = phi i32 [ %x, %entry ] +; CHECK-NEXT: %[[R2:.*]] = phi i32 [ %y, %entry ] +; CHECK-NEXT: %[[R:.*]] = add i32 %[[R1]], %[[R2]] +; CHECK-NEXT: ret i32 %[[R]] +}