diff --git a/llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h b/llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h --- a/llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h +++ b/llvm/include/llvm/Analysis/BlockFrequencyInfoImpl.h @@ -24,6 +24,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" #include "llvm/IR/BasicBlock.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/Support/BlockFrequency.h" #include "llvm/Support/BranchProbability.h" #include "llvm/Support/CommandLine.h" @@ -547,6 +548,7 @@ template struct TypeMap {}; template <> struct TypeMap { using BlockT = BasicBlock; + using BlockKeyT = AssertingVH; using FunctionT = Function; using BranchProbabilityInfoT = BranchProbabilityInfo; using LoopT = Loop; @@ -554,6 +556,7 @@ }; template <> struct TypeMap { using BlockT = MachineBasicBlock; + using BlockKeyT = const MachineBasicBlock *; using FunctionT = MachineFunction; using BranchProbabilityInfoT = MachineBranchProbabilityInfo; using LoopT = MachineLoop; @@ -845,6 +848,7 @@ friend struct bfi_detail::BlockEdgesAdder; using BlockT = typename bfi_detail::TypeMap::BlockT; + using BlockKeyT = typename bfi_detail::TypeMap::BlockKeyT; using FunctionT = typename bfi_detail::TypeMap::FunctionT; using BranchProbabilityInfoT = typename bfi_detail::TypeMap::BranchProbabilityInfoT; @@ -857,9 +861,11 @@ const LoopInfoT *LI = nullptr; const FunctionT *F = nullptr; + class BFICallbackVH; + // All blocks in reverse postorder. std::vector RPOT; - DenseMap Nodes; + DenseMap> Nodes; using rpot_iterator = typename std::vector::const_iterator; @@ -871,7 +877,8 @@ BlockNode getNode(const rpot_iterator &I) const { return BlockNode(getIndex(I)); } - BlockNode getNode(const BlockT *BB) const { return Nodes.lookup(BB); } + + BlockNode getNode(const BlockT *BB) const { return Nodes.lookup(BB).first; } const BlockT *getBlock(const BlockNode &Node) const { assert(Node.Index < RPOT.size()); @@ -992,6 +999,13 @@ void setBlockFreq(const BlockT *BB, uint64_t Freq); + void forgetBlock(const BlockT *BB) { + // We don't erase corresponding items from `Freqs`, `RPOT` and other to + // avoid invalidating indices. Doing so would have saved some memory, but + // it's not worth it. + Nodes.erase(BB); + } + Scaled64 getFloatingBlockFreq(const BlockT *BB) const { return BlockFrequencyInfoImplBase::getFloatingBlockFreq(getNode(BB)); } @@ -1019,6 +1033,30 @@ } }; +template <> +class BlockFrequencyInfoImpl::BFICallbackVH : public CallbackVH { + BlockFrequencyInfoImpl *BFIImpl; + +public: + BFICallbackVH() = default; + + BFICallbackVH(const BasicBlock *BB, BlockFrequencyInfoImpl *BFIImpl) + : CallbackVH(BB), BFIImpl(BFIImpl) {} + + void deleted() override { + BFIImpl->forgetBlock(cast(getValPtr())); + } +}; + +/// Dummy implementation since MachineBasicBlocks aren't Values, so ValueHandles +/// don't apply to them. +template <> +class BlockFrequencyInfoImpl::BFICallbackVH { +public: + BFICallbackVH() = default; + BFICallbackVH(const MachineBasicBlock *, BlockFrequencyInfoImpl *) {} +}; + template void BlockFrequencyInfoImpl::calculate(const FunctionT &F, const BranchProbabilityInfoT &BPI, @@ -1066,7 +1104,7 @@ // BlockNode for it assigned with a new index. The index can be determined // by the size of Freqs. BlockNode NewNode(Freqs.size()); - Nodes[BB] = NewNode; + Nodes[BB] = {NewNode, BFICallbackVH(BB, this)}; Freqs.emplace_back(); BlockFrequencyInfoImplBase::setBlockFreq(NewNode, Freq); } @@ -1086,7 +1124,7 @@ BlockNode Node = getNode(I); LLVM_DEBUG(dbgs() << " - " << getIndex(I) << ": " << getBlockName(Node) << "\n"); - Nodes[*I] = Node; + Nodes[*I] = {Node, BFICallbackVH(*I, this)}; } Working.reserve(RPOT.size()); diff --git a/llvm/include/llvm/IR/ValueHandle.h b/llvm/include/llvm/IR/ValueHandle.h --- a/llvm/include/llvm/IR/ValueHandle.h +++ b/llvm/include/llvm/IR/ValueHandle.h @@ -414,6 +414,7 @@ public: CallbackVH() : ValueHandleBase(Callback) {} CallbackVH(Value *P) : ValueHandleBase(Callback, P) {} + CallbackVH(const Value *P) : CallbackVH(const_cast(P)) {} operator Value*() const { return getValPtr();