diff --git a/llvm/include/llvm/Analysis/CFGPrinter.h b/llvm/include/llvm/Analysis/CFGPrinter.h --- a/llvm/include/llvm/Analysis/CFGPrinter.h +++ b/llvm/include/llvm/Analysis/CFGPrinter.h @@ -26,6 +26,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/DOTGraphTraits.h" #include "llvm/Support/FormatVariadic.h" @@ -276,14 +277,10 @@ if (Attrs.size()) return Attrs; - MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); + MDNode *WeightsNode = getBranchWeightMDNode(*TI); if (!WeightsNode) return ""; - MDString *MDName = cast(WeightsNode->getOperand(0)); - if (MDName->getString() != "branch_weights") - return ""; - OpNo = I.getSuccessorIndex() + 1; if (OpNo >= WeightsNode->getNumOperands()) return ""; diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -3620,8 +3620,6 @@ bool Changed = false; protected: - static MDNode *getProfBranchWeightsMD(const SwitchInst &SI); - MDNode *buildProfBranchWeightsMD(); void init(); diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -34,6 +34,13 @@ /// otherwise. bool hasBranchWeightMD(const Instruction &I); +/// Get the branch weights metadata node +/// +/// \param I The Instruction to get the weights from. +/// \returns A pointer to I's branch weights metadata node, if it exists. +/// Nullptr otherwise. +MDNode *getBranchWeightMDNode(const Instruction &I); + /// Extract branch weights from MD_prof metadata /// /// \param ProfileData A pointer to an MDNode. @@ -52,7 +59,8 @@ bool extractBranchWeights(const Instruction &I, SmallVectorImpl &Weights); -/// Extract branch weights from a conditional branch or select Instruction. +/// Extract branch weights from a conditional branch, select, or switch +/// Instruction. /// /// \param I The instruction to extract branch weights from. /// \param [out] TrueVal will contain the branch weight for the True branch @@ -70,5 +78,13 @@ /// metadata was found. bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights); +/// Retrieve the total of all weights from an instruction. +/// +/// \param I The instruction to extract the total weight from +/// \param [out] TotalWeights input variable to fill with total weights +/// \returns True on success with profile total weights filled in. False if no +/// metadata was found. +bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights); + } // namespace llvm #endif diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -383,7 +383,7 @@ isa(TI) || isa(TI))) return false; - MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); + MDNode *WeightsNode = getBranchWeightMDNode(*TI); if (!WeightsNode) return false; @@ -403,7 +403,7 @@ SmallVector UnreachableIdxs; SmallVector ReachableIdxs; - extractBranchWeights(*TI, Weights); + extractBranchWeights(WeightsNode, Weights); for (unsigned I = 0, E = Weights.size(); I != E; ++I) { WeightSum += Weights[I]; const LoopBlock SrcLoopBB = getLoopBlock(BB); diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp --- a/llvm/lib/IR/Instruction.cpp +++ b/llvm/lib/IR/Instruction.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" using namespace llvm; @@ -855,13 +856,8 @@ } void Instruction::swapProfMetadata() { - MDNode *ProfileData = getMetadata(LLVMContext::MD_prof); - if (!ProfileData || ProfileData->getNumOperands() != 3 || - !isa(ProfileData->getOperand(0))) - return; - - MDString *MDName = cast(ProfileData->getOperand(0)); - if (MDName->getString() != "branch_weights") + MDNode *ProfileData = getBranchWeightMDNode(*this); + if (!ProfileData || ProfileData->getNumOperands() != 3) return; // The first operand is the name. Fetch them backwards and build a new one. diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/AtomicOrdering.h" @@ -4572,15 +4573,6 @@ growHungoffUses(ReservedSpace); } -MDNode * -SwitchInstProfUpdateWrapper::getProfBranchWeightsMD(const SwitchInst &SI) { - if (MDNode *ProfileData = SI.getMetadata(LLVMContext::MD_prof)) - if (auto *MDName = dyn_cast(ProfileData->getOperand(0))) - if (MDName->getString() == "branch_weights") - return ProfileData; - return nullptr; -} - MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() { assert(Changed && "called only if metadata has changed"); @@ -4599,21 +4591,9 @@ } void SwitchInstProfUpdateWrapper::init() { - MDNode *ProfileData = getProfBranchWeightsMD(SI); - if (!ProfileData) - return; - - if (ProfileData->getNumOperands() != SI.getNumSuccessors() + 1) { - llvm_unreachable("number of prof branch_weights metadata operands does " - "not correspond to number of succesors"); - } - SmallVector Weights; - for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) { - ConstantInt *C = mdconst::extract(ProfileData->getOperand(CI)); - uint32_t CW = C->getValue().getZExtValue(); - Weights.push_back(CW); - } + if (!extractBranchWeights(SI, Weights)) + return; this->Weights = std::move(Weights); } @@ -4686,7 +4666,7 @@ SwitchInstProfUpdateWrapper::CaseWeightOpt SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI, unsigned idx) { - if (MDNode *ProfileData = getProfBranchWeightsMD(SI)) + if (MDNode *ProfileData = getBranchWeightMDNode(SI)) if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1) return mdconst::extract(ProfileData->getOperand(idx + 1)) ->getValue() diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -1544,7 +1544,7 @@ getOpcode() == Instruction::Switch) && "Looking for branch weights on something besides branch"); - return ::extractProfTotalWeight(getMetadata(LLVMContext::MD_prof), TotalVal); + return ::extractProfTotalWeight(*this, TotalVal); } void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) { diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -43,7 +43,7 @@ constexpr unsigned WeightsIdx = 1; // the minimum number of operands for MD_prof nodes with branch weights -constexpr unsigned MinBWOps = 3; +constexpr unsigned MinBWOps = 2; bool extractWeights(const MDNode *ProfileData, SmallVectorImpl &Weights) { @@ -101,6 +101,13 @@ return isBranchWeightMD(ProfileData); } +MDNode *getBranchWeightMDNode(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + if (!isBranchWeightMD(ProfileData)) + return nullptr; + return ProfileData; +} + bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights) { if (!isBranchWeightMD(ProfileData)) @@ -117,8 +124,10 @@ bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, uint64_t &FalseVal) { assert((I.getOpcode() == Instruction::Br || - I.getOpcode() == Instruction::Select) && - "Looking for branch weights on something besides branch or select"); + I.getOpcode() == Instruction::Select || + I.getOpcode() == Instruction::Switch) && + "Looking for branch weights on something besides branch, select, or " + "switch"); SmallVector Weights; auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); @@ -161,4 +170,8 @@ return false; } +bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { + return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); +} + } // namespace llvm diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp --- a/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -716,8 +716,7 @@ BranchInst *BR = dyn_cast(E->getTerminator()); if (!BR || BR->isUnconditional()) continue; - uint64_t T, F; - if (extractBranchWeights(*BR, T, F)) + if (getBranchWeightMDNode(*BR)) return true; } return false; diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp --- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" @@ -577,30 +578,23 @@ // Returns true and sets the true probability and false probability of an // MD_prof metadata if it's well-formed. -static bool checkMDProf(MDNode *MD, BranchProbability &TrueProb, +static bool checkMDProf(Instruction *I, BranchProbability &TrueProb, BranchProbability &FalseProb) { - if (!MD) return false; - MDString *MDName = cast(MD->getOperand(0)); - if (MDName->getString() != "branch_weights" || - MD->getNumOperands() != 3) + uint64_t TrueWeight; + uint64_t FalseWeight; + if (!extractBranchWeights(*I, TrueWeight, FalseWeight)) return false; - ConstantInt *TrueWeight = mdconst::extract(MD->getOperand(1)); - ConstantInt *FalseWeight = mdconst::extract(MD->getOperand(2)); - if (!TrueWeight || !FalseWeight) - return false; - uint64_t TrueWt = TrueWeight->getValue().getZExtValue(); - uint64_t FalseWt = FalseWeight->getValue().getZExtValue(); - uint64_t SumWt = TrueWt + FalseWt; + uint64_t SumWt = TrueWeight + FalseWeight; - assert(SumWt >= TrueWt && SumWt >= FalseWt && + assert(SumWt >= TrueWeight && SumWt >= FalseWeight && "Overflow calculating branch probabilities."); // Guard against 0-to-0 branch weights to avoid a division-by-zero crash. if (SumWt == 0) return false; - TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt); - FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt); + TrueProb = BranchProbability::getBranchProbability(TrueWeight, SumWt); + FalseProb = BranchProbability::getBranchProbability(FalseWeight, SumWt); return true; } @@ -639,8 +633,7 @@ if (!BI->isConditional()) return false; BranchProbability ThenProb, ElseProb; - if (!checkMDProf(BI->getMetadata(LLVMContext::MD_prof), - ThenProb, ElseProb)) + if (!checkMDProf(BI, ThenProb, ElseProb)) return false; BasicBlock *IfThen = BI->getSuccessor(0); BasicBlock *IfElse = BI->getSuccessor(1); @@ -669,8 +662,7 @@ DenseSet &FalseBiasedSelectsGlobal, DenseMap &SelectBiasMap) { BranchProbability TrueProb, FalseProb; - if (!checkMDProf(SI->getMetadata(LLVMContext::MD_prof), - TrueProb, FalseProb)) + if (!checkMDProf(SI, TrueProb, FalseProb)) return false; CHR_DEBUG(dbgs() << "SI " << *SI << " "); CHR_DEBUG(dbgs() << "TrueProb " << TrueProb << " "); diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -2523,14 +2523,9 @@ const Instruction *TI = BB->getTerminator(); assert(TI->getNumSuccessors() > 1 && "not a split"); - MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); + auto *WeightsNode = getBranchWeightMDNode(*TI); if (!WeightsNode) return false; - - MDString *MDName = cast(WeightsNode->getOperand(0)); - if (MDName->getString() != "branch_weights") - return false; - // Ensure there are weights for all of the successors. Note that the first // operand to the metadata node is a name, not a weight. return WeightsNode->getNumOperands() == TI->getNumSuccessors() + 1; diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -191,6 +191,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -975,16 +976,10 @@ return false; auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) { - if (!ProfileData || !ProfileData->getOperand(0)) - return false; - if (MDString *MDS = dyn_cast(ProfileData->getOperand(0))) - if (!MDS->getString().equals("branch_weights")) - return false; - if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors()) - return false; - return true; + return ProfileData && + ProfileData->getNumOperands() == 1 + Term->getNumSuccessors(); }; - MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof); + MDNode *LatchProfileData = getBranchWeightMDNode(*LatchTerm); // Latch terminator has no valid profile data, so nothing to check // profitability on. if (!IsValidProfileData(LatchProfileData, LatchTerm)) @@ -994,7 +989,7 @@ [&](const BasicBlock *ExitingBlock, const BasicBlock *ExitBlock) -> BranchProbability { auto *Term = ExitingBlock->getTerminator(); - MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof); + MDNode *ProfileData = getBranchWeightMDNode(*Term); unsigned NumSucc = Term->getNumSuccessors(); if (IsValidProfileData(ProfileData, Term)) { uint64_t Numerator = 0, Denominator = 0, ProfVal = 0; diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -62,6 +62,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -210,20 +211,18 @@ // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. if (i->getCaseSuccessor() == DefaultDest) { - MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = getBranchWeightMDNode(*SI); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches // left, unless the metadata doesn't match the switch. if (NCases > 1 && MD && MD->getNumOperands() == 2 + NCases) { // Collect branch weights into a vector. SmallVector Weights; - for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; - ++MD_i) { - auto *CI = mdconst::extract(MD->getOperand(MD_i)); - Weights.push_back(CI->getValue().getZExtValue()); - } + extractBranchWeights(*SI, Weights); + // Merge weight of this case to the default weight. unsigned idx = i->getCaseIndex(); + // TODO: no overflow check Weights[0] += Weights[idx+1]; // Remove weight for this case. std::swap(Weights[idx+1], Weights.back()); @@ -313,18 +312,13 @@ BranchInst *NewBr = Builder.CreateCondBr(Cond, FirstCase.getCaseSuccessor(), SI->getDefaultDest()); - MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); - if (MD && MD->getNumOperands() == 3) { - ConstantInt *SICase = - mdconst::dyn_extract(MD->getOperand(2)); - ConstantInt *SIDef = - mdconst::dyn_extract(MD->getOperand(1)); - assert(SICase && SIDef); + uint64_t DefWeight; + uint64_t CaseWeight; + if (extractBranchWeights(*SI, DefWeight, CaseWeight)) { // The TrueWeight should be the weight for the single case of SI. NewBr->setMetadata(LLVMContext::MD_prof, - MDBuilder(BB->getContext()). - createBranchWeights(SICase->getValue().getZExtValue(), - SIDef->getValue().getZExtValue())); + MDBuilder(BB->getContext()) + .createBranchWeights(CaseWeight, DefWeight)); } // Update make.implicit metadata to the newly-created conditional branch.