diff --git a/llvm/include/llvm/IR/Instruction.h b/llvm/include/llvm/IR/Instruction.h --- a/llvm/include/llvm/IR/Instruction.h +++ b/llvm/include/llvm/IR/Instruction.h @@ -355,6 +355,11 @@ /// Returns false if no metadata or invalid metadata was found. bool extractProfMetadata(uint64_t &TrueVal, uint64_t &FalseVal) const; + /// Retrieve the raw weight values of a conditional branch, select, or switch. + /// Returns true on success with profile weights filled in. + /// Returns false if no metadata or invalid metadata was found. + bool extractProfMetadata(SmallVector &Weights) const; + /// Retrieve total raw weight values of a branch. /// Returns true on success with profile total weights filled in. /// Returns false if no metadata was found. diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -30,6 +30,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" @@ -400,24 +401,18 @@ SmallVector Weights; SmallVector UnreachableIdxs; SmallVector ReachableIdxs; - Weights.reserve(TI->getNumSuccessors()); - for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) { - ConstantInt *Weight = - mdconst::dyn_extract(WeightsNode->getOperand(I)); - if (!Weight) - return false; - assert(Weight->getValue().getActiveBits() <= 32 && - "Too many bits for uint32_t"); - Weights.push_back(Weight->getZExtValue()); - WeightSum += Weights.back(); + + extractBranchWeights(*TI, Weights); + for (unsigned I = 0, E = Weights.size(); I != E; ++I) { + WeightSum += Weights[I]; const LoopBlock SrcLoopBB = getLoopBlock(BB); - const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1)); + const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I)); auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB}); if (EstimatedWeight && *EstimatedWeight <= static_cast(BlockExecWeight::UNREACHABLE)) - UnreachableIdxs.push_back(I - 1); + UnreachableIdxs.push_back(I); else - ReachableIdxs.push_back(I - 1); + ReachableIdxs.push_back(I); } assert(Weights.size() == TI->getNumSuccessors() && "Checked above"); diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -40,6 +40,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/TrackingMDRef.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -1505,26 +1506,24 @@ assert( (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select) && "Looking for branch weights on something besides branch or select"); - - auto *ProfileData = getMetadata(LLVMContext::MD_prof); - if (!ProfileData || ProfileData->getNumOperands() != 3) - return false; - - auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); - if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) - return false; - - auto *CITrue = mdconst::dyn_extract(ProfileData->getOperand(1)); - auto *CIFalse = mdconst::dyn_extract(ProfileData->getOperand(2)); - if (!CITrue || !CIFalse) + SmallVectorWeights; + if(!extractBranchWeights(*this, Weights) || Weights.size() != 2) return false; - TrueVal = CITrue->getValue().getZExtValue(); - FalseVal = CIFalse->getValue().getZExtValue(); + TrueVal = Weights[0]; + FalseVal = Weights[1]; return true; } +bool Instruction::extractProfMetadata(SmallVector &Weights) const { + assert( + (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select) && + "Looking for branch weights on something besides branch or select"); + return extractBranchWeights(*this, Weights); +} + + bool Instruction::extractProfTotalWeight(uint64_t &TotalVal) const { assert( (getOpcode() == Instruction::Br || getOpcode() == Instruction::Select || @@ -1533,32 +1532,7 @@ getOpcode() == Instruction::Switch) && "Looking for branch weights on something besides branch"); - TotalVal = 0; - auto *ProfileData = getMetadata(LLVMContext::MD_prof); - if (!ProfileData) - return false; - - auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); - if (!ProfDataName) - return false; - - if (ProfDataName->getString().equals("branch_weights")) { - TotalVal = 0; - for (unsigned i = 1; i < ProfileData->getNumOperands(); i++) { - auto *V = mdconst::dyn_extract(ProfileData->getOperand(i)); - if (!V) - return false; - TotalVal += V->getValue().getZExtValue(); - } - return true; - } else if (ProfDataName->getString().equals("VP") && - ProfileData->getNumOperands() > 3) { - TotalVal = mdconst::dyn_extract(ProfileData->getOperand(2)) - ->getValue() - .getZExtValue(); - return true; - } - return false; + return extractTotalBranchWeights(getMetadata(LLVMContext::MD_prof), TotalVal); } void GlobalObject::copyMetadata(const GlobalObject *Other, unsigned Offset) { diff --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp --- a/llvm/lib/Transforms/Utils/MisExpect.cpp +++ b/llvm/lib/Transforms/Utils/MisExpect.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -118,34 +119,6 @@ namespace llvm { namespace misexpect { -// Helper function to extract branch weights into a vector -Optional> extractWeights(Instruction *I, - LLVMContext &Ctx) { - assert(I && "MisExpect::extractWeights given invalid pointer"); - - auto *ProfileData = I->getMetadata(LLVMContext::MD_prof); - if (!ProfileData) - return None; - - unsigned NOps = ProfileData->getNumOperands(); - if (NOps < 3) - return None; - - auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); - if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) - return None; - - SmallVector Weights(NOps - 1); - for (unsigned Idx = 1; Idx < NOps; Idx++) { - ConstantInt *Value = - mdconst::dyn_extract(ProfileData->getOperand(Idx)); - uint32_t V = Value->getZExtValue(); - Weights[Idx - 1] = V; - } - - return Weights; -} - // TODO: when clang allows c++17, use std::clamp instead uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) { if (value > hi) @@ -218,19 +191,17 @@ void checkBackendInstrumentation(Instruction &I, const ArrayRef RealWeights) { - auto ExpectedWeightsOpt = extractWeights(&I, I.getContext()); - if (!ExpectedWeightsOpt) + SmallVector ExpectedWeights; + if (!extractBranchWeights(I, ExpectedWeights)) return; - auto ExpectedWeights = ExpectedWeightsOpt.getValue(); verifyMisExpect(I, RealWeights, ExpectedWeights); } void checkFrontendInstrumentation(Instruction &I, const ArrayRef ExpectedWeights) { - auto RealWeightsOpt = extractWeights(&I, I.getContext()); - if (!RealWeightsOpt) + SmallVector RealWeights; + if (!extractBranchWeights(I, RealWeights)) return; - auto RealWeights = RealWeightsOpt.getValue(); verifyMisExpect(I, RealWeights, ExpectedWeights); } diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -57,6 +57,7 @@ #include "llvm/IR/NoFolder.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -1056,15 +1057,6 @@ return LHS->getValue().ult(RHS->getValue()) ? 1 : -1; } -static inline bool HasBranchWeights(const Instruction *I) { - MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof); - if (ProfMD && ProfMD->getOperand(0)) - if (MDString *MDS = dyn_cast(ProfMD->getOperand(0))) - return MDS->getString().equals("branch_weights"); - - return false; -} - /// Get Weights of a given terminator, the default weight is at the front /// of the vector. If TI is a conditional eq, we need to swap the branch-weight /// metadata. @@ -1183,8 +1175,8 @@ // Update the branch weight metadata along the way SmallVector Weights; - bool PredHasWeights = HasBranchWeights(PTI); - bool SuccHasWeights = HasBranchWeights(TI); + bool PredHasWeights = hasBranchWeightMD(*PTI); + bool SuccHasWeights = hasBranchWeightMD(*TI); if (PredHasWeights) { GetBranchWeights(PTI, Weights); @@ -4408,7 +4400,7 @@ // Get weight for TrueBB and FalseBB. uint32_t TrueWeight = 0, FalseWeight = 0; SmallVector Weights; - bool HasWeights = HasBranchWeights(SI); + bool HasWeights = hasBranchWeightMD(*SI); if (HasWeights) { GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) { @@ -5268,7 +5260,7 @@ BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest); // Update weight for the newly-created conditional branch. - if (HasBranchWeights(SI)) { + if (hasBranchWeightMD(*SI)) { SmallVector Weights; GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) {