Index: llvm/include/llvm/IR/Instructions.h =================================================================== --- llvm/include/llvm/IR/Instructions.h +++ llvm/include/llvm/IR/Instructions.h @@ -3186,6 +3186,8 @@ // -2 static const unsigned DefaultPseudoIndex = static_cast(~0L-1); + using CaseWeightOpt = Optional; + template class CaseIteratorImpl; /// A handle to a particular switch case. It exposes a convenient interface @@ -3260,8 +3262,8 @@ } /// Sets the new successor for current case. - void setSuccessor(BasicBlock *S) { - SI->setSuccessor(getSuccessorIndex(), S); + void setSuccessor(BasicBlock *S, CaseWeightOpt W = None) { + SI->setSuccessor(getSuccessorIndex(), S, W); } }; @@ -3357,8 +3359,10 @@ return cast(getOperand(1)); } - void setDefaultDest(BasicBlock *DefaultCase) { - setOperand(1, reinterpret_cast(DefaultCase)); + CaseWeightOpt getDefaultCaseWeight() const { return getSuccessorWeight(0); } + + void setDefaultDest(BasicBlock *DefaultCase, CaseWeightOpt W = None) { + setSuccessor(0, DefaultCase, W); } /// Return the number of 'cases' in this switch instruction, excluding the @@ -3459,7 +3463,7 @@ /// Note: /// This action invalidates case_end(). Old case_end() iterator will /// point to the added case. - void addCase(ConstantInt *OnVal, BasicBlock *Dest); + void addCase(ConstantInt *OnVal, BasicBlock *Dest, CaseWeightOpt W = None); /// This method removes the specified case and its successor from the switch /// instruction. Note that this operation may reorder the remaining cases at @@ -3475,10 +3479,10 @@ assert(idx < getNumSuccessors() &&"Successor idx out of range for switch!"); return cast(getOperand(idx*2+1)); } - void setSuccessor(unsigned idx, BasicBlock *NewSucc) { - assert(idx < getNumSuccessors() && "Successor # out of range for switch!"); - setOperand(idx * 2 + 1, NewSucc); - } + + CaseWeightOpt getSuccessorWeight(unsigned idx) const; + + void setSuccessor(unsigned idx, BasicBlock *NewSucc, CaseWeightOpt W = None); // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const Instruction *I) { Index: llvm/lib/IR/Instructions.cpp =================================================================== --- llvm/lib/IR/Instructions.cpp +++ llvm/lib/IR/Instructions.cpp @@ -3776,14 +3776,60 @@ return nullptr; } -/// addCase - Add an entry to the switch instruction... -/// -void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest) { - // Drop metdadata. +// to append weight pass BI = SI->getNumSuccessors() +static void setProfBranchWeight(SwitchInst *SI, unsigned BI, uint32_t BW) { + assert(BI <= SI->getNumSuccessors() && "case index is in range"); + uint64_t SW = 0; + SmallVector Weights; + + if (MDNode *ProfileData = getProfBranchWeightsMD(SI)) + for (unsigned CI = 0, CE = std::max(BI + 1, SI->getNumSuccessors()); + CI < CE; ++CI) { + uint32_t CW = (CI == BI) ? BW : mdconst::extract( + ProfileData->getOperand(CI + 1)) + ->getValue() + .getZExtValue(); + Weights.push_back(CW); + SW += CW; + } + + if (!Weights.size() && BW) { + for (unsigned CI = 0, CE = std::max(BI + 1, SI->getNumSuccessors()); + CI < CE; ++CI) + Weights.push_back(0); + Weights[BI] = BW; + SW = BW; + } + + auto *NewProfileData = + !SW ? nullptr : MDBuilder(SI->getParent()->getContext()) + .createBranchWeights(Weights); + SI->setMetadata(LLVMContext::MD_prof, NewProfileData); +} + +SwitchInst::CaseWeightOpt SwitchInst::getSuccessorWeight(unsigned idx) const { if (MDNode *ProfileData = getProfBranchWeightsMD(this)) - if (ProfileData->getNumOperands() == getNumSuccessors() + 1) - setMetadata(LLVMContext::MD_prof, nullptr); + return mdconst::extract(ProfileData->getOperand(idx + 1)) + ->getValue() + .getZExtValue(); + + return None; +} + +void SwitchInst::setSuccessor(unsigned idx, BasicBlock *NewSucc, + CaseWeightOpt W) { + assert(idx < getNumSuccessors() && "Successor # out of range for switch!"); + if (W) + setProfBranchWeight(this, idx, *W); + unsigned OI = idx * 2 + 1; + setOperand(OI, NewSucc); +} + +/// addCase - Add an entry to the switch instruction... +/// +void SwitchInst::addCase(ConstantInt *OnVal, BasicBlock *Dest, + CaseWeightOpt W) { unsigned NewCaseIdx = getNumCases(); unsigned OpNo = getNumOperands(); if (OpNo+2 > ReservedSpace) @@ -3793,7 +3839,18 @@ setNumHungOffUseOperands(OpNo+2); CaseHandle Case(this, NewCaseIdx); Case.setValue(OnVal); - Case.setSuccessor(Dest); + + if (!W) + if (MDNode *ProfileData = getProfBranchWeightsMD(this)) + if (ProfileData->getNumOperands() == getNumSuccessors() + 1) { + // prof branch_weights has been set properly, no need to touch them. + Case.setSuccessor(Dest); + return; + } + + // Pass some weight to make sure prof metadata will be fixed to have proper + // num of args. + Case.setSuccessor(Dest, W ? *W : 0); } /// removeCase - This method removes the specified case and its successor