Index: llvm/trunk/include/llvm/IR/Function.h =================================================================== --- llvm/trunk/include/llvm/IR/Function.h +++ llvm/trunk/include/llvm/IR/Function.h @@ -191,6 +191,12 @@ AttributeSet::FunctionIndex, Kind, Value)); } + /// @brief Remove function attribute from this function. + void removeFnAttr(StringRef Kind) { + setAttributes(AttributeSets.removeAttribute( + getContext(), AttributeSet::FunctionIndex, Kind)); + } + /// Set the entry count for this function. void setEntryCount(uint64_t Count); Index: llvm/trunk/lib/Transforms/Coroutines/CoroEarly.cpp =================================================================== --- llvm/trunk/lib/Transforms/Coroutines/CoroEarly.cpp +++ llvm/trunk/lib/Transforms/Coroutines/CoroEarly.cpp @@ -52,6 +52,14 @@ switch (CS.getIntrinsicID()) { default: continue; + case Intrinsic::coro_begin: + // Mark a function that comes out of the frontend that has a coro.begin + // with a coroutine attribute. + if (auto *CB = cast(&I)) { + if (CB->getInfo().isPreSplit()) + F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT); + } + break; case Intrinsic::coro_resume: lowerResumeOrDestroy(CS, CoroSubFnInst::ResumeIndex); break; @@ -80,7 +88,8 @@ // This pass has work to do only if we find intrinsics we are going to lower // in the module. bool doInitialization(Module &M) override { - if (coro::declaresIntrinsics(M, {"llvm.coro.resume", "llvm.coro.destroy"})) + if (coro::declaresIntrinsics( + M, {"llvm.coro.begin", "llvm.coro.resume", "llvm.coro.destroy"})) L = llvm::make_unique(M); return false; } Index: llvm/trunk/lib/Transforms/Coroutines/CoroElide.cpp =================================================================== --- llvm/trunk/lib/Transforms/Coroutines/CoroElide.cpp +++ llvm/trunk/lib/Transforms/Coroutines/CoroElide.cpp @@ -14,7 +14,6 @@ #include "CoroInternal.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/IR/ConstantFolder.h" #include "llvm/IR/InstIterator.h" #include "llvm/Pass.h" @@ -108,7 +107,32 @@ return true; } +// See if there are any coro.subfn.addr instructions referring to coro.devirt +// trigger, if so, replace them with a direct call to devirt trigger function. +static bool replaceDevirtTrigger(Function &F) { + SmallVector DevirtAddr; + for (auto &I : instructions(F)) + if (auto *SubFn = dyn_cast(&I)) + if (SubFn->getIndex() == CoroSubFnInst::RestartTrigger) + DevirtAddr.push_back(SubFn); + + if (DevirtAddr.empty()) + return false; + + Module &M = *F.getParent(); + Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); + assert(DevirtFn && "coro.devirt.fn not found"); + replaceWithConstant(DevirtFn, DevirtAddr); + + return true; +} + bool CoroElide::runOnFunction(Function &F) { + bool Changed = false; + + if (F.hasFnAttribute(CORO_PRESPLIT_ATTR)) + Changed = replaceDevirtTrigger(F); + // Collect all PostSplit coro.begins. SmallVector CoroBegins; for (auto &I : instructions(F)) @@ -117,9 +141,7 @@ CoroBegins.push_back(CB); if (CoroBegins.empty()) - return false; - - bool Changed = false; + return Changed; for (auto *CB : CoroBegins) Changed |= replaceIndirectCalls(CB); Index: llvm/trunk/lib/Transforms/Coroutines/CoroInstr.h =================================================================== --- llvm/trunk/lib/Transforms/Coroutines/CoroInstr.h +++ llvm/trunk/lib/Transforms/Coroutines/CoroInstr.h @@ -34,10 +34,11 @@ public: enum ResumeKind { + RestartTrigger = -1, ResumeIndex, DestroyIndex, IndexLast, - IndexFirst = ResumeIndex + IndexFirst = RestartTrigger }; Value *getFrame() const { return getArgOperand(FrameArg); } @@ -90,6 +91,7 @@ bool hasOutlinedParts() const { return OutlinedParts != nullptr; } bool isPostSplit() const { return Resumers != nullptr; } + bool isPreSplit() const { return !isPostSplit(); } }; Info getInfo() const { Info Result; Index: llvm/trunk/lib/Transforms/Coroutines/CoroInternal.h =================================================================== --- llvm/trunk/lib/Transforms/Coroutines/CoroInternal.h +++ llvm/trunk/lib/Transforms/Coroutines/CoroInternal.h @@ -24,6 +24,21 @@ void initializeCoroElidePass(PassRegistry &); void initializeCoroCleanupPass(PassRegistry &); +// CoroEarly pass marks every function that has coro.begin with a string +// attribute "coroutine.presplit"="0". CoroSplit pass processes the coroutine +// twice. First, it lets it go through complete IPO optimization pipeline as a +// single function. It forces restart of the pipeline by inserting an indirect +// call to an empty function "coro.devirt.trigger" which is devirtualized by +// CoroElide pass that triggers a restart of the pipeline by CGPassManager. +// When CoroSplit pass sees the same coroutine the second time, it splits it up, +// adds coroutine subfunctions to the SCC to be processed by IPO pipeline. + +#define CORO_PRESPLIT_ATTR "coroutine.presplit" +#define UNPREPARED_FOR_SPLIT "0" +#define PREPARED_FOR_SPLIT "1" + +#define CORO_DEVIRT_TRIGGER_FN "coro.devirt.trigger" + namespace coro { bool declaresIntrinsics(Module &M, std::initializer_list); Index: llvm/trunk/lib/Transforms/Coroutines/CoroSplit.cpp =================================================================== --- llvm/trunk/lib/Transforms/Coroutines/CoroSplit.cpp +++ llvm/trunk/lib/Transforms/Coroutines/CoroSplit.cpp @@ -17,6 +17,66 @@ #define DEBUG_TYPE "coro-split" +// We present a coroutine to an LLVM as an ordinary function with suspension +// points marked up with intrinsics. We let the optimizer party on the coroutine +// as a single function for as long as possible. Shortly before the coroutine is +// eligible to be inlined into its callers, we split up the coroutine into parts +// corresponding to an initial, resume and destroy invocations of the coroutine, +// add them to the current SCC and restart the IPO pipeline to optimize the +// coroutine subfunctions we extracted before proceeding to the caller of the +// coroutine. + +// When we see the coroutine the first time, we insert an indirect call to a +// devirt trigger function and mark the coroutine that it is now ready for +// split. +static void prepareForSplit(Function &F, CallGraph &CG) { + Module &M = *F.getParent(); + Function *DevirtFn = M.getFunction(CORO_DEVIRT_TRIGGER_FN); + assert(DevirtFn && "coro.devirt.trigger function not found"); + + F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT); + + // Insert an indirect call sequence that will be devirtualized by CoroElide + // pass: + // %0 = call i8* @llvm.coro.subfn.addr(i8* null, i8 -1) + // %1 = bitcast i8* %0 to void(i8*)* + // call void %1(i8* null) + coro::LowererBase Lowerer(M); + Instruction *InsertPt = F.getEntryBlock().getTerminator(); + auto *Null = ConstantPointerNull::get(Type::getInt8PtrTy(F.getContext())); + auto *DevirtFnAddr = + Lowerer.makeSubFnCall(Null, CoroSubFnInst::RestartTrigger, InsertPt); + auto *IndirectCall = CallInst::Create(DevirtFnAddr, Null, "", InsertPt); + + // Update CG graph with an indirect call we just added. + CG[&F]->addCalledFunction(IndirectCall, CG.getCallsExternalNode()); +} + +// Make sure that there is a devirtualization trigger function that CoroSplit +// pass uses the force restart CGSCC pipeline. If devirt trigger function is not +// found, we will create one and add it to the current SCC. +static void createDevirtTriggerFunc(CallGraph &CG, CallGraphSCC &SCC) { + Module &M = CG.getModule(); + if (M.getFunction(CORO_DEVIRT_TRIGGER_FN)) + return; + + LLVMContext &C = M.getContext(); + auto *FnTy = FunctionType::get(Type::getVoidTy(C), Type::getInt8PtrTy(C), + /*IsVarArgs=*/false); + Function *DevirtFn = + Function::Create(FnTy, GlobalValue::LinkageTypes::PrivateLinkage, + CORO_DEVIRT_TRIGGER_FN, &M); + DevirtFn->addFnAttr(Attribute::AlwaysInline); + auto *Entry = BasicBlock::Create(C, "entry", DevirtFn); + ReturnInst::Create(C, Entry); + + auto *Node = CG.getOrInsertFunction(DevirtFn); + + SmallVector Nodes(SCC.begin(), SCC.end()); + Nodes.push_back(Node); + SCC.initialize(Nodes); +} + //===----------------------------------------------------------------------===// // Top Level Driver //===----------------------------------------------------------------------===// @@ -27,13 +87,51 @@ static char ID; // Pass identification, replacement for typeid CoroSplit() : CallGraphSCCPass(ID) {} - bool runOnSCC(CallGraphSCC &SCC) override { return false; } + bool Run = false; + + // A coroutine is identified by the presence of coro.begin intrinsic, if + // we don't have any, this pass has nothing to do. + bool doInitialization(CallGraph &CG) override { + Run = coro::declaresIntrinsics(CG.getModule(), {"llvm.coro.begin"}); + return CallGraphSCCPass::doInitialization(CG); + } + + bool runOnSCC(CallGraphSCC &SCC) override { + if (!Run) + return false; + + // Find coroutines for processing. + SmallVector Coroutines; + for (CallGraphNode *CGN : SCC) + if (auto *F = CGN->getFunction()) + if (F->hasFnAttribute(CORO_PRESPLIT_ATTR)) + Coroutines.push_back(F); + + if (Coroutines.empty()) + return false; + + CallGraph &CG = getAnalysis().getCallGraph(); + createDevirtTriggerFunc(CG, SCC); + + for (Function *F : Coroutines) { + Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); + StringRef Value = Attr.getValueAsString(); + DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F->getName() + << "' state: " << Value << "\n"); + if (Value == UNPREPARED_FOR_SPLIT) { + prepareForSplit(*F, CG); + continue; + } + F->removeFnAttr(CORO_PRESPLIT_ATTR); + } + return true; + } + void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); CallGraphSCCPass::getAnalysisUsage(AU); } }; - } char CoroSplit::ID = 0; Index: llvm/trunk/test/Transforms/Coroutines/restart-trigger.ll =================================================================== --- llvm/trunk/test/Transforms/Coroutines/restart-trigger.ll +++ llvm/trunk/test/Transforms/Coroutines/restart-trigger.ll @@ -0,0 +1,16 @@ +; Verifies that restart trigger forces IPO pipelines restart and the same +; coroutine is looked at by CoroSplit pass twice. +; RUN: opt < %s -S -O0 -enable-coroutines -debug-only=coro-split 2>&1 | FileCheck %s +; RUN: opt < %s -S -O1 -enable-coroutines -debug-only=coro-split 2>&1 | FileCheck %s + +; CHECK: CoroSplit: Processing coroutine 'f' state: 0 +; CHECK-NEXT: CoroSplit: Processing coroutine 'f' state: 1 + +declare i8* @llvm.coro.begin(i8*, i32, i8*, i8*) + +; a coroutine start function +define i8* @f() { +entry: + %hdl = call i8* @llvm.coro.begin(i8* null, i32 0, i8* null, i8* null) + ret i8* %hdl +}