Index: llvm/include/llvm/IR/Instructions.h =================================================================== --- llvm/include/llvm/IR/Instructions.h +++ llvm/include/llvm/IR/Instructions.h @@ -3186,6 +3186,8 @@ // -2 static const unsigned DefaultPseudoIndex = static_cast(~0L-1); + using CaseWeightOpt = Optional; + template class CaseIteratorImpl; /// A handle to a particular switch case. It exposes a convenient interface @@ -3260,8 +3262,8 @@ } /// Sets the new successor for current case. - void setSuccessor(BasicBlock *S) { - SI->setSuccessor(getSuccessorIndex(), S); + void setSuccessor(BasicBlock *S, CaseWeightOpt W = None) { + SI->setSuccessor(getSuccessorIndex(), S, W); } }; @@ -3357,8 +3359,12 @@ return cast(getOperand(1)); } - void setDefaultDest(BasicBlock *DefaultCase) { - setOperand(1, reinterpret_cast(DefaultCase)); + CaseWeightOpt getDefaultCaseWeight() const { + return getSuccessorWeight(0); + } + + void setDefaultDest(BasicBlock *DefaultCase, CaseWeightOpt W = None) { + setSuccessor(0, DefaultCase, W); } /// Return the number of 'cases' in this switch instruction, excluding the @@ -3459,7 +3465,7 @@ /// Note: /// This action invalidates case_end(). Old case_end() iterator will /// point to the added case. - void addCase(ConstantInt *OnVal, BasicBlock *Dest); + void addCase(ConstantInt *OnVal, BasicBlock *Dest, CaseWeightOpt W = None); /// This method removes the specified case and its successor from the switch /// instruction. Note that this operation may reorder the remaining cases at @@ -3475,10 +3481,10 @@ assert(idx < getNumSuccessors() &&"Successor idx out of range for switch!"); return cast(getOperand(idx*2+1)); } - void setSuccessor(unsigned idx, BasicBlock *NewSucc) { - assert(idx < getNumSuccessors() && "Successor # out of range for switch!"); - setOperand(idx * 2 + 1, NewSucc); - } + + CaseWeightOpt getSuccessorWeight(unsigned idx) const; + + void setSuccessor(unsigned idx, BasicBlock *NewSucc, CaseWeightOpt W = None); // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const Instruction *I) { Index: llvm/lib/IR/Instructions.cpp =================================================================== --- llvm/lib/IR/Instructions.cpp +++ llvm/lib/IR/Instructions.cpp @@ -3768,16 +3768,64 @@ SubclassOptionalData = SI.SubclassOptionalData; } -/// addCase - Add an entry to the switch instruction... -/// -void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest) { - // Drop metdadata. - MDNode *ProfileData = getMetadata(LLVMContext::MD_prof); - if (ProfileData && ProfileData->getNumOperands() == getNumSuccessors() + 1) +static MDNode * getProfBranchWeightsMD(const SwitchInst *SI) { + MDNode *ProfileData = SI->getMetadata(LLVMContext::MD_prof); + if (ProfileData) if (auto *MDName = dyn_cast(ProfileData->getOperand(0))) if (MDName->getString() == "branch_weights") - setMetadata(LLVMContext::MD_prof, nullptr); + return ProfileData; + return nullptr; +} + +// to append weight pass BI = SI->getNumSuccessors() +static void setProfBranchWeight(SwitchInst *SI, unsigned BI, uint32_t BW) { + assert(BI <= SI->getNumSuccessors() && "case index is in range"); + uint32_t SW = 0; + SmallVector Weights; + + if (MDNode *ProfileData = getProfBranchWeightsMD(SI)) + for (unsigned CI = 0, CE = std::max(BI + 1, SI->getNumSuccessors()); + CI < CE; ++CI) { + uint32_t CW = (CI == BI) ? BW : + mdconst::extract(ProfileData->getOperand(CI + 1)) + ->getValue().getZExtValue(); + Weights.push_back(CW); + SW += CW; + } + + if (!Weights.size() && BW) { + for (unsigned CI = 0, CE = std::max(BI + 1, SI->getNumSuccessors()); + CI < CE; ++CI) + Weights.push_back(0); + Weights[BI] = BW; + SW = BW; + } + + auto *NewProfileData = !SW ? nullptr : + MDBuilder(SI->getParent()->getContext()).createBranchWeights(Weights); + SI->setMetadata(LLVMContext::MD_prof, NewProfileData); +} + +SwitchInst::CaseWeightOpt SwitchInst::getSuccessorWeight(unsigned idx) const { + if (MDNode *ProfileData = getProfBranchWeightsMD(this)) + return mdconst::extract(ProfileData->getOperand(idx + 1)) + ->getValue().getZExtValue(); + + return None; +} + +void SwitchInst::setSuccessor(unsigned idx, BasicBlock *NewSucc, CaseWeightOpt W) { + assert(idx < getNumSuccessors() && "Successor # out of range for switch!"); + if (W) + setProfBranchWeight(this, idx, *W); + + unsigned OI = idx * 2 + 1; + setOperand(OI, NewSucc); +} +/// addCase - Add an entry to the switch instruction... +/// +void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest, CaseWeightOpt W) { unsigned NewCaseIdx = getNumCases(); unsigned OpNo = getNumOperands(); if (OpNo+2 > ReservedSpace) @@ -3787,7 +3835,18 @@ setNumHungOffUseOperands(OpNo+2); CaseHandle Case(this, NewCaseIdx); Case.setValue(OnVal); - Case.setSuccessor(Dest); + + if (!W) + if (MDNode *ProfileData = getProfBranchWeightsMD(this)) + if (ProfileData->getNumOperands() == getNumSuccessors() + 1) { + // prof branch_weights has been set properly, no need to touch them. + Case.setSuccessor(Dest); + return; + } + + // Pass some weight to make sure prof metadata will be fixed to have proper + // num of args. + Case.setSuccessor(Dest, W ? *W : 0); } /// removeCase - This method removes the specified case and its successor @@ -3797,35 +3856,32 @@ assert(2 + idx*2 < getNumOperands() && "Case index out of range!!!"); - unsigned NumOps = getNumOperands(); - Use *OL = getOperandList(); - - // Update metdadata. - MDNode *ProfileData = getMetadata(LLVMContext::MD_prof); - if (ProfileData && ProfileData->getNumOperands() == getNumSuccessors() + 1) - if (auto *MDName = dyn_cast(ProfileData->getOperand(0))) - if (MDName->getString() == "branch_weights") { - if (DropPerfMetadata) { - setMetadata(LLVMContext::MD_prof, nullptr); - } else { - SmallVector Weights; - uint32_t SW = 0; - for (unsigned CI = 0, CE = getNumSuccessors() - 1; CI < CE; ++CI) { - // At the removed position we put the end of the list. - unsigned WI = (CI == idx + 1 ? CE : CI) + 1; - ConstantInt *C = mdconst::extract( - ProfileData->getOperand(WI)); - uint32_t CW = C->getValue().getZExtValue(); - SW += CW; - Weights.push_back(CW); - } - assert(Weights.size() == getNumSuccessors() - 1); - auto *NewProfileData = !SW || Weights.size() < 2 ? nullptr : - MDBuilder(getParent()->getContext()) - .createBranchWeights(Weights); - setMetadata(LLVMContext::MD_prof, NewProfileData); + if (MDNode *ProfileData = getProfBranchWeightsMD(this)) + if (ProfileData->getNumOperands() == getNumSuccessors() + 1) { + if (DropPerfMetadata) { + setMetadata(LLVMContext::MD_prof, nullptr); + } else { + SmallVector Weights; + uint32_t SW = 0; + for (unsigned CI = 0, CE = getNumSuccessors() - 1; CI < CE; ++CI) { + // At the removed position we put the end of the list. + unsigned WI = (CI == idx + 1 ? CE : CI) + 1; + ConstantInt *C = mdconst::extract( + ProfileData->getOperand(WI)); + uint32_t CW = C->getValue().getZExtValue(); + SW += CW; + Weights.push_back(CW); } + assert(Weights.size() == getNumSuccessors() - 1); + auto *NewProfileData = !SW || Weights.size() < 2 ? nullptr : + MDBuilder(getParent()->getContext()) + .createBranchWeights(Weights); + setMetadata(LLVMContext::MD_prof, NewProfileData); } + } + + unsigned NumOps = getNumOperands(); + Use *OL = getOperandList(); // Overwrite this case with the end of the list. if (2 + (idx + 1) * 2 != NumOps) {