Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -594,11 +594,13 @@ ExitCaseIndices.push_back(Case.getCaseIndex()); } BasicBlock *DefaultExitBB = nullptr; + SwitchInstProfUpdateWrapper::CaseWeightOpt DefaultExitWeight; if (!L.contains(SI.getDefaultDest()) && areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && - !isa(SI.getDefaultDest()->getTerminator())) + !isa(SI.getDefaultDest()->getTerminator())) { DefaultExitBB = SI.getDefaultDest(); - else if (ExitCaseIndices.empty()) + DefaultExitWeight = SwitchInstProfUpdateWrapper::getSuccessorWeight(SI, 0); + } else if (ExitCaseIndices.empty()) return false; LLVM_DEBUG(dbgs() << " unswitching trivial switch...\n"); @@ -622,8 +624,11 @@ // Store the exit cases into a separate data structure and remove them from // the switch. - SmallVector, 4> ExitCases; + SmallVector, + 4> ExitCases; ExitCases.reserve(ExitCaseIndices.size()); + SwitchInstProfUpdateWrapper SIW(SI); // We walk the case indices backwards so that we remove the last case first // and don't disrupt the earlier indices. for (unsigned Index : reverse(ExitCaseIndices)) { @@ -633,9 +638,10 @@ if (!ExitL || ExitL->contains(OuterL)) OuterL = ExitL; // Save the value of this case. - ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); + auto W = SIW.getSuccessorWeight(CaseI->getSuccessorIndex()); + ExitCases.emplace_back(CaseI->getCaseValue(), CaseI->getCaseSuccessor(), W); // Delete the unswitched cases. - SI.removeCase(CaseI); + SIW.removeCase(CaseI); } if (SE) { @@ -672,7 +678,8 @@ OldPH->getTerminator()->eraseFromParent(); // Now add the unswitched switch. - auto *NewSI = SwitchInst::Create(LoopCond, NewPH, ExitCases.size(), OldPH); + SwitchInstProfUpdateWrapper NewSI = + *SwitchInst::Create(LoopCond, NewPH, ExitCases.size(), OldPH); // Rewrite the IR for the unswitched basic blocks. This requires two steps. // First, we split any exit blocks with remaining in-loop predecessors. Then @@ -700,9 +707,9 @@ } // Note that we must use a reference in the for loop so that we update the // container. - for (auto &CasePair : reverse(ExitCases)) { + for (auto &ExitCase : reverse(ExitCases)) { // Grab a reference to the exit block in the pair so that we can update it. - BasicBlock *ExitBB = CasePair.second; + BasicBlock *ExitBB = std::get<1>(ExitCase); // If this case is the last edge into the exit block, we can simply reuse it // as it will no longer be a loop exit. No mapping necessary. @@ -724,27 +731,29 @@ /*FullUnswitch*/ true); } // Update the case pair to point to the split block. - CasePair.second = SplitExitBB; + std::get<1>(ExitCase) = SplitExitBB; } // Now add the unswitched cases. We do this in reverse order as we built them // in reverse order. - for (auto CasePair : reverse(ExitCases)) { - ConstantInt *CaseVal = CasePair.first; - BasicBlock *UnswitchedBB = CasePair.second; + for (auto &ExitCase : reverse(ExitCases)) { + ConstantInt *CaseVal = std::get<0>(ExitCase); + BasicBlock *UnswitchedBB = std::get<1>(ExitCase); - NewSI->addCase(CaseVal, UnswitchedBB); + NewSI.addCase(CaseVal, UnswitchedBB, std::get<2>(ExitCase)); } // If the default was unswitched, re-point it and add explicit cases for // entering the loop. if (DefaultExitBB) { NewSI->setDefaultDest(DefaultExitBB); + NewSI.setSuccessorWeight(0, DefaultExitWeight); // We removed all the exit cases, so we just copy the cases to the // unswitched switch. - for (auto Case : SI.cases()) - NewSI->addCase(Case.getCaseValue(), NewPH); + for (const auto &Case : SI.cases()) + NewSI.addCase(Case.getCaseValue(), NewPH, + SIW.getSuccessorWeight(Case.getSuccessorIndex())); } // If we ended up with a common successor for every path through the switch @@ -769,7 +778,7 @@ /*KeepOneInputPHIs*/ true); } // Now nuke the switch and replace it with a direct branch. - SI.eraseFromParent(); + SIW.eraseFromParent(); BranchInst::Create(CommonSuccBB, BB); } else if (DefaultExitBB) { assert(SI.getNumCases() > 0 && @@ -779,8 +788,11 @@ // being simple and keeping the number of edges from this switch to // successors the same, and avoiding any PHI update complexity. auto LastCaseI = std::prev(SI.case_end()); + SI.setDefaultDest(LastCaseI->getCaseSuccessor()); - SI.removeCase(LastCaseI); + SIW.setSuccessorWeight( + 0, SIW.getSuccessorWeight(LastCaseI->getSuccessorIndex())); + SIW.removeCase(LastCaseI); } // Walk the unswitched exit blocks and the unswitched split blocks and update Index: llvm/test/Transforms/SimpleLoopUnswitch/basictest-profmd.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimpleLoopUnswitch/basictest-profmd.ll @@ -0,0 +1,46 @@ +; RUN: opt -passes='loop(unswitch),verify' -S < %s | FileCheck %s +; RUN: opt -enable-mssa-loop-dependency=true -verify-memoryssa -passes='loop(unswitch),verify' -S < %s | FileCheck %s + +; This simple test would normally unswitch, but should be inhibited by the presence of +; the noduplicate call. + +; CHECK-LABEL: @test2( +define i32 @test2(i32* %var) { + %mem = alloca i32 + store i32 2, i32* %mem + %c = load i32, i32* %mem + + br label %loop_begin + +; CHECK: !prof ![[MD0:[0-9]+]] +; CHECK: loop_begin: +; CHECK: !prof ![[MD1:[0-9]+]] +loop_begin: + + %var_val = load i32, i32* %var + + switch i32 %c, label %default [ + i32 1, label %inc + i32 2, label %dec + ], !prof !{!"branch_weights", i32 99, i32 1, i32 2} + +inc: + call void @incf() noreturn nounwind + br label %loop_begin +dec: +; CHECK: call void @decf() +; CHECK-NOT: call void @decf() + call void @decf() noreturn nounwind noduplicate + br label %loop_begin +default: + br label %loop_exit +loop_exit: + ret i32 0 +; CHECK: } +} + +declare void @incf() noreturn +declare void @decf() noreturn + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 99, i32 1, i32 2} +; CHECK: ![[MD1]] = !{!"branch_weights", i32 2, i32 1} Index: llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-profmd.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-profmd.ll @@ -0,0 +1,83 @@ +; RUN: opt -passes='loop(unswitch),verify' -S < %s | FileCheck %s +; RUN: opt -enable-mssa-loop-dependency=true -verify-memoryssa -passes='loop(unswitch),verify' -S < %s | FileCheck %s + +declare void @some_func() noreturn + +; Test for a trivially unswitchable switch with multiple exiting cases and +; multiple looping cases. +define i32 @test4(i32* %var, i32 %cond1, i32 %cond2) { +; CHECK-LABEL: @test4( +entry: + br label %loop_begin +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 %cond2, label %loop_exit2 [ +; CHECK-NEXT: i32 13, label %loop_exit1 +; CHECK-NEXT: i32 42, label %loop_exit3 +; CHECK-NEXT: i32 0, label %entry.split +; CHECK-NEXT: i32 1, label %entry.split +; CHECK-NEXT: i32 2, label %entry.split +; CHECK-NEXT: ], !prof ![[MD0:[0-9]+]] +; +; CHECK: entry.split: +; CHECK-NEXT: br label %loop_begin + +loop_begin: + %var_val = load i32, i32* %var + switch i32 %cond2, label %loop_exit2 [ + i32 0, label %loop0 + i32 1, label %loop1 + i32 13, label %loop_exit1 + i32 2, label %loop2 + i32 42, label %loop_exit3 + ], !prof !{!"branch_weights", i32 99, i32 100, i32 101, i32 113, i32 102, i32 142} +; CHECK: loop_begin: +; CHECK-NEXT: load +; CHECK-NEXT: switch i32 %cond2, label %loop2 [ +; CHECK-NEXT: i32 0, label %loop0 +; CHECK-NEXT: i32 1, label %loop1 +; CHECK-NEXT: ], !prof ![[MD1:[0-9]+]] + +loop0: + call void @some_func() noreturn nounwind + br label %loop_latch +; CHECK: loop0: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_latch + +loop1: + call void @some_func() noreturn nounwind + br label %loop_latch +; CHECK: loop1: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_latch + +loop2: + call void @some_func() noreturn nounwind + br label %loop_latch +; CHECK: loop2: +; CHECK-NEXT: call +; CHECK-NEXT: br label %loop_latch + +loop_latch: + br label %loop_begin +; CHECK: loop_latch: +; CHECK-NEXT: br label %loop_begin + +loop_exit1: + ret i32 0 +; CHECK: loop_exit1: +; CHECK-NEXT: ret + +loop_exit2: + ret i32 0 +; CHECK: loop_exit2: +; CHECK-NEXT: ret + +loop_exit3: + ret i32 0 +; CHECK: loop_exit3: +; CHECK-NEXT: ret +} + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 99, i32 113, i32 142, i32 100, i32 101, i32 102} +; CHECK: ![[MD1]] = !{!"branch_weights", i32 102, i32 100, i32 101}