diff --git a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp --- a/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -281,6 +281,16 @@ RematerializedValueMapTy RematerializedValues; }; +struct RematerizlizationCandidateRecord { + // Chain from derived pointer to base. + SmallVector ChainToBase; + // Original base. + Value *RootOfChain; + // Cost of chain. + InstructionCost Cost; +}; +using RematCandTy = MapVector; + } // end anonymous namespace static ArrayRef GetDeoptBundleOperands(const CallBase *Call) { @@ -2221,27 +2231,25 @@ return true; } -// From the statepoint live set pick values that are cheaper to recompute then -// to relocate. Remove this values from the live set, rematerialize them after -// statepoint and record them in "Info" structure. Note that similar to -// relocated values we don't do any user adjustments here. -static void rematerializeLiveValues(CallBase *Call, - PartiallyConstructedSafepointRecord &Info, - PointerToBaseTy &PointerToBase, - TargetTransformInfo &TTI) { +// Find derived pointers that can be recomputed cheap enough and fill +// RematerizationCandidates with such candidates. +static void +findRematerializationCandidates(PointerToBaseTy PointerToBase, + RematCandTy &RematerizationCandidates, + TargetTransformInfo &TTI) { const unsigned int ChainLengthThreshold = 10; - // Record values we are going to delete from this statepoint live set. - // We can not di this in following loop due to iterator invalidation. - SmallVector LiveValuesToBeDeleted; + for (auto P2B : PointerToBase) { + auto *Derived = P2B.first; + auto *Base = P2B.second; + // Consider only derived pointers. + if (Derived == Base) + continue; - for (Value *LiveValue: Info.LiveSet) { - // For each live pointer find its defining chain + // For each live pointer find its defining chain. SmallVector ChainToBase; - assert(PointerToBase.count(LiveValue)); Value *RootOfChain = - findRematerializableChainToBasePointer(ChainToBase, - LiveValue); + findRematerializableChainToBasePointer(ChainToBase, Derived); // Nothing to do, or chain is too long if ( ChainToBase.size() == 0 || @@ -2250,9 +2258,9 @@ // Handle the scenario where the RootOfChain is not equal to the // Base Value, but they are essentially the same phi values. - if (RootOfChain != PointerToBase[LiveValue]) { + if (RootOfChain != PointerToBase[Derived]) { PHINode *OrigRootPhi = dyn_cast(RootOfChain); - PHINode *AlternateRootPhi = dyn_cast(PointerToBase[LiveValue]); + PHINode *AlternateRootPhi = dyn_cast(PointerToBase[Derived]); if (!OrigRootPhi || !AlternateRootPhi) continue; // PHI nodes that have the same incoming values, and belonging to the same @@ -2266,33 +2274,61 @@ // deficiency in the findBasePointer algorithm. if (!AreEquivalentPhiNodes(*OrigRootPhi, *AlternateRootPhi)) continue; - // Now that the phi nodes are proved to be the same, assert that - // findBasePointer's newly generated AlternateRootPhi is present in the - // liveset of the call. - assert(Info.LiveSet.count(AlternateRootPhi)); } - // Compute cost of this chain + // Compute cost of this chain. InstructionCost Cost = chainToBasePointerCost(ChainToBase, TTI); // TODO: We can also account for cases when we will be able to remove some // of the rematerialized values by later optimization passes. I.e if // we rematerialized several intersecting chains. Or if original values // don't have any uses besides this statepoint. + // Ok, there is a candidate. + RematerizlizationCandidateRecord Record; + Record.ChainToBase = ChainToBase; + Record.RootOfChain = RootOfChain; + Record.Cost = Cost; + RematerizationCandidates.insert({ Derived, Record }); + } +} + +// From the statepoint live set pick values that are cheaper to recompute then +// to relocate. Remove this values from the live set, rematerialize them after +// statepoint and record them in "Info" structure. Note that similar to +// relocated values we don't do any user adjustments here. +static void rematerializeLiveValues(CallBase *Call, + PartiallyConstructedSafepointRecord &Info, + PointerToBaseTy &PointerToBase, + RematCandTy &RematerizationCandidates, + TargetTransformInfo &TTI) { + // Record values we are going to delete from this statepoint live set. + // We can not di this in following loop due to iterator invalidation. + SmallVector LiveValuesToBeDeleted; + + for (Value *LiveValue : Info.LiveSet) { + auto It = RematerizationCandidates.find(LiveValue); + if (It == RematerizationCandidates.end()) + continue; + + RematerizlizationCandidateRecord &Record = It->second; + + InstructionCost Cost = Record.Cost; // For invokes we need to rematerialize each chain twice - for normal and // for unwind basic blocks. Model this by multiplying cost by two. - if (isa(Call)) { + if (isa(Call)) Cost *= 2; - } - // If it's too expensive - skip it + + // If it's too expensive - skip it. if (Cost >= RematerializationThreshold) continue; // Remove value from the live set LiveValuesToBeDeleted.push_back(LiveValue); - // Clone instructions and record them inside "Info" structure + // Clone instructions and record them inside "Info" structure. - // Walk backwards to visit top-most instructions first + // For each live pointer find get its defining chain. + SmallVector ChainToBase = Record.ChainToBase; + // Walk backwards to visit top-most instructions first. std::reverse(ChainToBase.begin(), ChainToBase.end()); // Utility function which clones all instructions from "ChainToBase" @@ -2352,7 +2388,7 @@ Instruction *InsertBefore = Call->getNextNode(); assert(InsertBefore); Instruction *RematerializedValue = rematerializeChain( - InsertBefore, RootOfChain, PointerToBase[LiveValue]); + InsertBefore, Record.RootOfChain, PointerToBase[LiveValue]); Info.RematerializedValues[RematerializedValue] = LiveValue; } else { auto *Invoke = cast(Call); @@ -2363,9 +2399,9 @@ &*Invoke->getUnwindDest()->getFirstInsertionPt(); Instruction *NormalRematerializedValue = rematerializeChain( - NormalInsertBefore, RootOfChain, PointerToBase[LiveValue]); + NormalInsertBefore, Record.RootOfChain, PointerToBase[LiveValue]); Instruction *UnwindRematerializedValue = rematerializeChain( - UnwindInsertBefore, RootOfChain, PointerToBase[LiveValue]); + UnwindInsertBefore, Record.RootOfChain, PointerToBase[LiveValue]); Info.RematerializedValues[NormalRematerializedValue] = LiveValue; Info.RematerializedValues[UnwindRematerializedValue] = LiveValue; @@ -2563,11 +2599,16 @@ Holders.clear(); + // Compute the cost of possible re-materialization of derived pointers. + RematCandTy RematerizationCandidates; + findRematerializationCandidates(PointerToBase, RematerizationCandidates, TTI); + // In order to reduce live set of statepoint we might choose to rematerialize // some values instead of relocating them. This is purely an optimization and // does not influence correctness. for (size_t i = 0; i < Records.size(); i++) - rematerializeLiveValues(ToUpdate[i], Records[i], PointerToBase, TTI); + rematerializeLiveValues(ToUpdate[i], Records[i], PointerToBase, + RematerizationCandidates, TTI); // We need this to safely RAUW and delete call or invoke return values that // may themselves be live over a statepoint. For details, please see usage in