Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -594,11 +594,13 @@ ExitCaseIndices.push_back(Case.getCaseIndex()); } BasicBlock *DefaultExitBB = nullptr; + SwitchInst::CaseWeightOpt DefaultExitWeight; if (!L.contains(SI.getDefaultDest()) && areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && - !isa(SI.getDefaultDest()->getTerminator())) + !isa(SI.getDefaultDest()->getTerminator())) { DefaultExitBB = SI.getDefaultDest(); - else if (ExitCaseIndices.empty()) + DefaultExitWeight = SI.getDefaultCaseWeight(); + } else if (ExitCaseIndices.empty()) return false; LLVM_DEBUG(dbgs() << " unswitching trivial switch...\n"); @@ -622,7 +624,7 @@ // Store the exit cases into a separate data structure and remove them from // the switch. - SmallVector, 4> ExitCases; + SmallVector, 4> ExitCases; ExitCases.reserve(ExitCaseIndices.size()); // We walk the case indices backwards so that we remove the last case first // and don't disrupt the earlier indices. @@ -633,7 +635,9 @@ if (!ExitL || ExitL->contains(OuterL)) OuterL = ExitL; // Save the value of this case. - ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); + ExitCases.emplace_back(CaseI->getCaseValue(), + CaseI->getCaseSuccessor(), + SI.getSuccessorWeight(CaseI->getSuccessorIndex())); // Delete the unswitched cases. SI.removeCase(CaseI); } @@ -700,9 +704,9 @@ } // Note that we must use a reference in the for loop so that we update the // container. - for (auto &CasePair : reverse(ExitCases)) { + for (auto &ExitCase : reverse(ExitCases)) { // Grab a reference to the exit block in the pair so that we can update it. - BasicBlock *ExitBB = CasePair.second; + BasicBlock *ExitBB = std::get<1>(ExitCase); // If this case is the last edge into the exit block, we can simply reuse it // as it will no longer be a loop exit. No mapping necessary. @@ -724,27 +728,28 @@ /*FullUnswitch*/ true); } // Update the case pair to point to the split block. - CasePair.second = SplitExitBB; + std::get<1>(ExitCase) = SplitExitBB; } // Now add the unswitched cases. We do this in reverse order as we built them // in reverse order. - for (auto CasePair : reverse(ExitCases)) { - ConstantInt *CaseVal = CasePair.first; - BasicBlock *UnswitchedBB = CasePair.second; + for (auto &ExitCase : reverse(ExitCases)) { + ConstantInt *CaseVal = std::get<0>(ExitCase); + BasicBlock *UnswitchedBB = std::get<1>(ExitCase); - NewSI->addCase(CaseVal, UnswitchedBB); + NewSI->addCase(CaseVal, UnswitchedBB, std::get<2>(ExitCase)); } // If the default was unswitched, re-point it and add explicit cases for // entering the loop. if (DefaultExitBB) { - NewSI->setDefaultDest(DefaultExitBB); + NewSI->setDefaultDest(DefaultExitBB, DefaultExitWeight); // We removed all the exit cases, so we just copy the cases to the // unswitched switch. - for (auto Case : SI.cases()) - NewSI->addCase(Case.getCaseValue(), NewPH); + for (const auto &Case : SI.cases()) + NewSI->addCase(Case.getCaseValue(), NewPH, + SI.getSuccessorWeight(Case.getSuccessorIndex())); } // If we ended up with a common successor for every path through the switch @@ -779,7 +784,9 @@ // being simple and keeping the number of edges from this switch to // successors the same, and avoiding any PHI update complexity. auto LastCaseI = std::prev(SI.case_end()); - SI.setDefaultDest(LastCaseI->getCaseSuccessor()); + + SI.setDefaultDest(LastCaseI->getCaseSuccessor(), + SI.getSuccessorWeight(LastCaseI->getSuccessorIndex())); SI.removeCase(LastCaseI); } Index: llvm/test/Transforms/SimpleLoopUnswitch/basictest.ll =================================================================== --- llvm/test/Transforms/SimpleLoopUnswitch/basictest.ll +++ llvm/test/Transforms/SimpleLoopUnswitch/basictest.ll @@ -50,7 +50,7 @@ switch i32 %c, label %default [ i32 1, label %inc i32 2, label %dec - ], !prof !{!"branch_weights", i32 1, i32 1, i32 1} + ], !prof !{!"branch_weights", i32 99, i32 1, i32 2} inc: call void @incf() noreturn nounwind @@ -186,4 +186,4 @@ declare void @decf() noreturn declare void @conv() convergent -; CHECK: ![[MD0]] = !{!"branch_weights", i32 1, i32 1} +; CHECK: ![[MD0]] = !{!"branch_weights", i32 2, i32 1} Index: llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll =================================================================== --- llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll +++ llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll @@ -135,7 +135,7 @@ i32 13, label %loop_exit1 i32 2, label %loop2 i32 42, label %loop_exit3 - ], !prof !{!"branch_weights", i32 1, i32 1, i32 1, i32 1, i32 1, i32 1} + ], !prof !{!"branch_weights", i32 99, i32 1, i32 2, i32 3, i32 4, i32 5} ; CHECK: loop_begin: ; CHECK-NEXT: load ; CHECK-NEXT: switch i32 %cond2, label %loop2 [ @@ -1244,4 +1244,4 @@ ; CHECK-NEXT: ret } -; CHECK: ![[MD0]] = !{!"branch_weights", i32 1, i32 1, i32 1} +; CHECK: ![[MD0]] = !{!"branch_weights", i32 4, i32 1, i32 2}