diff --git a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h --- a/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h +++ b/llvm/include/llvm/Transforms/Coroutines/CoroSplit.h @@ -22,7 +22,13 @@ namespace llvm { struct CoroSplitPass : PassInfoMixin { - CoroSplitPass(bool OptimizeFrame = false) : OptimizeFrame(OptimizeFrame) {} + const std::function MaterializableCallback; + + CoroSplitPass(bool OptimizeFrame = false); + CoroSplitPass(std::function MaterializableCallback, + bool OptimizeFrame = false) + : MaterializableCallback(MaterializableCallback), + OptimizeFrame(OptimizeFrame) {} PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR); diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -318,8 +318,6 @@ LLVM_DEBUG(dump()); } -static bool materializable(Instruction &V); - namespace { // RematGraph is used to construct a DAG for rematerializable instructions @@ -342,9 +340,12 @@ using RematNodeMap = SmallMapVector, 8>; RematNodeMap Remats; + const std::function &MaterializableCallback; SuspendCrossingInfo &Checker; - RematGraph(Instruction *I, SuspendCrossingInfo &Checker) : Checker(Checker) { + RematGraph(const std::function &MaterializableCallback, + Instruction *I, SuspendCrossingInfo &Checker) + : MaterializableCallback(MaterializableCallback), Checker(Checker) { std::unique_ptr FirstNode = std::make_unique(I); EntryNode = FirstNode.get(); std::deque> WorkList; @@ -367,7 +368,7 @@ Remats[N->Node] = std::move(NUPtr); for (auto &Def : N->Node->operands()) { Instruction *D = dyn_cast(Def.get()); - if (!D || !materializable(*D) || + if (!D || !MaterializableCallback(*D) || !Checker.isDefinitionAcrossSuspend(*D, FirstUse)) continue; @@ -2211,11 +2212,12 @@ rewritePHIs(*BB); } +/// Default materializable callback // Check for instructions that we can recreate on resume as opposed to spill // the result into a coroutine frame. -static bool materializable(Instruction &V) { - return isa(&V) || isa(&V) || - isa(&V) || isa(&V) || isa(&V); +bool coro::defaultMaterializable(Instruction &V) { + return (isa(&V) || isa(&V) || + isa(&V) || isa(&V) || isa(&V)); } // Check for structural coroutine intrinsics that should not be spilled into @@ -2887,14 +2889,16 @@ } } -static void doRematerializations(Function &F, SuspendCrossingInfo &Checker) { +static void doRematerializations( + Function &F, SuspendCrossingInfo &Checker, + const std::function &MaterializableCallback) { SpillInfo Spills; // See if there are materializable instructions across suspend points // We record these as the starting point to also identify materializable // defs of uses in these operations for (Instruction &I : instructions(F)) { - if (!materializable(I)) + if (!MaterializableCallback(I)) continue; for (User *U : I.users()) if (Checker.isDefinitionAcrossSuspend(I, U)) @@ -2925,7 +2929,8 @@ continue; // Constructor creates the whole RematGraph for the given Use - auto RematUPtr = std::make_unique(U, Checker); + auto RematUPtr = + std::make_unique(MaterializableCallback, U, Checker); LLVM_DEBUG(dbgs() << "***** Next remat group *****\n"; ReversePostOrderTraversal RPOT(RematUPtr.get()); @@ -2943,7 +2948,9 @@ rewriteMaterializableInstructions(AllRemats); } -void coro::buildCoroutineFrame(Function &F, Shape &Shape) { +void coro::buildCoroutineFrame( + Function &F, Shape &Shape, + const std::function &MaterializableCallback) { // Don't eliminate swifterror in async functions that won't be split. if (Shape.ABI != coro::ABI::Async || !Shape.CoroSuspends.empty()) eliminateSwiftError(F, Shape); @@ -2994,7 +3001,7 @@ // Build suspend crossing info. SuspendCrossingInfo Checker(F, Shape); - doRematerializations(F, Checker); + doRematerializations(F, Checker, MaterializableCallback); FrameDataInfo FrameData; SmallVector LocalAllocas; diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -261,7 +261,10 @@ void buildFrom(Function &F); }; -void buildCoroutineFrame(Function &F, Shape &Shape); +bool defaultMaterializable(Instruction &V); +void buildCoroutineFrame( + Function &F, Shape &Shape, + const std::function &MaterializableCallback); CallInst *createMustTailCall(DebugLoc Loc, Function *MustTailCallFn, ArrayRef Arguments, IRBuilder<> &); } // End namespace coro. diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -1929,10 +1929,10 @@ }; } -static coro::Shape splitCoroutine(Function &F, - SmallVectorImpl &Clones, - TargetTransformInfo &TTI, - bool OptimizeFrame) { +static coro::Shape +splitCoroutine(Function &F, SmallVectorImpl &Clones, + TargetTransformInfo &TTI, bool OptimizeFrame, + std::function MaterializableCallback) { PrettyStackTraceFunction prettyStackTrace(F); // The suspend-crossing algorithm in buildCoroutineFrame get tripped @@ -1944,7 +1944,7 @@ return Shape; simplifySuspendPoints(Shape); - buildCoroutineFrame(F, Shape); + buildCoroutineFrame(F, Shape, MaterializableCallback); replaceFrameSizeAndAlignment(Shape); // If there are no suspend points, no split required, just remove @@ -2104,6 +2104,10 @@ Fns.push_back(PrepareFn); } +CoroSplitPass::CoroSplitPass(bool OptimizeFrame) + : MaterializableCallback(coro::defaultMaterializable), + OptimizeFrame(OptimizeFrame) {} + PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { @@ -2142,8 +2146,9 @@ F.setSplittedCoroutine(); SmallVector Clones; - const coro::Shape Shape = splitCoroutine( - F, Clones, FAM.getResult(F), OptimizeFrame); + const coro::Shape Shape = + splitCoroutine(F, Clones, FAM.getResult(F), + OptimizeFrame, MaterializableCallback); updateCallGraphAfterCoroutineSplit(*N, Shape, Clones, C, CG, AM, UR, FAM); if (!Shape.CoroSuspends.empty()) { diff --git a/llvm/unittests/Transforms/CMakeLists.txt b/llvm/unittests/Transforms/CMakeLists.txt --- a/llvm/unittests/Transforms/CMakeLists.txt +++ b/llvm/unittests/Transforms/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Coroutines) add_subdirectory(IPO) add_subdirectory(Scalar) add_subdirectory(Utils) diff --git a/llvm/unittests/Transforms/Coroutines/CMakeLists.txt b/llvm/unittests/Transforms/Coroutines/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Coroutines/CMakeLists.txt @@ -0,0 +1,18 @@ +set(LLVM_LINK_COMPONENTS + Analysis + AsmParser + Core + Coroutines + Passes + Support + TargetParser + TransformUtils + ) + +add_llvm_unittest(CoroTests + ExtraRematTest.cpp + ) + +target_link_libraries(CoroTests PRIVATE LLVMTestingSupport) + +set_property(TARGET CoroTests PROPERTY FOLDER "Tests/UnitTests/TransformTests") diff --git a/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Coroutines/ExtraRematTest.cpp @@ -0,0 +1,180 @@ +//===- ExtraRematTest.cpp - Coroutines unit tests -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Module.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Testing/Support/Error.h" +#include "llvm/Transforms/Coroutines/CoroSplit.h" +#include "gtest/gtest.h" + +namespace llvm { + +struct ExtraRematTest : public testing::Test { + LLVMContext Ctx; + ModulePassManager MPM; + PassBuilder PB; + LoopAnalysisManager LAM; + FunctionAnalysisManager FAM; + CGSCCAnalysisManager CGAM; + ModuleAnalysisManager MAM; + LLVMContext Context; + std::unique_ptr M; + + ExtraRematTest() { + PB.registerModuleAnalyses(MAM); + PB.registerCGSCCAnalyses(CGAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + } + + BasicBlock *getBasicBlockByName(Function *F, StringRef Name) const { + for (BasicBlock &BB : *F) { + if (BB.getName() == Name) + return &BB; + } + return nullptr; + } + + CallInst *getCallByName(BasicBlock *BB, StringRef Name) const { + for (Instruction &I : *BB) { + if (CallInst *CI = dyn_cast(&I)) + if (CI->getCalledFunction()->getName() == Name) + return CI; + } + return nullptr; + } + + void ParseAssembly(const StringRef IR) { + SMDiagnostic Error; + M = parseAssemblyString(IR, Error, Context); + std::string errMsg; + raw_string_ostream os(errMsg); + Error.print("", os); + + // A failure here means that the test itself is buggy. + if (!M) + report_fatal_error(os.str().c_str()); + } +}; + +StringRef Text = R"( + define ptr @f(i32 %n) presplitcoroutine { + entry: + %id = call token @llvm.coro.id(i32 0, ptr null, ptr null, ptr null) + %size = call i32 @llvm.coro.size.i32() + %alloc = call ptr @malloc(i32 %size) + %hdl = call ptr @llvm.coro.begin(token %id, ptr %alloc) + + %inc1 = add i32 %n, 1 + %val2 = call i32 @should.remat(i32 %inc1) + %sp1 = call i8 @llvm.coro.suspend(token none, i1 false) + switch i8 %sp1, label %suspend [i8 0, label %resume1 + i8 1, label %cleanup] + resume1: + %inc2 = add i32 %val2, 1 + %sp2 = call i8 @llvm.coro.suspend(token none, i1 false) + switch i8 %sp1, label %suspend [i8 0, label %resume2 + i8 1, label %cleanup] + + resume2: + call void @print(i32 %val2) + call void @print(i32 %inc2) + br label %cleanup + + cleanup: + %mem = call ptr @llvm.coro.free(token %id, ptr %hdl) + call void @free(ptr %mem) + br label %suspend + suspend: + call i1 @llvm.coro.end(ptr %hdl, i1 0) + ret ptr %hdl + } + + declare ptr @llvm.coro.free(token, ptr) + declare i32 @llvm.coro.size.i32() + declare i8 @llvm.coro.suspend(token, i1) + declare void @llvm.coro.resume(ptr) + declare void @llvm.coro.destroy(ptr) + + declare token @llvm.coro.id(i32, ptr, ptr, ptr) + declare i1 @llvm.coro.alloc(token) + declare ptr @llvm.coro.begin(token, ptr) + declare i1 @llvm.coro.end(ptr, i1) + + declare i32 @should.remat(i32) + + declare noalias ptr @malloc(i32) + declare void @print(i32) + declare void @free(ptr) + )"; + +// Materializable callback with extra rematerialization +bool ExtraMaterializable(Instruction &I) { + if (isa(&I) || isa(&I) || + isa(&I) || isa(&I) || isa(&I)) + return true; + + if (auto *CI = dyn_cast(&I)) { + auto *CalledFunc = CI->getCalledFunction(); + if (CalledFunc && CalledFunc->getName().startswith("should.remat")) + return true; + } + + return false; +} + +TEST_F(ExtraRematTest, TestCoroRematDefault) { + ParseAssembly(Text); + + ASSERT_TRUE(M); + + CGSCCPassManager CGPM; + CGPM.addPass(CoroSplitPass()); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); + + // Verify that extra rematerializable instruction has been rematerialized + Function *F = M->getFunction("f.resume"); + assert(F && "could not find split function f.resume"); + + BasicBlock *Resume1 = getBasicBlockByName(F, "resume1"); + assert(Resume1 && "could not find expected BB resume1 in split function"); + + // With default materialization the intrinsic should not have be + // rematerialized + CallInst *CI = getCallByName(Resume1, "should.remat"); + ASSERT_FALSE(CI); +} + +TEST_F(ExtraRematTest, TestCoroRematWithCallback) { + ParseAssembly(Text); + + ASSERT_TRUE(M); + + CGSCCPassManager CGPM; + CGPM.addPass( + CoroSplitPass(std::function(ExtraMaterializable))); + MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM))); + MPM.run(*M, MAM); + + // Verify that extra rematerializable instruction has been rematerialized + Function *F = M->getFunction("f.resume"); + assert(F && "could not find split function f.resume"); + + BasicBlock *Resume1 = getBasicBlockByName(F, "resume1"); + assert(Resume1 && "could not find expected BB resume1 in split function"); + + // With callback the extra rematerialization of the function should have + // happened + CallInst *CI = getCallByName(Resume1, "should.remat"); + ASSERT_TRUE(CI); +} +} // namespace llvm