diff --git a/llvm/include/llvm/Analysis/CFGPrinter.h b/llvm/include/llvm/Analysis/CFGPrinter.h --- a/llvm/include/llvm/Analysis/CFGPrinter.h +++ b/llvm/include/llvm/Analysis/CFGPrinter.h @@ -26,6 +26,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/DOTGraphTraits.h" #include "llvm/Support/FormatVariadic.h" @@ -276,14 +277,10 @@ if (Attrs.size()) return Attrs; - MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); + MDNode *WeightsNode = getBranchWeightMDNode(*TI); if (!WeightsNode) return ""; - MDString *MDName = cast(WeightsNode->getOperand(0)); - if (MDName->getString() != "branch_weights") - return ""; - OpNo = I.getSuccessorIndex() + 1; if (OpNo >= WeightsNode->getNumOperands()) return ""; diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -3620,8 +3620,6 @@ bool Changed = false; protected: - static MDNode *getProfBranchWeightsMD(const SwitchInst &SI); - MDNode *buildProfBranchWeightsMD(); void init(); diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h --- a/llvm/include/llvm/IR/ProfDataUtils.h +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -34,6 +34,27 @@ /// otherwise. bool hasBranchWeightMD(const Instruction &I); +/// Checks if an instructions has valid Branch Weight Metadata +/// +/// \param I The instruction to check +/// \returns True if I has an MD_prof node containing valid Branch Weights, +/// i.e., one weight for each successor. False otherwise. +bool hasValidBranchWeightMD(const Instruction &I); + +/// Get the branch weights metadata node +/// +/// \param I The Instruction to get the weights from. +/// \returns A pointer to I's branch weights metadata node, if it exists. +/// Nullptr otherwise. +MDNode *getBranchWeightMDNode(const Instruction &I); + +/// Get the valid branch weights metadata node +/// +/// \param I The Instruction to get the weights from. +/// \returns A pointer to I's valid branch weights metadata node, if it exists. +/// Nullptr otherwise. +MDNode *getValidBranchWeightMDNode(const Instruction &I); + /// Extract branch weights from MD_prof metadata /// /// \param ProfileData A pointer to an MDNode. @@ -70,5 +91,13 @@ /// metadata was found. bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights); +/// Retrieve the total of all weights from an instruction. +/// +/// \param I The instruction to extract the total weight from +/// \param [out] TotalWeights input variable to fill with total weights +/// \returns True on success with profile total weights filled in. False if no +/// metadata was found. +bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights); + } // namespace llvm #endif diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -383,18 +383,13 @@ isa(TI) || isa(TI))) return false; - MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); + MDNode *WeightsNode = getValidBranchWeightMDNode(*TI); if (!WeightsNode) return false; // Check that the number of successors is manageable. assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors"); - // Ensure there are weights for all of the successors. Note that the first - // operand to the metadata node is a name, not a weight. - if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1) - return false; - // Build up the final weights that will be used in a temporary buffer. // Compute the sum of all weights to later decide whether they need to // be scaled to fit in 32 bits. @@ -403,7 +398,7 @@ SmallVector UnreachableIdxs; SmallVector ReachableIdxs; - extractBranchWeights(*TI, Weights); + extractBranchWeights(WeightsNode, Weights); for (unsigned I = 0, E = Weights.size(); I != E; ++I) { WeightSum += Weights[I]; const LoopBlock SrcLoopBB = getLoopBlock(BB); diff --git a/llvm/lib/IR/Instruction.cpp b/llvm/lib/IR/Instruction.cpp --- a/llvm/lib/IR/Instruction.cpp +++ b/llvm/lib/IR/Instruction.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" using namespace llvm; @@ -855,13 +856,8 @@ } void Instruction::swapProfMetadata() { - MDNode *ProfileData = getMetadata(LLVMContext::MD_prof); - if (!ProfileData || ProfileData->getNumOperands() != 3 || - !isa(ProfileData->getOperand(0))) - return; - - MDString *MDName = cast(ProfileData->getOperand(0)); - if (MDName->getString() != "branch_weights") + MDNode *ProfileData = getBranchWeightMDNode(*this); + if (!ProfileData || ProfileData->getNumOperands() != 3) return; // The first operand is the name. Fetch them backwards and build a new one. diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/AtomicOrdering.h" @@ -4572,15 +4573,6 @@ growHungoffUses(ReservedSpace); } -MDNode * -SwitchInstProfUpdateWrapper::getProfBranchWeightsMD(const SwitchInst &SI) { - if (MDNode *ProfileData = SI.getMetadata(LLVMContext::MD_prof)) - if (auto *MDName = dyn_cast(ProfileData->getOperand(0))) - if (MDName->getString() == "branch_weights") - return ProfileData; - return nullptr; -} - MDNode *SwitchInstProfUpdateWrapper::buildProfBranchWeightsMD() { assert(Changed && "called only if metadata has changed"); @@ -4599,7 +4591,7 @@ } void SwitchInstProfUpdateWrapper::init() { - MDNode *ProfileData = getProfBranchWeightsMD(SI); + MDNode *ProfileData = getBranchWeightMDNode(SI); if (!ProfileData) return; @@ -4609,11 +4601,8 @@ } SmallVector Weights; - for (unsigned CI = 1, CE = SI.getNumSuccessors(); CI <= CE; ++CI) { - ConstantInt *C = mdconst::extract(ProfileData->getOperand(CI)); - uint32_t CW = C->getValue().getZExtValue(); - Weights.push_back(CW); - } + if (!extractBranchWeights(ProfileData, Weights)) + return; this->Weights = std::move(Weights); } @@ -4686,7 +4675,7 @@ SwitchInstProfUpdateWrapper::CaseWeightOpt SwitchInstProfUpdateWrapper::getSuccessorWeight(const SwitchInst &SI, unsigned idx) { - if (MDNode *ProfileData = getProfBranchWeightsMD(SI)) + if (MDNode *ProfileData = getBranchWeightMDNode(SI)) if (ProfileData->getNumOperands() == SI.getNumSuccessors() + 1) return mdconst::extract(ProfileData->getOperand(idx + 1)) ->getValue() diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -1544,7 +1544,7 @@ getOpcode() == Instruction::Switch) && "Looking for branch weights on something besides branch"); - return ::extractProfTotalWeight(getMetadata(LLVMContext::MD_prof), TotalVal); + return ::extractProfTotalWeight(*this, TotalVal); } void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) { diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp --- a/llvm/lib/IR/ProfDataUtils.cpp +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -101,6 +101,28 @@ return isBranchWeightMD(ProfileData); } +bool hasValidBranchWeightMD(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + if (!isBranchWeightMD(ProfileData)) + return false; + if (ProfileData && ProfileData->getNumOperands() == 1 + I.getNumSuccessors()) + return true; + return false; +} + +MDNode *getBranchWeightMDNode(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + if (!isBranchWeightMD(ProfileData)) + return nullptr; + return ProfileData; +} + +MDNode *getValidBranchWeightMDNode(const Instruction &I) { + if (!hasValidBranchWeightMD(I)) + return nullptr; + return I.getMetadata(LLVMContext::MD_prof); +} + bool extractBranchWeights(const MDNode *ProfileData, SmallVectorImpl &Weights) { if (!isBranchWeightMD(ProfileData)) @@ -118,7 +140,8 @@ uint64_t &FalseVal) { assert((I.getOpcode() == Instruction::Br || I.getOpcode() == Instruction::Select) && - "Looking for branch weights on something besides branch or select"); + "Looking for branch weights on something besides branch, select, or " + "switch"); SmallVector Weights; auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); @@ -161,4 +184,8 @@ return false; } +bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) { + return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal); +} + } // namespace llvm diff --git a/llvm/lib/Transforms/IPO/PartialInlining.cpp b/llvm/lib/Transforms/IPO/PartialInlining.cpp --- a/llvm/lib/Transforms/IPO/PartialInlining.cpp +++ b/llvm/lib/Transforms/IPO/PartialInlining.cpp @@ -716,8 +716,7 @@ BranchInst *BR = dyn_cast(E->getTerminator()); if (!BR || BR->isUnconditional()) continue; - uint64_t T, F; - if (extractBranchWeights(*BR, T, F)) + if (hasBranchWeightMD(*BR)) return true; } return false; diff --git a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp --- a/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp +++ b/llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" @@ -575,32 +576,26 @@ return true; } -// Returns true and sets the true probability and false probability of an -// MD_prof metadata if it's well-formed. -static bool checkMDProf(MDNode *MD, BranchProbability &TrueProb, - BranchProbability &FalseProb) { - if (!MD) return false; - MDString *MDName = cast(MD->getOperand(0)); - if (MDName->getString() != "branch_weights" || - MD->getNumOperands() != 3) +// Constructs the true and false branch probabilities if the the instruction has +// valid branch weights. Returns true when this was successful, false otherwise. +static bool extractBranchProbabilities(Instruction *I, + BranchProbability &TrueProb, + BranchProbability &FalseProb) { + uint64_t TrueWeight; + uint64_t FalseWeight; + if (!extractBranchWeights(*I, TrueWeight, FalseWeight)) return false; - ConstantInt *TrueWeight = mdconst::extract(MD->getOperand(1)); - ConstantInt *FalseWeight = mdconst::extract(MD->getOperand(2)); - if (!TrueWeight || !FalseWeight) - return false; - uint64_t TrueWt = TrueWeight->getValue().getZExtValue(); - uint64_t FalseWt = FalseWeight->getValue().getZExtValue(); - uint64_t SumWt = TrueWt + FalseWt; + uint64_t SumWeight = TrueWeight + FalseWeight; - assert(SumWt >= TrueWt && SumWt >= FalseWt && + assert(SumWeight >= TrueWeight && SumWeight >= FalseWeight && "Overflow calculating branch probabilities."); // Guard against 0-to-0 branch weights to avoid a division-by-zero crash. - if (SumWt == 0) + if (SumWeight == 0) return false; - TrueProb = BranchProbability::getBranchProbability(TrueWt, SumWt); - FalseProb = BranchProbability::getBranchProbability(FalseWt, SumWt); + TrueProb = BranchProbability::getBranchProbability(TrueWeight, SumWeight); + FalseProb = BranchProbability::getBranchProbability(FalseWeight, SumWeight); return true; } @@ -639,8 +634,7 @@ if (!BI->isConditional()) return false; BranchProbability ThenProb, ElseProb; - if (!checkMDProf(BI->getMetadata(LLVMContext::MD_prof), - ThenProb, ElseProb)) + if (!extractBranchProbabilities(BI, ThenProb, ElseProb)) return false; BasicBlock *IfThen = BI->getSuccessor(0); BasicBlock *IfElse = BI->getSuccessor(1); @@ -669,8 +663,7 @@ DenseSet &FalseBiasedSelectsGlobal, DenseMap &SelectBiasMap) { BranchProbability TrueProb, FalseProb; - if (!checkMDProf(SI->getMetadata(LLVMContext::MD_prof), - TrueProb, FalseProb)) + if (!extractBranchProbabilities(SI, TrueProb, FalseProb)) return false; CHR_DEBUG(dbgs() << "SI " << *SI << " "); CHR_DEBUG(dbgs() << "TrueProb " << TrueProb << " "); diff --git a/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/llvm/lib/Transforms/Scalar/JumpThreading.cpp --- a/llvm/lib/Transforms/Scalar/JumpThreading.cpp +++ b/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -2522,18 +2522,7 @@ bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) { const Instruction *TI = BB->getTerminator(); assert(TI->getNumSuccessors() > 1 && "not a split"); - - MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); - if (!WeightsNode) - return false; - - MDString *MDName = cast(WeightsNode->getOperand(0)); - if (MDName->getString() != "branch_weights") - return false; - - // Ensure there are weights for all of the successors. Note that the first - // operand to the metadata node is a name, not a weight. - return WeightsNode->getNumOperands() == TI->getNumSuccessors() + 1; + return hasValidBranchWeightMD(*TI); } /// Update the block frequency of BB and branch weight and the metadata on the diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -191,6 +191,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" @@ -974,37 +975,24 @@ LatchExitBlock->getTerminatingDeoptimizeCall()) return false; - auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) { - if (!ProfileData || !ProfileData->getOperand(0)) - return false; - if (MDString *MDS = dyn_cast(ProfileData->getOperand(0))) - if (!MDS->getString().equals("branch_weights")) - return false; - if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors()) - return false; - return true; - }; - MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof); // Latch terminator has no valid profile data, so nothing to check // profitability on. - if (!IsValidProfileData(LatchProfileData, LatchTerm)) + if (!hasValidBranchWeightMD(*LatchTerm)) return true; auto ComputeBranchProbability = [&](const BasicBlock *ExitingBlock, const BasicBlock *ExitBlock) -> BranchProbability { auto *Term = ExitingBlock->getTerminator(); - MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof); unsigned NumSucc = Term->getNumSuccessors(); - if (IsValidProfileData(ProfileData, Term)) { - uint64_t Numerator = 0, Denominator = 0, ProfVal = 0; - for (unsigned i = 0; i < NumSucc; i++) { - ConstantInt *CI = - mdconst::extract(ProfileData->getOperand(i + 1)); - ProfVal = CI->getValue().getZExtValue(); + if (MDNode *ProfileData = getValidBranchWeightMDNode(*Term)) { + SmallVector Weights; + extractBranchWeights(ProfileData, Weights); + uint64_t Numerator = 0, Denominator = 0; + for (auto [i, Weight] : llvm::enumerate(Weights)) { if (Term->getSuccessor(i) == ExitBlock) - Numerator += ProfVal; - Denominator += ProfVal; + Numerator += Weight; + Denominator += Weight; } return BranchProbability::getBranchProbability(Numerator, Denominator); } else { diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -62,6 +62,7 @@ #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -210,20 +211,18 @@ // Check to see if this branch is going to the same place as the default // dest. If so, eliminate it as an explicit compare. if (i->getCaseSuccessor() == DefaultDest) { - MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); + MDNode *MD = getValidBranchWeightMDNode(*SI); unsigned NCases = SI->getNumCases(); // Fold the case metadata into the default if there will be any branches // left, unless the metadata doesn't match the switch. - if (NCases > 1 && MD && MD->getNumOperands() == 2 + NCases) { + if (NCases > 1 && MD) { // Collect branch weights into a vector. SmallVector Weights; - for (unsigned MD_i = 1, MD_e = MD->getNumOperands(); MD_i < MD_e; - ++MD_i) { - auto *CI = mdconst::extract(MD->getOperand(MD_i)); - Weights.push_back(CI->getValue().getZExtValue()); - } + extractBranchWeights(MD, Weights); + // Merge weight of this case to the default weight. unsigned idx = i->getCaseIndex(); + // TODO: Add overflow check. Weights[0] += Weights[idx+1]; // Remove weight for this case. std::swap(Weights[idx+1], Weights.back()); @@ -313,18 +312,14 @@ BranchInst *NewBr = Builder.CreateCondBr(Cond, FirstCase.getCaseSuccessor(), SI->getDefaultDest()); - MDNode *MD = SI->getMetadata(LLVMContext::MD_prof); - if (MD && MD->getNumOperands() == 3) { - ConstantInt *SICase = - mdconst::dyn_extract(MD->getOperand(2)); - ConstantInt *SIDef = - mdconst::dyn_extract(MD->getOperand(1)); - assert(SICase && SIDef); + SmallVector Weights; + if (extractBranchWeights(*SI, Weights) && Weights.size() == 2) { + uint64_t DefWeight = Weights[0]; + uint64_t CaseWeight = Weights[1]; // The TrueWeight should be the weight for the single case of SI. NewBr->setMetadata(LLVMContext::MD_prof, - MDBuilder(BB->getContext()). - createBranchWeights(SICase->getValue().getZExtValue(), - SIDef->getValue().getZExtValue())); + MDBuilder(BB->getContext()) + .createBranchWeights(CaseWeight, DefWeight)); } // Update make.implicit metadata to the newly-created conditional branch.