diff --git a/llvm/include/llvm/IR/ProfDataUtils.h b/llvm/include/llvm/IR/ProfDataUtils.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/ProfDataUtils.h @@ -0,0 +1,54 @@ +#ifndef PROF_DATA_UTILS_H +#define PROF_DATA_UTILS_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/IR/Metadata.h" + +namespace llvm { + +/// Checks if an Instruction has MD_prof Metadata +bool hasMdProfMD(const Instruction &I); + +/// Checks if an MDNode contains Branch Weight Metadata +bool isBranchWeightMD(MDNode *ProfileData); + +/// Checks if an instructions has Branch Weight Metadata +/// +/// \param I The instruction to check +/// \return True if I has an MD_prof node containing Branch Weights. False +/// otherwise. +bool hasBranchWeightMD(const Instruction &I); + +/// Extract branch weights from MD_prof metadata +/// +/// \param ProfileData A pointer to an MDNode. +/// \param Weights An output vector to fill with branch weights +/// \return True if weights were extracted, False otherwise. When false Weights +/// will be cleared. +bool extractMDProfWeights(MDNode *ProfileData, + SmallVectorImpl &Weights); + +/// Extract branch weights from MD_prof metadata +/// +/// \param I The Instruction to extract weights from. +/// \param Weights An output vector to fill with branch weights +/// \return True if weights were extracted, False otherwise. When false Weights +/// will be cleared. +bool extractMDProfWeights(const Instruction &I, + SmallVectorImpl &Weights); + +/// Checks if an MDNode contains Value Profile Metadata +bool isValueProfileMD(MDNode *ProfileData); + +/// Checks if an Instruction has Value Profile Metadata +bool hasValueProfileMD(const Instruction &I); + +/// Checks if an MDNode contains Function Entry Metadata +bool isFunctionEntryMD(MDNode *ProfileData); + +/// Checks if an Instruction has Function Entry Metadata +bool hasFunctionEntryMD(const Instruction &I); + +} // namespace llvm +#endif diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -49,6 +49,7 @@ PassRegistry.cpp PassTimingInfo.cpp PrintPasses.cpp + ProfDataUtils.cpp SafepointIRVerifier.cpp ProfileSummary.cpp PseudoProbe.cpp diff --git a/llvm/lib/IR/ProfDataUtils.cpp b/llvm/lib/IR/ProfDataUtils.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/ProfDataUtils.cpp @@ -0,0 +1,138 @@ +#include "llvm/IR/ProfDataUtils.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/CommandLine.h" + +using namespace llvm; + +namespace { + +// MD_prof nodes have the following layout +// +// In general: +// { String name, Array of i32 } +// +// In terms of Types: +// { MDString, [i32, i32, ...]} +// +// Concretely for Branch Weights +// { "branch_weights", [i32 1, i32 10000]} +// +// We maintain some constants here to ensure that we access the branch weights +// correctly, and can change the behavior in the future + +// The index at which the weights vector starts +constexpr unsigned WeightsIndex = 1; + +// the minimum number of operands for branch weights +constexpr unsigned BWMinOpCount = 3; + +// the minimum number of operands for Value Profile +// Value taken from InstrProf.cpp +constexpr unsigned VPMinOpCount = 5; + +bool extractWeights(MDNode *ProfileData, unsigned NOps, + SmallVectorImpl &Weights, unsigned WeightsIdx) { + // Assume preconditions are already met (i.e. this is valid metadata) + assert(ProfileData && "ProfileData was nullptr in extractWeights"); + assert(WeightsIdx < NOps && "Weights Index must be less than NOps."); + Weights.resize(NOps - WeightsIdx); + for (unsigned Idx = 0, E = NOps - WeightsIdx; Idx != E; ++Idx) { + ConstantInt *Weight = mdconst::dyn_extract( + ProfileData->getOperand(Idx + WeightsIdx)); + if (!Weight) { + Weights.clear(); + return false; + } + assert(Weight->getValue().getActiveBits() <= 32 && + "Too many bits for uint32_t"); + Weights[Idx] = Weight->getZExtValue(); + } + return true; +} + +bool isTargetMD(MDNode *ProfData, const char *Name, unsigned MinOps) { + if (!ProfData || !Name || MinOps < 2) + return false; + + unsigned NOps = ProfData->getNumOperands(); + if (NOps < MinOps) + return false; + + auto *ProfDataName = dyn_cast(ProfData->getOperand(0)); + if (!ProfDataName) + return false; + + return ProfDataName->getString().equals(Name); +} + +} // namespace + +namespace llvm { + +// TODO: we may not need many of the APIs here, since branch_weights are so much +// more common than other MD_prof types. Perhaps we should remove support +// outsize of branch weights.? + +/// Checks if an Instruction has MD_prof Metadata +bool hasMdProfMD(const Instruction &I) { + return nullptr != I.getMetadata(LLVMContext::MD_prof); +} + +/// Checks if an MDNode contains Value Profile Metadata +bool isValueProfileMD(MDNode *ProfileData) { + return isTargetMD(ProfileData, "VP", VPMinOpCount); +} + +/// Checks if an Instruction has Value Profile Metadata +bool hasValueProfileMD(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + return isValueProfileMD(ProfileData); +} + +/// Checks if an MDNode contains Function Entry Metadata +bool isFunctionEntryMD(MDNode *ProfileData) { + // TODO: add constant? can func entry have fewer ops? + return isTargetMD(ProfileData, "function_entry", 3); +} + +/// Checks if an Instruction has Function Entry Metadata +bool hasFunctionEntryMD(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + return isFunctionEntryMD(ProfileData); +} + +bool isBranchWeightMD(MDNode *ProfileData) { + return isTargetMD(ProfileData, "branch_weights", BWMinOpCount); +} + +bool hasBranchWeightMD(const Instruction &I) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + return isBranchWeightMD(ProfileData); +} + +// TODO: Some MisExpect helper fucntions can be rewritten in terms of these +// helpers +bool extractMDProfWeights(MDNode *ProfileData, + SmallVectorImpl &Weights) { + if (!isBranchWeightMD(ProfileData)) + return false; + + unsigned NOps = ProfileData->getNumOperands(); + + return extractWeights(ProfileData, NOps, Weights, WeightsIndex); +} + +bool extractMDProfWeights(const Instruction &I, + SmallVectorImpl &Weights) { + auto *ProfileData = I.getMetadata(LLVMContext::MD_prof); + return extractMDProfWeights(ProfileData, Weights); +} + +} // namespace llvm diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -30,6 +30,7 @@ UnrollLoopTest.cpp ValueMapperTest.cpp VFABIUtils.cpp + ProfDataUtilTest.cpp ) set_property(TARGET UtilsTests PROPERTY FOLDER "Tests/UnitTests/TransformsTests") diff --git a/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp b/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp @@ -0,0 +1,93 @@ +//===----- ProfDataUtils.cpp - Unit tests for ProfDataUtils ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/ProfDataUtils.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Utils/BreakCriticalEdges.h" +#include "gtest/gtest.h" + +using namespace llvm; + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("ProfDataUtilsTests", errs()); + return Mod; +} + +TEST(ProfDataUtils, extractWeights) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"IR( +define void @foo(i1 %cond0) { +entry: + br i1 %cond0, label %bb0, label %bb1, !prof !1 +bb0: + %0 = mul i32 1, 2 + br label %bb1 +bb1: + ret void +} + +!1 = !{!"branch_weights", i32 1, i32 100000} +)IR"); + Function *F = M->getFunction("foo"); + auto &Entry = F->getEntryBlock(); + auto &I = Entry.front(); + auto Branch = dyn_cast(&I); + EXPECT_NE(nullptr, Branch); + auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof); + EXPECT_NE(ProfileData, nullptr); + EXPECT_TRUE(hasMdProfMD(I)); + SmallVector Weights; + EXPECT_TRUE(extractMDProfWeights(ProfileData, Weights)); + EXPECT_EQ(Weights[0], 1U); + EXPECT_EQ(Weights[1], 100000U); + EXPECT_EQ(Weights.size(), 2U); +} + +TEST(ProfDataUtils, NoWeights) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"IR( +define void @foo(i1 %cond0) { +entry: + br i1 %cond0, label %bb0, label %bb1 +bb0: + %0 = mul i32 1, 2 + br label %bb1 +bb1: + ret void +} +)IR"); + Function *F = M->getFunction("foo"); + auto &Entry = F->getEntryBlock(); + auto &I = Entry.front(); + auto Branch = dyn_cast(&I); + EXPECT_NE(nullptr, Branch); + auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof); + EXPECT_EQ(ProfileData, nullptr); + EXPECT_FALSE(hasMdProfMD(I)); + SmallVector Weights; + EXPECT_FALSE(extractMDProfWeights(ProfileData, Weights)); + EXPECT_EQ(Weights.size(), 0U); +}