Index: llvm/include/llvm/IR/Instructions.h =================================================================== --- llvm/include/llvm/IR/Instructions.h +++ llvm/include/llvm/IR/Instructions.h @@ -3468,7 +3468,7 @@ /// This action invalidates iterators for all cases following the one removed, /// including the case_end() iterator. It returns an iterator for the next /// case. - CaseIt removeCase(CaseIt I); + CaseIt removeCase(CaseIt I, bool DropPerfMetadata = false); unsigned getNumSuccessors() const { return getNumOperands()/2; } BasicBlock *getSuccessor(unsigned idx) const { Index: llvm/lib/IR/Instructions.cpp =================================================================== --- llvm/lib/IR/Instructions.cpp +++ llvm/lib/IR/Instructions.cpp @@ -28,6 +28,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" @@ -3767,9 +3768,22 @@ SubclassOptionalData = SI.SubclassOptionalData; } +static MDNode *getProfBranchWeightsMD(const SwitchInst *SI) { + if (MDNode *ProfileData = SI->getMetadata(LLVMContext::MD_prof)) + if (auto *MDName = dyn_cast(ProfileData->getOperand(0))) + if (MDName->getString() == "branch_weights") + return ProfileData; + return nullptr; +} + /// addCase - Add an entry to the switch instruction... /// void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest) { + // Drop metdadata. + if (MDNode *ProfileData = getProfBranchWeightsMD(this)) + if (ProfileData->getNumOperands() == getNumSuccessors() + 1) + setMetadata(LLVMContext::MD_prof, nullptr); + unsigned NewCaseIdx = getNumCases(); unsigned OpNo = getNumOperands(); if (OpNo+2 > ReservedSpace) @@ -3784,11 +3798,36 @@ /// removeCase - This method removes the specified case and its successor /// from the switch instruction. -SwitchInst::CaseIt SwitchInst::removeCase(CaseIt I) { +SwitchInst::CaseIt SwitchInst::removeCase(CaseIt I, bool DropPerfMetadata) { unsigned idx = I->getCaseIndex(); assert(2 + idx*2 < getNumOperands() && "Case index out of range!!!"); + if (MDNode *ProfileData = getProfBranchWeightsMD(this)) + if (ProfileData->getNumOperands() == getNumSuccessors() + 1) { + if (DropPerfMetadata) { + setMetadata(LLVMContext::MD_prof, nullptr); + } else { + SmallVector Weights; + uint64_t SW = 0; + for (unsigned CI = 0, CE = getNumSuccessors() - 1; CI < CE; ++CI) { + // At the removed position we put the end of the list. + unsigned WI = (CI == idx + 1 ? CE : CI) + 1; + ConstantInt *C = + mdconst::extract(ProfileData->getOperand(WI)); + uint32_t CW = C->getValue().getZExtValue(); + SW += CW; + Weights.push_back(CW); + } + assert(Weights.size() == getNumSuccessors() - 1); + auto *NewProfileData = !SW || Weights.size() < 2 + ? nullptr + : MDBuilder(getParent()->getContext()) + .createBranchWeights(Weights); + setMetadata(LLVMContext::MD_prof, NewProfileData); + } + } + unsigned NumOps = getNumOperands(); Use *OL = getOperandList(); Index: llvm/lib/IR/Verifier.cpp =================================================================== --- llvm/lib/IR/Verifier.cpp +++ llvm/lib/IR/Verifier.cpp @@ -2440,6 +2440,20 @@ "Duplicate integer as switch case", &SI, Case.getCaseValue()); } + // Check consistency of !prof branch_weights metadata if any. + auto *ProfileData = SI.getMetadata(LLVMContext::MD_prof); + if (ProfileData != nullptr && ProfileData->getNumOperands() > 0) { + if (auto *ProfDataName = dyn_cast(ProfileData->getOperand(0))) + if (ProfDataName->getString().equals("branch_weights")) { + Assert(ProfileData->getNumOperands() == SI.getNumSuccessors() + 1, + "Number of !prof brunch_weights operands differ from number of " + "successors"); + for (unsigned i = 1; i < ProfileData->getNumOperands(); ++i) + Assert(mdconst::dyn_extract(ProfileData->getOperand(i)), + "!prof brunch_weights operand is not a const int"); + } + } + visitTerminator(SI); } Index: llvm/lib/Transforms/Utils/SimplifyCFG.cpp =================================================================== --- llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -884,7 +884,7 @@ Weights.pop_back(); } i->getCaseSuccessor()->removePredecessor(TI->getParent()); - SI->removeCase(i); + SI->removeCase(i, true); } } if (HasWeight && Weights.size() >= 2) @@ -4461,7 +4461,7 @@ // Prune unused values from PHI nodes. CaseI->getCaseSuccessor()->removePredecessor(SI->getParent()); - SI->removeCase(CaseI); + SI->removeCase(CaseI, true); } if (HasWeight && Weights.size() >= 2) { SmallVector MDWeights(Weights.begin(), Weights.end()); Index: llvm/test/Transforms/CorrelatedValuePropagation/basic.ll =================================================================== --- llvm/test/Transforms/CorrelatedValuePropagation/basic.ll +++ llvm/test/Transforms/CorrelatedValuePropagation/basic.ll @@ -152,7 +152,7 @@ ; CHECK-NEXT: switch i32 [[S]], label [[OUT]] [ ; CHECK-NEXT: i32 -2, label [[NEXT:%.*]] ; CHECK-NEXT: i32 -1, label [[NEXT]] -; CHECK-NEXT: ] +; CHECK-NEXT: !prof ![[MD0:[0-9]+]] ; CHECK: out: ; CHECK-NEXT: [[P:%.*]] = phi i32 [ 1, [[ENTRY:%.*]] ], [ -1, [[NEGATIVE]] ] ; CHECK-NEXT: ret i32 [[P]] @@ -171,7 +171,7 @@ i32 -2, label %next i32 2, label %out i32 3, label %out - ] + ], !prof !{!"branch_weights", i32 99, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6} out: %p = phi i32 [ 1, %entry ], [ -1, %negative ], [ -1, %negative ], [ -1, %negative ], [ -1, %negative ], [ -1, %negative ] @@ -723,3 +723,5 @@ exit: ret i1 %cmp } + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 99, i32 4, i32 3} Index: llvm/test/Transforms/SimpleLoopUnswitch/basictest.ll =================================================================== --- llvm/test/Transforms/SimpleLoopUnswitch/basictest.ll +++ llvm/test/Transforms/SimpleLoopUnswitch/basictest.ll @@ -41,6 +41,8 @@ br label %loop_begin +; CHECK: loop_begin: +; CHECK: !prof ![[MD0:[0-9]+]] loop_begin: %var_val = load i32, i32* %var @@ -48,7 +50,7 @@ switch i32 %c, label %default [ i32 1, label %inc i32 2, label %dec - ] + ], !prof !{!"branch_weights", i32 1, i32 1, i32 1} inc: call void @incf() noreturn nounwind @@ -183,3 +185,5 @@ declare void @incf() noreturn declare void @decf() noreturn declare void @conv() convergent + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 1, i32 1} Index: llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll =================================================================== --- llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll +++ llvm/test/Transforms/SimpleLoopUnswitch/trivial-unswitch.ll @@ -135,13 +135,13 @@ i32 13, label %loop_exit1 i32 2, label %loop2 i32 42, label %loop_exit3 - ] + ], !prof !{!"branch_weights", i32 1, i32 1, i32 1, i32 1, i32 1, i32 1} ; CHECK: loop_begin: ; CHECK-NEXT: load ; CHECK-NEXT: switch i32 %cond2, label %loop2 [ ; CHECK-NEXT: i32 0, label %loop0 ; CHECK-NEXT: i32 1, label %loop1 -; CHECK-NEXT: ] +; CHECK-NEXT: ], !prof ![[MD0:[0-9]+]] loop0: call void @some_func() noreturn nounwind @@ -1243,3 +1243,5 @@ ; CHECK: loopexit: ; CHECK-NEXT: ret } + +; CHECK: ![[MD0]] = !{!"branch_weights", i32 1, i32 1, i32 1}