Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -594,11 +594,13 @@ ExitCaseIndices.push_back(Case.getCaseIndex()); } BasicBlock *DefaultExitBB = nullptr; + SwitchInstProfUpdateWrapper::CaseWeightOpt DefaultCaseWeight = + SwitchInstProfUpdateWrapper::getSuccessorWeight(SI, 0); if (!L.contains(SI.getDefaultDest()) && areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && - !isa(SI.getDefaultDest()->getTerminator())) + !isa(SI.getDefaultDest()->getTerminator())) { DefaultExitBB = SI.getDefaultDest(); - else if (ExitCaseIndices.empty()) + } else if (ExitCaseIndices.empty()) return false; LLVM_DEBUG(dbgs() << " unswitching trivial switch...\n"); @@ -622,8 +624,11 @@ // Store the exit cases into a separate data structure and remove them from // the switch. - SmallVector, 4> ExitCases; + SmallVector, + 4> ExitCases; ExitCases.reserve(ExitCaseIndices.size()); + SwitchInstProfUpdateWrapper SIW(SI); // We walk the case indices backwards so that we remove the last case first // and don't disrupt the earlier indices. for (unsigned Index : reverse(ExitCaseIndices)) { @@ -633,9 +638,10 @@ if (!ExitL || ExitL->contains(OuterL)) OuterL = ExitL; // Save the value of this case. - ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); + auto W = SIW.getSuccessorWeight(CaseI->getSuccessorIndex()); + ExitCases.emplace_back(CaseI->getCaseValue(), CaseI->getCaseSuccessor(), W); // Delete the unswitched cases. - SI.removeCase(CaseI); + SIW.removeCase(CaseI); } if (SE) { @@ -673,6 +679,7 @@ // Now add the unswitched switch. auto *NewSI = SwitchInst::Create(LoopCond, NewPH, ExitCases.size(), OldPH); + SwitchInstProfUpdateWrapper NewSIW(*NewSI); // Rewrite the IR for the unswitched basic blocks. This requires two steps. // First, we split any exit blocks with remaining in-loop predecessors. Then @@ -700,9 +707,9 @@ } // Note that we must use a reference in the for loop so that we update the // container. - for (auto &CasePair : reverse(ExitCases)) { + for (auto &ExitCase : reverse(ExitCases)) { // Grab a reference to the exit block in the pair so that we can update it. - BasicBlock *ExitBB = CasePair.second; + BasicBlock *ExitBB = std::get<1>(ExitCase); // If this case is the last edge into the exit block, we can simply reuse it // as it will no longer be a loop exit. No mapping necessary. @@ -724,27 +731,39 @@ /*FullUnswitch*/ true); } // Update the case pair to point to the split block. - CasePair.second = SplitExitBB; + std::get<1>(ExitCase) = SplitExitBB; } // Now add the unswitched cases. We do this in reverse order as we built them // in reverse order. - for (auto CasePair : reverse(ExitCases)) { - ConstantInt *CaseVal = CasePair.first; - BasicBlock *UnswitchedBB = CasePair.second; + for (auto &ExitCase : reverse(ExitCases)) { + ConstantInt *CaseVal = std::get<0>(ExitCase); + BasicBlock *UnswitchedBB = std::get<1>(ExitCase); - NewSI->addCase(CaseVal, UnswitchedBB); + NewSIW.addCase(CaseVal, UnswitchedBB, std::get<2>(ExitCase)); } // If the default was unswitched, re-point it and add explicit cases for // entering the loop. if (DefaultExitBB) { - NewSI->setDefaultDest(DefaultExitBB); + NewSIW->setDefaultDest(DefaultExitBB); + NewSIW.setSuccessorWeight(0, DefaultCaseWeight); // We removed all the exit cases, so we just copy the cases to the // unswitched switch. - for (auto Case : SI.cases()) - NewSI->addCase(Case.getCaseValue(), NewPH); + for (const auto &Case : SI.cases()) + NewSIW.addCase(Case.getCaseValue(), NewPH, + SIW.getSuccessorWeight(Case.getSuccessorIndex())); + } else if (DefaultCaseWeight) { + // We have to set branch weight of the default case. + uint64_t SW = *DefaultCaseWeight; + for (const auto &Case : SI.cases()) { + auto W = SIW.getSuccessorWeight(Case.getSuccessorIndex()); + assert(W && + "case weight must be defined as default case weight is defined"); + SW += *W; + } + NewSIW.setSuccessorWeight(0, SW); } // If we ended up with a common successor for every path through the switch @@ -769,7 +788,7 @@ /*KeepOneInputPHIs*/ true); } // Now nuke the switch and replace it with a direct branch. - SI.eraseFromParent(); + SIW.eraseFromParent(); BranchInst::Create(CommonSuccBB, BB); } else if (DefaultExitBB) { assert(SI.getNumCases() > 0 && @@ -779,8 +798,11 @@ // being simple and keeping the number of edges from this switch to // successors the same, and avoiding any PHI update complexity. auto LastCaseI = std::prev(SI.case_end()); + SI.setDefaultDest(LastCaseI->getCaseSuccessor()); - SI.removeCase(LastCaseI); + SIW.setSuccessorWeight( + 0, SIW.getSuccessorWeight(LastCaseI->getSuccessorIndex())); + SIW.removeCase(LastCaseI); } // Walk the unswitched exit blocks and the unswitched split blocks and update Index: llvm/test/Transforms/SimpleLoopUnswitch/basictest-profmd.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimpleLoopUnswitch/basictest-profmd.ll @@ -0,0 +1,34 @@ +; RUN: opt -passes='loop(unswitch),verify' -S < %s | FileCheck %s +; RUN: opt -enable-mssa-loop-dependency=true -verify-memoryssa -passes='loop(unswitch),verify' -S < %s | FileCheck %s + +declare void @incf() +declare void @decf() + +define i32 @test2(i32 %c) { +; CHECK-LABEL: @test2( + br label %loop_begin + +; CHECK: !prof ![[MD0:[0-9]+]] +; CHECK: loop_begin: +; CHECK: !prof ![[MD1:[0-9]+]] +loop_begin: + + switch i32 %c, label %default [ + i32 1, label %inc + i32 2, label %dec + ], !prof !{!"branch_weights", i32 99, i32 1, i32 2} + +inc: + call void @incf() + br label %loop_begin + +dec: + call void @decf() + br label %loop_begin + +default: + ret i32 0 +} + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 99, i32 1, i32 2} +; CHECK: ![[MD1]] = !{!"branch_weights", i32 2, i32 1} Index: llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-profmd.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch-profmd.ll @@ -0,0 +1,228 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; then metadata checks MDn were added manually. +; RUN: opt -passes='loop(unswitch),verify' -S < %s | FileCheck %s +; RUN: opt -enable-mssa-loop-dependency=true -verify-memoryssa -passes='loop(unswitch),verify' -S < %s | FileCheck %s + +declare void @some_func() + +; Test for a trivially unswitchable switch with non-default case exiting. +define i32 @test2(i32* %var, i32 %cond1, i32 %cond2) { +; CHECK-LABEL: @test2( +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 [[COND2:%.*]], label [[ENTRY_SPLIT:%.*]] [ +; CHECK-NEXT: i32 2, label [[LOOP_EXIT2:%.*]] +; CHECK-NEXT: ], !prof ![[MD0:[0-9]+]] +; CHECK: entry.split: +; CHECK-NEXT: br label [[LOOP_BEGIN:%.*]] +; CHECK: loop_begin: +; CHECK-NEXT: [[VAR_VAL:%.*]] = load i32, i32* [[VAR:%.*]] +; CHECK-NEXT: switch i32 [[COND2]], label [[LOOP2:%.*]] [ +; CHECK-NEXT: i32 0, label [[LOOP0:%.*]] +; CHECK-NEXT: i32 1, label [[LOOP1:%.*]] +; CHECK-NEXT: ], !prof ![[MD1:[0-9]+]] +; CHECK: loop0: +; CHECK-NEXT: call void @some_func() +; CHECK-NEXT: br label [[LOOP_LATCH:%.*]] +; CHECK: loop1: +; CHECK-NEXT: call void @some_func() +; CHECK-NEXT: br label [[LOOP_LATCH]] +; CHECK: loop2: +; CHECK-NEXT: call void @some_func() +; CHECK-NEXT: br label [[LOOP_LATCH]] +; CHECK: loop_latch: +; CHECK-NEXT: br label [[LOOP_BEGIN]] +; CHECK: loop_exit1: +; CHECK-NEXT: ret i32 0 +; CHECK: loop_exit2: +; CHECK-NEXT: ret i32 0 +; CHECK: loop_exit3: +; CHECK-NEXT: ret i32 0 +; +entry: + br label %loop_begin + +loop_begin: + %var_val = load i32, i32* %var + switch i32 %cond2, label %loop2 [ + i32 0, label %loop0 + i32 1, label %loop1 + i32 2, label %loop_exit2 + ], !prof !{!"branch_weights", i32 99, i32 100, i32 101, i32 102} + +loop0: + call void @some_func() + br label %loop_latch + +loop1: + call void @some_func() + br label %loop_latch + +loop2: + call void @some_func() + br label %loop_latch + +loop_latch: + br label %loop_begin + +loop_exit1: + ret i32 0 + +loop_exit2: + ret i32 0 + +loop_exit3: + ret i32 0 +} + +; Test for a trivially unswitchable switch with only the default case exiting. +define i32 @test3(i32* %var, i32 %cond1, i32 %cond2) { +; CHECK-LABEL: @test3( +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 [[COND2:%.*]], label [[LOOP_EXIT2:%.*]] [ +; CHECK-NEXT: i32 0, label [[ENTRY_SPLIT:%.*]] +; CHECK-NEXT: i32 1, label [[ENTRY_SPLIT]] +; CHECK-NEXT: i32 2, label [[ENTRY_SPLIT]] +; CHECK-NEXT: ], !prof ![[MD2:[0-9]+]] +; CHECK: entry.split: +; CHECK-NEXT: br label [[LOOP_BEGIN:%.*]] +; CHECK: loop_begin: +; CHECK-NEXT: [[VAR_VAL:%.*]] = load i32, i32* [[VAR:%.*]] +; CHECK-NEXT: switch i32 [[COND2]], label [[LOOP2:%.*]] [ +; CHECK-NEXT: i32 0, label [[LOOP0:%.*]] +; CHECK-NEXT: i32 1, label [[LOOP1:%.*]] +; CHECK-NEXT: ], !prof ![[MD3:[0-9]+]] +; CHECK: loop0: +; CHECK-NEXT: call void @some_func() +; CHECK-NEXT: br label [[LOOP_LATCH:%.*]] +; CHECK: loop1: +; CHECK-NEXT: call void @some_func() +; CHECK-NEXT: br label [[LOOP_LATCH]] +; CHECK: loop2: +; CHECK-NEXT: call void @some_func() +; CHECK-NEXT: br label [[LOOP_LATCH]] +; CHECK: loop_latch: +; CHECK-NEXT: br label [[LOOP_BEGIN]] +; CHECK: loop_exit1: +; CHECK-NEXT: ret i32 0 +; CHECK: loop_exit2: +; CHECK-NEXT: ret i32 0 +; CHECK: loop_exit3: +; CHECK-NEXT: ret i32 0 +; +entry: + br label %loop_begin + +loop_begin: + %var_val = load i32, i32* %var + switch i32 %cond2, label %loop_exit2 [ + i32 0, label %loop0 + i32 1, label %loop1 + i32 2, label %loop2 + ], !prof !{!"branch_weights", i32 99, i32 100, i32 101, i32 102} + +loop0: + call void @some_func() + br label %loop_latch + +loop1: + call void @some_func() + br label %loop_latch + +loop2: + call void @some_func() + br label %loop_latch + +loop_latch: + br label %loop_begin + +loop_exit1: + ret i32 0 + +loop_exit2: + ret i32 0 + +loop_exit3: + ret i32 0 +} + +; Test for a trivially unswitchable switch with multiple exiting cases and +; multiple looping cases. +define i32 @test4(i32* %var, i32 %cond1, i32 %cond2) { +; CHECK-LABEL: @test4( +; CHECK-NEXT: entry: +; CHECK-NEXT: switch i32 [[COND2:%.*]], label [[LOOP_EXIT2:%.*]] [ +; CHECK-NEXT: i32 13, label [[LOOP_EXIT1:%.*]] +; CHECK-NEXT: i32 42, label [[LOOP_EXIT3:%.*]] +; CHECK-NEXT: i32 0, label [[ENTRY_SPLIT:%.*]] +; CHECK-NEXT: i32 1, label [[ENTRY_SPLIT]] +; CHECK-NEXT: i32 2, label [[ENTRY_SPLIT]] +; CHECK-NEXT: ], !prof ![[MD4:[0-9]+]] +; CHECK: entry.split: +; CHECK-NEXT: br label [[LOOP_BEGIN:%.*]] +; CHECK: loop_begin: +; CHECK-NEXT: [[VAR_VAL:%.*]] = load i32, i32* [[VAR:%.*]] +; CHECK-NEXT: switch i32 [[COND2]], label [[LOOP2:%.*]] [ +; CHECK-NEXT: i32 0, label [[LOOP0:%.*]] +; CHECK-NEXT: i32 1, label [[LOOP1:%.*]] +; CHECK-NEXT: ], !prof ![[MD3:[0-9]+]] +; CHECK: loop0: +; CHECK-NEXT: call void @some_func() +; CHECK-NEXT: br label [[LOOP_LATCH:%.*]] +; CHECK: loop1: +; CHECK-NEXT: call void @some_func() +; CHECK-NEXT: br label [[LOOP_LATCH]] +; CHECK: loop2: +; CHECK-NEXT: call void @some_func() +; CHECK-NEXT: br label [[LOOP_LATCH]] +; CHECK: loop_latch: +; CHECK-NEXT: br label [[LOOP_BEGIN]] +; CHECK: loop_exit1: +; CHECK-NEXT: ret i32 0 +; CHECK: loop_exit2: +; CHECK-NEXT: ret i32 0 +; CHECK: loop_exit3: +; CHECK-NEXT: ret i32 0 +; +entry: + br label %loop_begin + +loop_begin: + %var_val = load i32, i32* %var + switch i32 %cond2, label %loop_exit2 [ + i32 0, label %loop0 + i32 1, label %loop1 + i32 13, label %loop_exit1 + i32 2, label %loop2 + i32 42, label %loop_exit3 + ], !prof !{!"branch_weights", i32 99, i32 100, i32 101, i32 113, i32 102, i32 142} + +loop0: + call void @some_func() + br label %loop_latch + +loop1: + call void @some_func() + br label %loop_latch + +loop2: + call void @some_func() + br label %loop_latch + +loop_latch: + br label %loop_begin + +loop_exit1: + ret i32 0 + +loop_exit2: + ret i32 0 + +loop_exit3: + ret i32 0 +} + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 300, i32 102} +; CHECK: ![[MD1]] = !{!"branch_weights", i32 99, i32 100, i32 101} +; CHECK: ![[MD2]] = !{!"branch_weights", i32 99, i32 100, i32 101, i32 102} +; CHECK: ![[MD3]] = !{!"branch_weights", i32 102, i32 100, i32 101} +; CHECK: ![[MD4]] = !{!"branch_weights", i32 99, i32 113, i32 142, i32 100, i32 101, i32 102}