diff --git a/llvm/lib/Analysis/BranchProbabilityInfo.cpp b/llvm/lib/Analysis/BranchProbabilityInfo.cpp --- a/llvm/lib/Analysis/BranchProbabilityInfo.cpp +++ b/llvm/lib/Analysis/BranchProbabilityInfo.cpp @@ -30,6 +30,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" @@ -400,24 +401,18 @@ SmallVector Weights; SmallVector UnreachableIdxs; SmallVector ReachableIdxs; - Weights.reserve(TI->getNumSuccessors()); - for (unsigned I = 1, E = WeightsNode->getNumOperands(); I != E; ++I) { - ConstantInt *Weight = - mdconst::dyn_extract(WeightsNode->getOperand(I)); - if (!Weight) - return false; - assert(Weight->getValue().getActiveBits() <= 32 && - "Too many bits for uint32_t"); - Weights.push_back(Weight->getZExtValue()); - WeightSum += Weights.back(); + + extractMDProfWeights(*TI, Weights); + for (unsigned I = 0, E = Weights.size(); I != E; ++I) { + WeightSum += Weights[I]; const LoopBlock SrcLoopBB = getLoopBlock(BB); - const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I - 1)); + const LoopBlock DstLoopBB = getLoopBlock(TI->getSuccessor(I)); auto EstimatedWeight = getEstimatedEdgeWeight({SrcLoopBB, DstLoopBB}); if (EstimatedWeight && *EstimatedWeight <= static_cast(BlockExecWeight::UNREACHABLE)) - UnreachableIdxs.push_back(I - 1); + UnreachableIdxs.push_back(I); else - ReachableIdxs.push_back(I - 1); + ReachableIdxs.push_back(I); } assert(Weights.size() == TI->getNumSuccessors() && "Checked above"); diff --git a/llvm/lib/IR/Metadata.cpp b/llvm/lib/IR/Metadata.cpp --- a/llvm/lib/IR/Metadata.cpp +++ b/llvm/lib/IR/Metadata.cpp @@ -40,6 +40,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/TrackingMDRef.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" diff --git a/llvm/lib/ProfileData/InstrProf.cpp b/llvm/lib/ProfileData/InstrProf.cpp --- a/llvm/lib/ProfileData/InstrProf.cpp +++ b/llvm/lib/ProfileData/InstrProf.cpp @@ -29,6 +29,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/ProfileData/InstrProfReader.h" #include "llvm/Support/Casting.h" @@ -1043,22 +1044,11 @@ uint32_t &ActualNumValueData, uint64_t &TotalC, bool GetNoICPValue) { MDNode *MD = Inst.getMetadata(LLVMContext::MD_prof); - if (!MD) + if(!isValueProfileMD(MD)) return false; unsigned NOps = MD->getNumOperands(); - if (NOps < 5) - return false; - - // Operand 0 is a string tag "VP": - MDString *Tag = cast(MD->getOperand(0)); - if (!Tag) - return false; - - if (!Tag->getString().equals("VP")) - return false; - // Now check kind: ConstantInt *KindInt = mdconst::dyn_extract(MD->getOperand(1)); if (!KindInt) diff --git a/llvm/lib/Transforms/Utils/MisExpect.cpp b/llvm/lib/Transforms/Utils/MisExpect.cpp --- a/llvm/lib/Transforms/Utils/MisExpect.cpp +++ b/llvm/lib/Transforms/Utils/MisExpect.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -118,34 +119,6 @@ namespace llvm { namespace misexpect { -// Helper function to extract branch weights into a vector -Optional> extractWeights(Instruction *I, - LLVMContext &Ctx) { - assert(I && "MisExpect::extractWeights given invalid pointer"); - - auto *ProfileData = I->getMetadata(LLVMContext::MD_prof); - if (!ProfileData) - return None; - - unsigned NOps = ProfileData->getNumOperands(); - if (NOps < 3) - return None; - - auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); - if (!ProfDataName || !ProfDataName->getString().equals("branch_weights")) - return None; - - SmallVector Weights(NOps - 1); - for (unsigned Idx = 1; Idx < NOps; Idx++) { - ConstantInt *Value = - mdconst::dyn_extract(ProfileData->getOperand(Idx)); - uint32_t V = Value->getZExtValue(); - Weights[Idx - 1] = V; - } - - return Weights; -} - // TODO: when clang allows c++17, use std::clamp instead uint32_t clamp(uint64_t value, uint32_t low, uint32_t hi) { if (value > hi) @@ -218,19 +191,17 @@ void checkBackendInstrumentation(Instruction &I, const ArrayRef RealWeights) { - auto ExpectedWeightsOpt = extractWeights(&I, I.getContext()); - if (!ExpectedWeightsOpt) + SmallVector ExpectedWeights; + if (!extractMDProfWeights(I, ExpectedWeights)) return; - auto ExpectedWeights = ExpectedWeightsOpt.getValue(); verifyMisExpect(I, RealWeights, ExpectedWeights); } void checkFrontendInstrumentation(Instruction &I, const ArrayRef ExpectedWeights) { - auto RealWeightsOpt = extractWeights(&I, I.getContext()); - if (!RealWeightsOpt) + SmallVector RealWeights; + if (!extractMDProfWeights(I, RealWeights)) return; - auto RealWeights = RealWeightsOpt.getValue(); verifyMisExpect(I, RealWeights, ExpectedWeights); } diff --git a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp --- a/llvm/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/llvm/lib/Transforms/Utils/SimplifyCFG.cpp @@ -57,6 +57,7 @@ #include "llvm/IR/NoFolder.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ProfDataUtils.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/IR/User.h" @@ -1056,15 +1057,6 @@ return LHS->getValue().ult(RHS->getValue()) ? 1 : -1; } -static inline bool HasBranchWeights(const Instruction *I) { - MDNode *ProfMD = I->getMetadata(LLVMContext::MD_prof); - if (ProfMD && ProfMD->getOperand(0)) - if (MDString *MDS = dyn_cast(ProfMD->getOperand(0))) - return MDS->getString().equals("branch_weights"); - - return false; -} - /// Get Weights of a given terminator, the default weight is at the front /// of the vector. If TI is a conditional eq, we need to swap the branch-weight /// metadata. @@ -1183,8 +1175,8 @@ // Update the branch weight metadata along the way SmallVector Weights; - bool PredHasWeights = HasBranchWeights(PTI); - bool SuccHasWeights = HasBranchWeights(TI); + bool PredHasWeights = hasBranchWeightMD(*PTI); + bool SuccHasWeights = hasBranchWeightMD(*TI); if (PredHasWeights) { GetBranchWeights(PTI, Weights); @@ -4408,7 +4400,7 @@ // Get weight for TrueBB and FalseBB. uint32_t TrueWeight = 0, FalseWeight = 0; SmallVector Weights; - bool HasWeights = HasBranchWeights(SI); + bool HasWeights = hasBranchWeightMD(*SI); if (HasWeights) { GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) { @@ -5268,7 +5260,7 @@ BranchInst *NewBI = Builder.CreateCondBr(Cmp, ContiguousDest, OtherDest); // Update weight for the newly-created conditional branch. - if (HasBranchWeights(SI)) { + if (hasBranchWeightMD(*SI)) { SmallVector Weights; GetBranchWeights(SI, Weights); if (Weights.size() == 1 + SI->getNumCases()) {