Index: llvm/include/llvm/IR/Instructions.h =================================================================== --- llvm/include/llvm/IR/Instructions.h +++ llvm/include/llvm/IR/Instructions.h @@ -3468,7 +3468,7 @@ /// This action invalidates iterators for all cases following the one removed, /// including the case_end() iterator. It returns an iterator for the next /// case. - CaseIt removeCase(CaseIt I); + CaseIt removeCase(CaseIt I, bool DropPerfMetadata = false); unsigned getNumSuccessors() const { return getNumOperands()/2; } BasicBlock *getSuccessor(unsigned idx) const { Index: llvm/lib/IR/Instructions.cpp =================================================================== --- llvm/lib/IR/Instructions.cpp +++ llvm/lib/IR/Instructions.cpp @@ -28,6 +28,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" @@ -3770,6 +3771,13 @@ /// addCase - Add an entry to the switch instruction... /// void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest) { + // Drop metdadata. + MDNode *ProfileData = getMetadata(LLVMContext::MD_prof); + if (ProfileData && ProfileData->getNumOperands() == getNumSuccessors() + 1) + if (auto *MDName = dyn_cast(ProfileData->getOperand(0))) + if (MDName->getString() == "branch_weights") + setMetadata(LLVMContext::MD_prof, nullptr); + unsigned NewCaseIdx = getNumCases(); unsigned OpNo = getNumOperands(); if (OpNo+2 > ReservedSpace) @@ -3784,7 +3792,7 @@ /// removeCase - This method removes the specified case and its successor /// from the switch instruction. -SwitchInst::CaseIt SwitchInst::removeCase(CaseIt I) { +SwitchInst::CaseIt SwitchInst::removeCase(CaseIt I, bool DropPerfMetadata) { unsigned idx = I->getCaseIndex(); assert(2 + idx*2 < getNumOperands() && "Case index out of range!!!"); @@ -3792,6 +3800,33 @@ unsigned NumOps = getNumOperands(); Use *OL = getOperandList(); + // Update metdadata. + MDNode *ProfileData = getMetadata(LLVMContext::MD_prof); + if (ProfileData && ProfileData->getNumOperands() == getNumSuccessors() + 1) + if (auto *MDName = dyn_cast(ProfileData->getOperand(0))) + if (MDName->getString() == "branch_weights") { + if (DropPerfMetadata) { + setMetadata(LLVMContext::MD_prof, nullptr); + } else { + SmallVector Weights; + uint32_t SW = 0; + for (unsigned CI = 0, CE = getNumSuccessors() - 1; CI < CE; ++CI) { + // At the removed position we put the end of the list. + unsigned WI = (CI == idx + 1 ? CE : CI) + 1; + ConstantInt *C = mdconst::extract( + ProfileData->getOperand(WI)); + uint32_t CW = C->getValue().getZExtValue(); + SW += CW; + Weights.push_back(CW); + } + assert(Weights.size() == getNumSuccessors() - 1); + auto *NewProfileData = !SW || Weights.size() < 2 ? nullptr : + MDBuilder(getParent()->getContext()) + .createBranchWeights(Weights); + setMetadata(LLVMContext::MD_prof, NewProfileData); + } + } + // Overwrite this case with the end of the list. if (2 + (idx + 1) * 2 != NumOps) { OL[2 + idx * 2] = OL[NumOps - 2]; Index: llvm/lib/IR/Verifier.cpp =================================================================== --- llvm/lib/IR/Verifier.cpp +++ llvm/lib/IR/Verifier.cpp @@ -2440,6 +2440,20 @@ "Duplicate integer as switch case", &SI, Case.getCaseValue()); } + // Check consistency of !prof branch_weights metadata if any. + auto *TI = &SI; + auto *ProfileData = TI->getMetadata(LLVMContext::MD_prof); + if (ProfileData != nullptr && ProfileData->getNumOperands() > 0) { + if (auto *ProfDataName = dyn_cast(ProfileData->getOperand(0))) + if (ProfDataName->getString().equals("branch_weights")) { + Assert(ProfileData->getNumOperands() == TI->getNumSuccessors() + 1, + "Number of !prof brunch_weights operands differ from number of successors"); + for (unsigned i = 1; i < ProfileData->getNumOperands(); ++i) + Assert(mdconst::dyn_extract(ProfileData->getOperand(i)), + "!prof brunch_weights operand is not a const int"); + } + } + visitTerminator(SI); } Index: llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -732,18 +732,20 @@ for (auto CasePair : reverse(ExitCases)) { ConstantInt *CaseVal = CasePair.first; BasicBlock *UnswitchedBB = CasePair.second; - + // TODO set weight. NewSI->addCase(CaseVal, UnswitchedBB); } // If the default was unswitched, re-point it and add explicit cases for // entering the loop. if (DefaultExitBB) { + // TODO set weight. NewSI->setDefaultDest(DefaultExitBB); // We removed all the exit cases, so we just copy the cases to the // unswitched switch. for (auto Case : SI.cases()) + // TODO set weight. NewSI->addCase(Case.getCaseValue(), NewPH); } @@ -779,6 +781,7 @@ // being simple and keeping the number of edges from this switch to // successors the same, and avoiding any PHI update complexity. auto LastCaseI = std::prev(SI.case_end()); + // TODO fix the default branch weight. SI.setDefaultDest(LastCaseI->getCaseSuccessor()); SI.removeCase(LastCaseI); } Index: llvm/lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -884,7 +884,7 @@ Weights.pop_back(); } i->getCaseSuccessor()->removePredecessor(TI->getParent()); - SI->removeCase(i); + SI->removeCase(i, true); } } if (HasWeight && Weights.size() >= 2) @@ -4461,7 +4461,7 @@ // Prune unused values from PHI nodes. CaseI->getCaseSuccessor()->removePredecessor(SI->getParent()); - SI->removeCase(CaseI); + SI->removeCase(CaseI, true); } if (HasWeight && Weights.size() >= 2) { SmallVector MDWeights(Weights.begin(), Weights.end()); Index: llvm/test/CodeGen/AArch64/ragreedy-csr.ll =================================================================== --- llvm/test/CodeGen/AArch64/ragreedy-csr.ll +++ llvm/test/CodeGen/AArch64/ragreedy-csr.ll @@ -278,8 +278,8 @@ !991 = !{!"branch_weights", i32 8677007, i32 4606493} !992 = !{!"branch_weights", i32 -1172426948, i32 145094705} !993 = !{!"branch_weights", i32 1468914, i32 5683688} -!994 = !{!"branch_weights", i32 114025221, i32 -1217548794, i32 -1199521551, i32 87712616} -!995 = !{!"branch_weights", i32 1853716452, i32 -444717951, i32 932776759} +!994 = !{!"branch_weights", i32 114025221, i32 -1217548794, i32 -1199521551} +!995 = !{!"branch_weights", i32 1853716452, i32 -444717951} !996 = !{!"branch_weights", i32 1004870, i32 20259} !997 = !{!"branch_weights", i32 20071, i32 189} !998 = !{!"branch_weights", i32 -1020255939, i32 572177766} Index: llvm/test/CodeGen/X86/ragreedy-bug.ll =================================================================== --- llvm/test/CodeGen/X86/ragreedy-bug.ll +++ llvm/test/CodeGen/X86/ragreedy-bug.ll @@ -291,8 +291,8 @@ !991 = !{!"branch_weights", i32 8677007, i32 4606493} !992 = !{!"branch_weights", i32 -1172426948, i32 145094705} !993 = !{!"branch_weights", i32 1468914, i32 5683688} -!994 = !{!"branch_weights", i32 114025221, i32 -1217548794, i32 -1199521551, i32 87712616} -!995 = !{!"branch_weights", i32 1853716452, i32 -444717951, i32 932776759} +!994 = !{!"branch_weights", i32 114025221, i32 -1217548794, i32 87712616} +!995 = !{!"branch_weights", i32 1853716452, i32 932776759} !996 = !{!"branch_weights", i32 1004870, i32 20259} !997 = !{!"branch_weights", i32 20071, i32 189} !998 = !{!"branch_weights", i32 -1020255939, i32 572177766} Index: llvm/test/Transforms/CorrelatedValuePropagation/basic.ll =================================================================== --- llvm/test/Transforms/CorrelatedValuePropagation/basic.ll +++ llvm/test/Transforms/CorrelatedValuePropagation/basic.ll @@ -152,7 +152,7 @@ ; CHECK-NEXT: switch i32 [[S]], label [[OUT]] [ ; CHECK-NEXT: i32 -2, label [[NEXT:%.*]] ; CHECK-NEXT: i32 -1, label [[NEXT]] -; CHECK-NEXT: ] +; CHECK-NEXT: !prof ![[MD0:[0-9]+]] ; CHECK: out: ; CHECK-NEXT: [[P:%.*]] = phi i32 [ 1, [[ENTRY:%.*]] ], [ -1, [[NEGATIVE]] ] ; CHECK-NEXT: ret i32 [[P]] @@ -171,7 +171,7 @@ i32 -2, label %next i32 2, label %out i32 3, label %out - ] + ], !prof !{!"branch_weights", i32 99, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6} out: %p = phi i32 [ 1, %entry ], [ -1, %negative ], [ -1, %negative ], [ -1, %negative ], [ -1, %negative ], [ -1, %negative ] @@ -723,3 +723,5 @@ exit: ret i1 %cmp } + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 99, i32 4, i32 3} Index: llvm/test/Transforms/SimpleLoopUnswitch/basictest.ll =================================================================== --- llvm/test/Transforms/SimpleLoopUnswitch/basictest.ll +++ llvm/test/Transforms/SimpleLoopUnswitch/basictest.ll @@ -41,6 +41,8 @@ br label %loop_begin +; CHECK: loop_begin: +; CHECK: !prof ![[MD0:[0-9]+]] loop_begin: %var_val = load i32, i32* %var @@ -48,7 +50,7 @@ switch i32 %c, label %default [ i32 1, label %inc i32 2, label %dec - ] + ], !prof !{!"branch_weights", i32 1, i32 1, i32 1} inc: call void @incf() noreturn nounwind @@ -183,3 +185,5 @@ declare void @incf() noreturn declare void @decf() noreturn declare void @conv() convergent + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 1, i32 1} Index: llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll =================================================================== --- llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll +++ llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll @@ -135,13 +135,13 @@ i32 13, label %loop_exit1 i32 2, label %loop2 i32 42, label %loop_exit3 - ] + ], !prof !{!"branch_weights", i32 99, i32 1, i32 2, i32 3, i32 4, i32 5} ; CHECK: loop_begin: ; CHECK-NEXT: load ; CHECK-NEXT: switch i32 %cond2, label %loop2 [ ; CHECK-NEXT: i32 0, label %loop0 ; CHECK-NEXT: i32 1, label %loop1 -; CHECK-NEXT: ] +; CHECK-NEXT: ], !prof ![[MD0:[0-9]+]] loop0: call void @some_func() noreturn nounwind @@ -1243,3 +1243,5 @@ ; CHECK: loopexit: ; CHECK-NEXT: ret } + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 99, i32 1, i32 2}