diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -0,0 +1,56 @@ +#ifndef PROF_DATA_UTILS_H +#define PROF_DATA_UTILS_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/IR/Metadata.h" + +namespace llvm { + +/// Checks if an Instruction has MD_prof Metadata +bool hasProfMD(const Instruction &I); + +/// Checks if an MDNode contains Branch Weight Metadata +bool isBranchWeightMD(const MDNode *ProfileData); + +/// Checks if an instructions has Branch Weight Metadata +/// +/// \param I The instruction to check +/// \return True if I has an MD_prof node containing Branch Weights. False +/// otherwise. +bool hasBranchWeightMD(const Instruction &I); + +/// Extract branch weights from MD_prof metadata +/// +/// \param ProfileData A pointer to an MDNode. +/// \param Weights An output vector to fill with branch weights +/// \return True if weights were extracted, False otherwise. When false Weights +/// will be cleared. +bool extractBranchWeights(const MDNode *ProfileData, + SmallVectorImpl &Weights); + +/// Extract branch weights attatched to an Instruction +/// +/// \param I The Instruction to extract weights from. +/// \param Weights An output vector to fill with branch weights +/// \return True if weights were extracted, False otherwise. When false Weights +/// will be cleared. +bool extractBranchWeights(const Instruction &I, + SmallVectorImpl &Weights); + +/// Retrieve the raw weight values of a conditional branch or select. +/// Returns true on success with profile weights filled in. +/// Returns false if no metadata or invalid metadata was found. +bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, + uint64_t &FalseVal); + +/// Retrieve the total of all weights from MD_prof data. +/// +/// \param ProfileData The profile data to extract the total weight from +/// \param TotalWeights input variable to fill with total weights +/// \return true on success with profile total weights filled in. +/// \return false if no metadata was found. +bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights); + +} // namespace llvm +#endif diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -49,6 +49,7 @@ PassRegistry.cpp PassTimingInfo.cpp PrintPasses.cpp + ProfDataUtils.cpp SafepointIRVerifier.cpp ProfileSummary.cpp PseudoProbe.cpp diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -0,0 +1,150 @@ +#include "llvm/IR/ProfDataUtils.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/CommandLine.h" + +using namespace llvm; + +namespace { + +// MD_prof nodes have the following layout +// +// In general: +// { String name, Array of i32 } +// +// In terms of Types: +// { MDString, [i32, i32, ...]} +// +// Concretely for Branch Weights +// { "branch_weights", [i32 1, i32 10000]} +// +// We maintain some constants here to ensure that we access the branch weights +// correctly, and can change the behavior in the future if the layout changes + +// The index at which the weights vector starts +constexpr unsigned WeightsIdx = 1; + +// the minimum number of operands for MD_prof nodes with branch weights +constexpr unsigned MinBWOps = 3; + +bool extractWeights(const MDNode *ProfileData, + SmallVectorImpl &Weights) { + // Assume preconditions are already met (i.e. this is valid metadata) + assert(ProfileData && "ProfileData was nullptr in extractWeights"); + unsigned NOps = ProfileData->getNumOperands(); + + assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); + Weights.resize(NOps - WeightsIdx); + + for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { + ConstantInt *Weight = + mdconst::dyn_extract(ProfileData->getOperand(Idx)); + assert(Weight && "Malformed branch_weight in MD_prof node"); + assert(Weight->getValue().getActiveBits() <= 32 && + "Too many bits for uint32_t"); + Weights[Idx - WeightsIdx] = Weight->getZExtValue(); + } + return true; +} + +// We may want to add support for other MD_prof types, so provide an abstraction +// for checking the metadata type. +bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { + // TODO: This routine may be simplified if MD_prof used an enum instead of a + // string to differentiate the types of MD_prof nodes. + if (!ProfData || !Name || MinOps < 2) + return false; + + unsigned NOps = ProfData->getNumOperands(); + if (NOps < MinOps) + return false; + + auto *ProfDataName = dyn_cast(ProfData->getOperand(0)); + if (!ProfDataName) + return false; + + return ProfDataName->getString().equals(Name); +} + +} // namespace + +namespace llvm { + +bool hasProfMD(const Instruction &I) { + return nullptr != I.getMetadata(LLVMContext::MD_prof); +} + +bool isBranchWeightMD(const MDNode *ProfileData) { + return isTargetMD(ProfileData, "branch_weights", MinBWOps); +} + +bool hasBranchWeightMD(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + return isBranchWeightMD(ProfileData); +} + +bool extractBranchWeights(const MDNode *ProfileData, + SmallVectorImpl &Weights) { + if (!isBranchWeightMD(ProfileData)) + return false; + return extractWeights(ProfileData, Weights); +} + +bool extractBranchWeights(const Instruction &I, + SmallVectorImpl &Weights) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + return extractBranchWeights(ProfileData, Weights); +} + +bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal, + uint64_t &FalseVal) { + assert((I.getOpcode() == Instruction::Br || + I.getOpcode() == Instruction::Select) && + "Looking for branch weights on something besides branch or select"); + + SmallVector Weights; + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + if (!extractBranchWeights(ProfileData, Weights)) + return false; + + if (Weights.size() > 2) + return false; + + TrueVal = Weights[0]; + FalseVal = Weights[1]; + return true; +} + +bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) { + TotalVal = 0; + if (!ProfileData) + return false; + + auto *ProfDataName = dyn_cast(ProfileData->getOperand(0)); + if (!ProfDataName) + return false; + + if (ProfDataName->getString().equals("branch_weights")) { + for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx++) { + auto *V = mdconst::dyn_extract(ProfileData->getOperand(Idx)); + assert(V && "Malformed branch_weight in MD_prof node"); + TotalVal += V->getValue().getZExtValue(); + } + return true; + } else if (ProfDataName->getString().equals("VP") && + ProfileData->getNumOperands() > 3) { + TotalVal = mdconst::dyn_extract(ProfileData->getOperand(2)) + ->getValue() + .getZExtValue(); + return true; + } + return false; +} + +} // namespace llvm diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -30,6 +30,7 @@ UnrollLoopTest.cpp ValueMapperTest.cpp VFABIUtils.cpp + ProfDataUtilTest.cpp ) set_property(TARGET UtilsTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests") diff --git a/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp b/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp @@ -0,0 +1,93 @@ +//===----- ProfDataUtils.cpp - Unit tests for ProfDataUtils ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/ProfDataUtils.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Utils/BreakCriticalEdges.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("ProfDataUtilsTests", errs()); + return Mod; +} + +TEST(ProfDataUtils, extractWeights) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"IR( +define void @foo(i1 %cond0) { +entry: + br i1 %cond0, label %bb0, label %bb1, !prof !1 +bb0: + %0 = mul i32 1, 2 + br label %bb1 +bb1: + ret void +} + +!1 = !{!"branch_weights", i32 1, i32 100000} +)IR"); + Function *F = M->getFunction("foo"); + auto &Entry = F->getEntryBlock(); + auto &I = Entry.front(); + auto *Branch = dyn_cast(&I); + EXPECT_NE(nullptr, Branch); + auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof); + EXPECT_NE(ProfileData, nullptr); + EXPECT_TRUE(hasProfMD(I)); + SmallVector Weights; + EXPECT_TRUE(extractBranchWeights(ProfileData, Weights)); + EXPECT_EQ(Weights[0], 1U); + EXPECT_EQ(Weights[1], 100000U); + EXPECT_EQ(Weights.size(), 2U); +} + +TEST(ProfDataUtils, NoWeights) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"IR( +define void @foo(i1 %cond0) { +entry: + br i1 %cond0, label %bb0, label %bb1 +bb0: + %0 = mul i32 1, 2 + br label %bb1 +bb1: + ret void +} +)IR"); + Function *F = M->getFunction("foo"); + auto &Entry = F->getEntryBlock(); + auto &I = Entry.front(); + auto *Branch = dyn_cast(&I); + EXPECT_NE(nullptr, Branch); + auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof); + EXPECT_EQ(ProfileData, nullptr); + EXPECT_FALSE(hasProfMD(I)); + SmallVector Weights; + EXPECT_FALSE(extractBranchWeights(ProfileData, Weights)); + EXPECT_EQ(Weights.size(), 0U); +}