diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h --- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h +++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h @@ -22,7 +22,13 @@ namespace llvm { struct CoroSplitPass : PassInfoMixin { - CoroSplitPass(bool OptimizeFrame = false) : OptimizeFrame(OptimizeFrame) {} + const std::function MaterializableCallback; + + CoroSplitPass(bool OptimizeFrame = false); + CoroSplitPass(std::function MaterializableCallback, + bool OptimizeFrame = false) + : MaterializableCallback(MaterializableCallback), + OptimizeFrame(OptimizeFrame) {} PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR); diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -318,8 +318,6 @@ LLVM_DEBUG(dump()); } -static bool materializable(Instruction &V); - namespace { struct RematNode { Instruction *Node; @@ -333,9 +331,12 @@ using RematNodeMap = SmallMapVector, 8>; RematNodeMap Remats; + const std::function MaterializableCallback; SuspendCrossingInfo &Checker; - RematGraph(Instruction *I, SuspendCrossingInfo &Checker) : Checker(Checker) { + RematGraph(std::function MaterializableCallback, + Instruction *I, SuspendCrossingInfo &Checker) + : MaterializableCallback(MaterializableCallback), Checker(Checker) { std::unique_ptr FirstNode = std::make_unique(I); EntryNode = FirstNode.get(); std::deque> WorkList; @@ -356,7 +357,7 @@ Remats[N->Node] = std::move(NUPtr); for (auto &Def : N->Node->operands()) { if (Instruction *D = dyn_cast(Def.get())) { - if (materializable(*D) && + if (MaterializableCallback(*D) && Checker.isDefinitionAcrossSuspend(*D, FirstUse)) { if (Remats.count(D)) { // Already have this in the graph @@ -2208,11 +2209,12 @@ rewritePHIs(*BB); } +/// Default materializable callback // Check for instructions that we can recreate on resume as opposed to spill // the result into a coroutine frame. -static bool materializable(Instruction &V) { - return isa(&V) || isa(&V) || - isa(&V) || isa(&V) || isa(&V); +bool coro::defaultMaterializable(Instruction &V) { + return (isa(&V) || isa(&V) || + isa(&V) || isa(&V) || isa(&V)); } // Check for structural coroutine intrinsics that should not be spilled into @@ -2886,14 +2888,16 @@ } } -static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) { +static void doRematerializations( + Function &F, SuspendCrossingInfo &Checker, + const std::function MaterializableCallback) { SpillInfo Spills; // See if there are materializable instructions across suspend points // We record these as the starting point to also identify materializable // defs of uses in these operations for (Instruction &I : instructions(F)) { - if (materializable(I)) { + if (MaterializableCallback(I)) { for (User *U : I.users()) if (Checker.isDefinitionAcrossSuspend(I, U)) Spills[&I].push_back(cast(U)); @@ -2930,7 +2934,7 @@ if (!AllRemats.count(U)) { // Constructor creates the whole RematGraph for the given Use std::unique_ptr RematUPtr = - std::make_unique(U, Checker); + std::make_unique(MaterializableCallback, U, Checker); LLVM_DEBUG( dbgs() << "***** Next remat group *****\n"; @@ -2955,7 +2959,9 @@ Spills.clear(); } -void coro::buildCoroutineFrame(Function &F, Shape &Shape) { +void coro::buildCoroutineFrame( + Function &F, Shape &Shape, + const std::function MaterializableCallback) { // Don't eliminate swifterror in async functions that won't be split. if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty()) eliminateSwiftError(F, Shape); @@ -3006,11 +3012,12 @@ // Build suspend crossing info. SuspendCrossingInfo Checker(F, Shape); - doRematerializations(F, Checker); + doRematerializations(F, Checker, MaterializableCallback); FrameDataInfo FrameData; SmallVector LocalAllocas; SmallVector DeadInstructions; + if (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon && Shape.ABI != coro::ABI::RetconOnce) sinkLifetimeStartMarkers(F, Shape, Checker); diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -261,7 +261,10 @@ void buildFrom(Function &F); }; -void buildCoroutineFrame(Function &F, Shape &Shape); +bool defaultMaterializable(Instruction &V); +void buildCoroutineFrame( + Function &F, Shape &Shape, + std::function MaterializableCallback); CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, ArrayRef Arguments, IRBuilder<> &); } // End namespace coro. diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -1929,10 +1929,10 @@ }; } -static coro::Shape splitCoroutine(Function &F, - SmallVectorImpl &Clones, - TargetTransformInfo &TTI, - bool OptimizeFrame) { +static coro::Shape +splitCoroutine(Function &F, SmallVectorImpl &Clones, + TargetTransformInfo &TTI, bool OptimizeFrame, + std::function MaterializableCallback) { PrettyStackTraceFunction prettyStackTrace(F); // The suspend-crossing algorithm in buildCoroutineFrame get tripped @@ -1944,7 +1944,7 @@ return Shape; simplifySuspendPoints(Shape); - buildCoroutineFrame(F, Shape); + buildCoroutineFrame(F, Shape, MaterializableCallback); replaceFrameSizeAndAlignment(Shape); // If there are no suspend points, no split required, just remove @@ -2104,6 +2104,10 @@ Fns.push_back(PrepareFn); } +CoroSplitPass::CoroSplitPass(bool OptimizeFrame) + : MaterializableCallback(coro::defaultMaterializable), + OptimizeFrame(OptimizeFrame) {} + PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { @@ -2142,8 +2146,9 @@ F.setSplittedCoroutine(); SmallVector Clones; - const coro::Shape Shape = splitCoroutine( - F, Clones, FAM.getResult(F), OptimizeFrame); + const coro::Shape Shape = + splitCoroutine(F, Clones, FAM.getResult(F), + OptimizeFrame, MaterializableCallback); updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM); if (!Shape.CoroSuspends.empty()) {