Index: include/llvm/IR/Function.h =================================================================== --- include/llvm/IR/Function.h +++ include/llvm/IR/Function.h @@ -19,11 +19,16 @@ #define LLVM_IR_FUNCTION_H #include "llvm/ADT/iterator_range.h" +#include "llvm/ADT/Optional.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallingConv.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/GlobalObject.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" #include "llvm/Support/Compiler.h" namespace llvm { @@ -194,6 +199,30 @@ AttributeSet::FunctionIndex, Kind, Value)); } + /// @brief Set the entry count for this function. + void setEntryCount(uint64_t Count) { + MDBuilder MDB(getContext()); + setMetadata(LLVMContext::MD_prof, MDB.createFunctionEntryCount(Count)); + } + + /// @brief Return true if the function has a known entry count. + bool hasEntryCount() const { + MDNode *MD = getMetadata(LLVMContext::MD_prof); + if (MD && MD->getOperand(0)) + if (MDString* MDS = dyn_cast(MD->getOperand(0))) + return MDS->getString().equals("function_entry_count"); + return false; + } + + /// @brief Get the entry count for this function. + Optional getEntryCount() const { + if (!hasEntryCount()) + return None; + MDNode *MD = getMetadata(LLVMContext::MD_prof); + ConstantInt *CI = mdconst::extract(MD->getOperand(1)); + return CI->getValue().getZExtValue(); + } + /// @brief Return true if the function has the attribute. bool hasFnAttribute(Attribute::AttrKind Kind) const { return AttributeSets.hasAttribute(AttributeSet::FunctionIndex, Kind); Index: include/llvm/IR/MDBuilder.h =================================================================== --- include/llvm/IR/MDBuilder.h +++ include/llvm/IR/MDBuilder.h @@ -60,6 +60,9 @@ /// \brief Return metadata containing a number of branch weights. MDNode *createBranchWeights(ArrayRef Weights); + /// \brief Return metadata containing the entry count for a function. + MDNode *createFunctionEntryCount(uint64_t Count); + //===------------------------------------------------------------------===// // Range metadata. //===------------------------------------------------------------------===// Index: lib/IR/MDBuilder.cpp =================================================================== --- lib/IR/MDBuilder.cpp +++ lib/IR/MDBuilder.cpp @@ -53,6 +53,16 @@ return MDNode::get(Context, Vals); } +MDNode *MDBuilder::createFunctionEntryCount(uint64_t Count) { + SmallVector Vals(2); + Vals[0] = createString("function_entry_count"); + + Type *Int64Ty = Type::getInt64Ty(Context); + Vals[1] = createConstant(ConstantInt::get(Int64Ty, Count)); + + return MDNode::get(Context, Vals); +} + MDNode *MDBuilder::createRange(const APInt &Lo, const APInt &Hi) { assert(Lo.getBitWidth() == Hi.getBitWidth() && "Mismatched bitwidths!"); Index: unittests/IR/MetadataTest.cpp =================================================================== --- unittests/IR/MetadataTest.cpp +++ unittests/IR/MetadataTest.cpp @@ -2272,4 +2272,12 @@ EXPECT_FALSE(verifyFunction(*F)); } +TEST_F(FunctionAttachmentTest, EntryCount) { + Function *F = getFunction("foo"); + EXPECT_FALSE(F->hasEntryCount()); + F->setEntryCount(12304); + EXPECT_TRUE(F->hasEntryCount()); + EXPECT_EQ(12304u, *F->getEntryCount()); +} + }