diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h --- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h +++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h @@ -22,7 +22,13 @@ namespace llvm { struct CoroSplitPass : PassInfoMixin { - CoroSplitPass(bool OptimizeFrame = false) : OptimizeFrame(OptimizeFrame) {} + const std::function MaterializableCallback; + + CoroSplitPass(bool OptimizeFrame = false); + CoroSplitPass(std::function MaterializableCallback, + bool OptimizeFrame = false) + : MaterializableCallback(MaterializableCallback), + OptimizeFrame(OptimizeFrame) {} PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR); diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -318,8 +318,6 @@ LLVM_DEBUG(dump()); } -static bool materializable(Instruction &V); - namespace { // RematGraph is used to construct a DAG for rematerializable instructions @@ -342,10 +340,13 @@ using RematNodeMap = SmallMapVector, 8>; RematNodeMap Remats; + const std::function MaterializableCallback; SuspendCrossingInfo &Checker; - RematGraph(Instruction *I, SuspendCrossingInfo &Checker) : Checker(Checker) { - assert(materializable(*I)); + RematGraph(std::function MaterializableCallback, + Instruction *I, SuspendCrossingInfo &Checker) + : MaterializableCallback(MaterializableCallback), Checker(Checker) { + assert(MaterializableCallback(*I)); std::unique_ptr FirstNode = std::make_unique(I); EntryNode = FirstNode.get(); std::deque> WorkList; @@ -368,7 +369,7 @@ Remats[N->Node] = std::move(NUPtr); for (auto &Def : N->Node->operands()) { Instruction *D = dyn_cast(Def.get()); - if (!D || !materializable(*D) || + if (!D || !MaterializableCallback(*D) || !Checker.isDefinitionAcrossSuspend(*D, FirstUse)) continue; @@ -2213,11 +2214,12 @@ rewritePHIs(*BB); } +/// Default materializable callback // Check for instructions that we can recreate on resume as opposed to spill // the result into a coroutine frame. -static bool materializable(Instruction &V) { - return isa(&V) || isa(&V) || - isa(&V) || isa(&V) || isa(&V); +bool coro::defaultMaterializable(Instruction &V) { + return (isa(&V) || isa(&V) || + isa(&V) || isa(&V) || isa(&V)); } // Check for structural coroutine intrinsics that should not be spilled into @@ -2889,14 +2891,16 @@ } } -static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) { +static void doRematerializations( + Function &F, SuspendCrossingInfo &Checker, + const std::function MaterializableCallback) { SpillInfo Spills; // See if there are materializable instructions across suspend points // We record these as the starting point to also identify materializable // defs of uses in these operations for (Instruction &I : instructions(F)) { - if (!materializable(I)) + if (!MaterializableCallback(I)) continue; for (User *U : I.users()) if (Checker.isDefinitionAcrossSuspend(I, U)) @@ -2927,7 +2931,8 @@ continue; // Constructor creates the whole RematGraph for the given Use - auto RematUPtr = std::make_unique(U, Checker); + auto RematUPtr = + std::make_unique(MaterializableCallback, U, Checker); LLVM_DEBUG(dbgs() << "***** Next remat group *****\n"; ReversePostOrderTraversal RPOT(RematUPtr.get()); @@ -2945,7 +2950,9 @@ rewriteMaterializableInstructions(AllRemats); } -void coro::buildCoroutineFrame(Function &F, Shape &Shape) { +void coro::buildCoroutineFrame( + Function &F, Shape &Shape, + const std::function MaterializableCallback) { // Don't eliminate swifterror in async functions that won't be split. if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty()) eliminateSwiftError(F, Shape); @@ -2996,11 +3003,12 @@ // Build suspend crossing info. SuspendCrossingInfo Checker(F, Shape); - doRematerializations(F, Checker); + doRematerializations(F, Checker, MaterializableCallback); FrameDataInfo FrameData; SmallVector LocalAllocas; SmallVector DeadInstructions; + if (Shape.ABI != coro::ABI::Async && Shape.ABI != coro::ABI::Retcon && Shape.ABI != coro::ABI::RetconOnce) sinkLifetimeStartMarkers(F, Shape, Checker); diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -261,7 +261,10 @@ void buildFrom(Function &F); }; -void buildCoroutineFrame(Function &F, Shape &Shape); +bool defaultMaterializable(Instruction &V); +void buildCoroutineFrame( + Function &F, Shape &Shape, + std::function MaterializableCallback); CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, ArrayRef Arguments, IRBuilder<> &); } // End namespace coro. diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -1929,10 +1929,10 @@ }; } -static coro::Shape splitCoroutine(Function &F, - SmallVectorImpl &Clones, - TargetTransformInfo &TTI, - bool OptimizeFrame) { +static coro::Shape +splitCoroutine(Function &F, SmallVectorImpl &Clones, + TargetTransformInfo &TTI, bool OptimizeFrame, + std::function MaterializableCallback) { PrettyStackTraceFunction prettyStackTrace(F); // The suspend-crossing algorithm in buildCoroutineFrame get tripped @@ -1944,7 +1944,7 @@ return Shape; simplifySuspendPoints(Shape); - buildCoroutineFrame(F, Shape); + buildCoroutineFrame(F, Shape, MaterializableCallback); replaceFrameSizeAndAlignment(Shape); // If there are no suspend points, no split required, just remove @@ -2104,6 +2104,10 @@ Fns.push_back(PrepareFn); } +CoroSplitPass::CoroSplitPass(bool OptimizeFrame) + : MaterializableCallback(coro::defaultMaterializable), + OptimizeFrame(OptimizeFrame) {} + PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { @@ -2142,8 +2146,9 @@ F.setSplittedCoroutine(); SmallVector Clones; - const coro::Shape Shape = splitCoroutine( - F, Clones, FAM.getResult(F), OptimizeFrame); + const coro::Shape Shape = + splitCoroutine(F, Clones, FAM.getResult(F), + OptimizeFrame, MaterializableCallback); updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM); if (!Shape.CoroSuspends.empty()) {