Index: llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -24,6 +24,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/Support/BlockFrequency.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" @@ -547,6 +548,7 @@ template struct TypeMap {}; template <> struct TypeMap { using BlockT = BasicBlock; + using BlockKeyT = AssertingVH; using FunctionT = Function; using BranchProbabilityInfoT = BranchProbabilityInfo; using LoopT = Loop; @@ -554,6 +556,7 @@ }; template <> struct TypeMap { using BlockT = MachineBasicBlock; + using BlockKeyT = MachineBasicBlock *; using FunctionT = MachineFunction; using BranchProbabilityInfoT = MachineBranchProbabilityInfo; using LoopT = MachineLoop; @@ -845,6 +848,7 @@ friend struct bfi_detail::BlockEdgesAdder; using BlockT = typename bfi_detail::TypeMap::BlockT; + using BlockKeyT = typename bfi_detail::TypeMap::BlockKeyT; using FunctionT = typename bfi_detail::TypeMap::FunctionT; using BranchProbabilityInfoT = typename bfi_detail::TypeMap::BranchProbabilityInfoT; @@ -857,9 +861,11 @@ const LoopInfoT *LI = nullptr; const FunctionT *F = nullptr; + class BFICallbackVH; + // All blocks in reverse postorder. std::vector RPOT; - DenseMap Nodes; + DenseMap> Nodes; using rpot_iterator = typename std::vector::const_iterator; @@ -871,7 +877,10 @@ BlockNode getNode(const rpot_iterator &I) const { return BlockNode(getIndex(I)); } - BlockNode getNode(const BlockT *BB) const { return Nodes.lookup(BB); } + + BlockNode getNode(const BlockT *BB) const { + return Nodes.lookup(const_cast(BB)).first; + } const BlockT *getBlock(const BlockNode &Node) const { assert(Node.Index < RPOT.size()); @@ -992,6 +1001,13 @@ void setBlockFreq(const BlockT *BB, uint64_t Freq); + void forgetBlock(const BlockT *BB) { + // We don't erase corresponding items from `Freqs`, `RPOT` and other to + // avoid invalidating indices. Doing so would have saved some memory, but + // it's not worth it. + Nodes.erase(const_cast(BB)); + } + Scaled64 getFloatingBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB)); } @@ -1019,6 +1035,30 @@ } }; +template <> +class BlockFrequencyInfoImpl::BFICallbackVH : public CallbackVH { + BlockFrequencyInfoImpl *BFIImpl; + +public: + BFICallbackVH() = default; + + BFICallbackVH(const BasicBlock *BB, BlockFrequencyInfoImpl *BFIImpl) + : CallbackVH(const_cast(BB)), BFIImpl(BFIImpl) {} + + void deleted() override { + BFIImpl->forgetBlock(cast(getValPtr())); + } +}; + +/// Dummy implementation since MachineBasicBlocks aren't Values, so ValueHandles +/// don't apply to them. +template <> +class BlockFrequencyInfoImpl::BFICallbackVH { +public: + BFICallbackVH() = default; + BFICallbackVH(const MachineBasicBlock *, BlockFrequencyInfoImpl *) {} +}; + template void BlockFrequencyInfoImpl::calculate(const FunctionT &F, const BranchProbabilityInfoT &BPI, @@ -1052,21 +1092,21 @@ // blocks, if any. This is to distinguish between known/existing unreachable // blocks and unknown blocks. for (const BlockT &BB : F) - if (!Nodes.count(&BB)) + if (!Nodes.count(const_cast(&BB))) setBlockFreq(&BB, 0); } } template void BlockFrequencyInfoImpl::setBlockFreq(const BlockT *BB, uint64_t Freq) { - if (Nodes.count(BB)) + if (Nodes.count(const_cast(BB))) BlockFrequencyInfoImplBase::setBlockFreq(getNode(BB), Freq); else { // If BB is a newly added block after BFI is done, we need to create a new // BlockNode for it assigned with a new index. The index can be determined // by the size of Freqs. BlockNode NewNode(Freqs.size()); - Nodes[BB] = NewNode; + Nodes[const_cast(BB)] = {NewNode, BFICallbackVH(BB, this)}; Freqs.emplace_back(); BlockFrequencyInfoImplBase::setBlockFreq(NewNode, Freq); } @@ -1086,7 +1126,7 @@ BlockNode Node = getNode(I); LLVM_DEBUG(dbgs() << " - " << getIndex(I) << ": " << getBlockName(Node) << "\n"); - Nodes[*I] = Node; + Nodes[const_cast(*I)] = {Node, BFICallbackVH(*I, this)}; } Working.reserve(RPOT.size());