Index: include/llvm/Analysis/BlockFrequencyInfo.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfo.h +++ include/llvm/Analysis/BlockFrequencyInfo.h @@ -51,6 +51,10 @@ /// floating points. BlockFrequency getBlockFreq(const BasicBlock *BB) const; + // Set the frequency of the given basic block. Return true if it succeeds and + // false otherwise. + bool setBlockFreq(const BasicBlock *BB, uint64_t Freq); + // Print the block frequency Freq to OS using the current functions entry // frequency to convert freq into a relative decimal form. raw_ostream &printBlockFreq(raw_ostream &OS, const BlockFrequency Freq) const; Index: include/llvm/Analysis/BlockFrequencyInfoImpl.h =================================================================== --- include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -456,6 +456,8 @@ BlockFrequency getBlockFreq(const BlockNode &Node) const; + bool setBlockFreq(const BlockNode &Node, uint64_t Freq); + raw_ostream &printBlockFreq(raw_ostream &OS, const BlockNode &Node) const; raw_ostream &printBlockFreq(raw_ostream &OS, const BlockFrequency &Freq) const; @@ -886,6 +888,7 @@ BlockFrequency getBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getBlockFreq(getNode(BB)); } + bool setBlockFreq(const BlockT *BB, uint64_t Freq); Scaled64 getFloatingBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB)); } @@ -938,6 +941,21 @@ finalizeMetrics(); } +template +bool BlockFrequencyInfoImpl::setBlockFreq(const BlockT *BB, uint64_t Freq) { + if (Nodes.count(BB)) + return BlockFrequencyInfoImplBase::setBlockFreq(getNode(BB), Freq); + else { + // If BB is a newly added block after BFI is done, we need to create a new + // BlockNode for it assigned with a new index. The index can be determined + // by the size of Freqs. + BlockNode NewNode(Freqs.size()); + Nodes[BB] = NewNode; + Freqs.emplace_back(); + return BlockFrequencyInfoImplBase::setBlockFreq(NewNode, Freq); + } +} + template void BlockFrequencyInfoImpl::initializeRPOT() { const BlockT *Entry = F->begin(); RPOT.reserve(F->size()); Index: lib/Analysis/BlockFrequencyInfo.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfo.cpp +++ lib/Analysis/BlockFrequencyInfo.cpp @@ -150,6 +150,11 @@ return BFI ? BFI->getBlockFreq(BB) : 0; } +bool BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, + uint64_t Freq) { + return BFI ? BFI->setBlockFreq(BB, Freq) : false; +} + /// Pop up a ghostview window with the current block frequency propagation /// rendered using dot. void BlockFrequencyInfo::view() const { Index: lib/Analysis/BlockFrequencyInfoImpl.cpp =================================================================== --- lib/Analysis/BlockFrequencyInfoImpl.cpp +++ lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -527,6 +527,14 @@ return Freqs[Node.Index].Scaled; } +bool BlockFrequencyInfoImplBase::setBlockFreq(const BlockNode &Node, + uint64_t Freq) { + if (!Node.isValid()) + return false; + Freqs[Node.Index].Integer = Freq; + return true; +} + std::string BlockFrequencyInfoImplBase::getBlockName(const BlockNode &Node) const { return std::string();