diff --git a/clang/lib/CodeGen/CGCoroutine.cpp b/clang/lib/CodeGen/CGCoroutine.cpp --- a/clang/lib/CodeGen/CGCoroutine.cpp +++ b/clang/lib/CodeGen/CGCoroutine.cpp @@ -547,7 +547,7 @@ auto *EntryBB = Builder.GetInsertBlock(); auto *AllocBB = createBasicBlock("coro.alloc"); - auto *InitBB = createBasicBlock("coro.init"); + auto *BeginBB = createBasicBlock("coro.begin"); auto *FinalBB = createBasicBlock("coro.final"); auto *RetBB = createBasicBlock("coro.ret"); @@ -564,7 +564,7 @@ auto *CoroAlloc = Builder.CreateCall( CGM.getIntrinsic(llvm::Intrinsic::coro_alloc), {CoroId}); - Builder.CreateCondBr(CoroAlloc, AllocBB, InitBB); + Builder.CreateCondBr(CoroAlloc, AllocBB, BeginBB); EmitBlock(AllocBB); auto *AllocateCall = EmitScalarExpr(S.getAllocate()); @@ -577,17 +577,17 @@ // See if allocation was successful. auto *NullPtr = llvm::ConstantPointerNull::get(Int8PtrTy); auto *Cond = Builder.CreateICmpNE(AllocateCall, NullPtr); - Builder.CreateCondBr(Cond, InitBB, RetOnFailureBB); + Builder.CreateCondBr(Cond, BeginBB, RetOnFailureBB); // If not, return OnAllocFailure object. EmitBlock(RetOnFailureBB); EmitStmt(RetOnAllocFailure); } else { - Builder.CreateBr(InitBB); + Builder.CreateBr(BeginBB); } - EmitBlock(InitBB); + EmitBlock(BeginBB); // Pass the result of the allocation to coro.begin. auto *Phi = Builder.CreatePHI(VoidPtrTy, 2); @@ -606,12 +606,36 @@ CodeGenFunction::RunCleanupsScope ResumeScope(*this); EHStack.pushCleanup(NormalAndEHCleanup, S.getDeallocate()); + // Wrap around the parameter copy with a coro.init() check. + // This will allows us to perform parameter copy in the init function, but + // not in the ramp function. + auto *InitBB = createBasicBlock("coro.init"); + auto *InitReadyBB = createBasicBlock("coro.init.ready"); + auto *CoroInit = + Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::coro_init)); + Builder.CreateCondBr(CoroInit, InitBB, InitReadyBB); + + EmitBlock(InitBB); + SmallVector FrameAllocas; // Create parameter copies. We do it before creating a promise, since an // evolution of coroutine TS may allow promise constructor to observe // parameter copies. + int ID = 0; for (auto *PM : S.getParamMoves()) { EmitStmt(PM); ParamReplacer.addCopy(cast(PM)); + llvm::AllocaInst *Alloca = cast( + GetAddrOfLocalVar(cast(cast(PM)->getSingleDecl())) + .getPointer()); + Alloca->setMetadata( + "coroutine_frame_alloca", + llvm::MDNode::get( + getLLVMContext(), + { + llvm::ConstantAsMetadata::get( + Builder.getInt1(false)) /*IsPromise*/, + llvm::ConstantAsMetadata::get(Builder.getInt32(ID++)), + })); // TODO: if(CoroParam(...)) need to surround ctor and dtor // for the copy, so that llvm can elide it if the copy is // not needed. @@ -619,12 +643,23 @@ EmitStmt(S.getPromiseDeclStmt()); + Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::coro_init_end)); + Builder.CreateBr(InitReadyBB); + EmitBlock(InitReadyBB); + Address PromiseAddr = GetAddrOfLocalVar(S.getPromiseDecl()); - auto *PromiseAddrVoidPtr = - new llvm::BitCastInst(PromiseAddr.getPointer(), VoidPtrTy, "", CoroId); - // Update CoroId to refer to the promise. We could not do it earlier because - // promise local variable was not emitted yet. - CoroId->setArgOperand(1, PromiseAddrVoidPtr); + llvm::AllocaInst *PromiseAlloca = + cast(PromiseAddr.getPointer()); + + PromiseAlloca->setMetadata( + "coroutine_frame_alloca", + llvm::MDNode::get( + getLLVMContext(), + { + llvm::ConstantAsMetadata::get( + Builder.getInt1(true)) /*IsPromise*/, + llvm::ConstantAsMetadata::get(Builder.getInt32(ID++)), + })); // Now we have the promise, initialize the GRO GroManager.EmitGroInit(); diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -1274,6 +1274,12 @@ ReadOnly>, NoCapture>]>; +def int_coro_frame_get : Intrinsic<[llvm_ptr_ty], + [llvm_ptr_ty, llvm_ptr_ty, llvm_i1_ty, llvm_i32_ty], + [IntrNoMem]>; +def int_coro_init: Intrinsic<[llvm_i1_ty], [], []>; +def int_coro_init_end: Intrinsic<[], [], []>; + ///===-------------------------- Other Intrinsics --------------------------===// // def int_trap : Intrinsic<[], [], [IntrNoReturn, IntrCold]>, diff --git a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp --- a/llvm/lib/Transforms/Coroutines/CoroEarly.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroEarly.cpp @@ -8,10 +8,15 @@ #include "llvm/Transforms/Coroutines/CoroEarly.h" #include "CoroInternal.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -145,6 +150,121 @@ CB->setCannotDuplicate(); } +static void splitRampFunction(Function &F) { + Module *M = F.getParent(); + LLVMContext &C = M->getContext(); + { + CoroBeginInst *CoroBegin = cast( + &*llvm::find_if(instructions(F), + [](Instruction &I) { return isa(&I); })); + Instruction *InsertPoint = CoroBegin->getNextNode(); + + for (Instruction &I : make_early_inc_range(instructions(F))) { + auto *AI = dyn_cast(&I); + if (!AI) + continue; + auto *MD = AI->getMetadata("coroutine_frame_alloca"); + if (!MD) + continue; + + auto *IsPromise = cast(MD->getOperand(0))->getValue(); + auto *SlotID = cast(MD->getOperand(1))->getValue(); + auto *VoidPt = + new BitCastInst(AI, llvm::Type::getInt8PtrTy(C), "", InsertPoint); + auto *FrameGet = CallInst::Create( + Intrinsic::getDeclaration(M, Intrinsic::coro_frame_get), + {CoroBegin, VoidPt, IsPromise, SlotID}, "", InsertPoint); + auto *NewPtr = new BitCastInst(FrameGet, AI->getType(), "", InsertPoint); + AI->replaceUsesWithIf(NewPtr, + [&](Use &U) { return U.getUser() != VoidPt; }); + } + } + + Function *NewF; + { + // Create the split ramp function, and clone. + llvm::Type *NewFArgTypes[] = {llvm::Type::getInt8PtrTy(C)}; + auto newFuncType = + FunctionType::get(F.getReturnType(), NewFArgTypes, false); + NewF = Function::Create(newFuncType, + GlobalValue::LinkageTypes::ExternalLinkage, + F.getName() + ".ramp"); + NewF->addFnAttr(Attribute::NoInline); + M->getFunctionList().push_back(NewF); + ValueToValueMapTy VMap; + for (Argument &A : F.args()) + VMap[&A] = UndefValue::get(A.getType()); + SmallVector Returns; + CloneFunctionInto(NewF, &F, VMap, CloneFunctionChangeType::LocalChangesOnly, + Returns); + } + + { + // Process the init function. + IntrinsicInst *CoroBegin = nullptr; + IntrinsicInst *CoroInitEnd = nullptr; + for (Instruction &I : make_early_inc_range(instructions(F))) { + auto *II = dyn_cast(&I); + if (!II) + continue; + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::coro_begin: + CoroBegin = II; + break; + case Intrinsic::coro_init: + II->replaceAllUsesWith( + llvm::ConstantInt::get(llvm::Type::getInt1Ty(C), 1)); + II->eraseFromParent(); + break; + case Intrinsic::coro_init_end: + CoroInitEnd = II; + break; + } + } + assert(CoroInitEnd->getNextNode() == + CoroInitEnd->getParent()->getTerminator() && + "coro.init.end call should be at the end of the init block"); + CoroInitEnd->getNextNode()->eraseFromParent(); + CallInst *Ret = CallInst::Create(NewF, {CoroBegin}, "", CoroInitEnd); + if (F.getReturnType()->isVoidTy()) + ReturnInst::Create(C, nullptr, CoroInitEnd); + else + ReturnInst::Create(C, Ret, CoroInitEnd); + CoroInitEnd->eraseFromParent(); + removeUnreachableBlocks(F); + F.addFnAttr(CORO_PRESPLIT_ATTR, DO_NOT_PROCESS); + } + + { + // Process the ramp function. + for (Instruction &I : make_early_inc_range(instructions(*NewF))) { + auto *II = dyn_cast(&I); + if (!II) + continue; + switch (II->getIntrinsicID()) { + default: + continue; + case Intrinsic::coro_begin: + II->replaceAllUsesWith(NewF->getArg(0)); + break; + case Intrinsic::coro_init: + II->replaceAllUsesWith( + llvm::ConstantInt::get(llvm::Type::getInt1Ty(C), 0)); + break; + case Intrinsic::coro_alloc: + II->replaceAllUsesWith( + llvm::ConstantInt::get(llvm::Type::getInt1Ty(C), 0)); + break; + } + II->eraseFromParent(); + } + removeUnreachableBlocks(*NewF); + NewF->addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT_RAMP); + } +} + bool Lowerer::lowerEarlyIntrinsics(Function &F) { bool Changed = false; CoroIdInst *CoroId = nullptr; @@ -179,7 +299,6 @@ // with a coroutine attribute. if (auto *CII = cast(&I)) { if (CII->getInfo().isPreSplit()) { - F.addFnAttr(CORO_PRESPLIT_ATTR, UNPREPARED_FOR_SPLIT); setCannotDuplicate(CII); CII->setCoroutineSelf(); CoroId = cast(&I); @@ -210,9 +329,11 @@ // Make sure that all CoroFree reference the coro.id intrinsic. // Token type is not exposed through coroutine C/C++ builtins to plain C, so // we allow specifying none and fixing it up here. - if (CoroId) + if (CoroId) { for (CoroFreeInst *CF : CoroFrees) CF->setArgOperand(0, CoroId); + splitRampFunction(F); + } return Changed; } @@ -226,6 +347,10 @@ } PreservedAnalyses CoroEarlyPass::run(Function &F, FunctionAnalysisManager &) { + if (F.getFnAttribute(CORO_PRESPLIT_ATTR).getValueAsString() == + UNPREPARED_FOR_SPLIT_RAMP) + return PreservedAnalyses::all(); + Module &M = *F.getParent(); if (!declaresCoroEarlyIntrinsics(M) || !Lowerer(M).lowerEarlyIntrinsics(F)) return PreservedAnalyses::all(); diff --git a/llvm/lib/Transforms/Coroutines/CoroInternal.h b/llvm/lib/Transforms/Coroutines/CoroInternal.h --- a/llvm/lib/Transforms/Coroutines/CoroInternal.h +++ b/llvm/lib/Transforms/Coroutines/CoroInternal.h @@ -37,9 +37,11 @@ // Async lowering similarily triggers a restart of the pipeline after it has // split the coroutine. #define CORO_PRESPLIT_ATTR "coroutine.presplit" -#define UNPREPARED_FOR_SPLIT "0" +#define DO_NOT_PROCESS "0" #define PREPARED_FOR_SPLIT "1" #define ASYNC_RESTART_AFTER_SPLIT "2" +#define UNPREPARED_FOR_SPLIT_RAMP "3" +#define PREPARED_FOR_SPLIT_INIT "4" #define CORO_DEVIRT_TRIGGER_FN "coro.devirt.trigger" diff --git a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp --- a/llvm/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroSplit.cpp @@ -2049,6 +2049,74 @@ Fns.push_back(PrepareFn); } +static Function *getCoroInitFunction(Function &RampFunc) { + StringRef RampName = RampFunc.getName(); + assert(RampName.endswith(".ramp") && "Ramp function must ends with .ramp"); + StringRef InitName = RampName.substr(0, RampName.size() - 5); + return RampFunc.getParent()->getFunction(InitName); +} + +static Function *inlineRampFunction(Function &F) { + CallInst *RampCall = cast( + &*llvm::find_if(instructions(F), [&](const Instruction &I) { + if (const CallInst *CI = dyn_cast(&I)) + return CI->getCalledFunction()->getName().startswith(F.getName()); + return false; + })); + InlineFunctionInfo IFI; + InlineFunction(*RampCall, IFI); + + SmallVector CoroIds; + CoroBeginInst *CoroBegin = nullptr; + SmallVector CoroFrameGets; + for (Instruction &I : instructions(F)) { + auto *II = dyn_cast(&I); + if (!II) + continue; + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::coro_id: + CoroIds.push_back(II); + break; + case Intrinsic::coro_begin: + CoroBegin = cast(II); + break; + case Intrinsic::coro_frame_get: + CoroFrameGets.push_back(II); + break; + } + } + assert(CoroIds.size() == 2 && "There must be two coro.id calls, from the " + "init function and ramp function respectively"); + CoroIdInst *RealId = cast(CoroBegin->getId()); + for (IntrinsicInst *I : CoroIds) + if (I != RealId) + I->replaceAllUsesWith(RealId); + DenseMap FrameSlotMap; + for (IntrinsicInst *FrameGet : CoroFrameGets) { + bool IsPromise = cast(FrameGet->getOperand(2))->getZExtValue(); + uint32_t SlotID = + cast(FrameGet->getOperand(3))->getZExtValue(); + auto Itr = FrameSlotMap.find(SlotID); + Instruction *Ptr; + if (Itr == FrameSlotMap.end()) { + Ptr = cast(FrameGet->getOperand(1)); + FrameSlotMap[SlotID] = Ptr; + } else { + Ptr = Itr->second; + } + FrameGet->replaceAllUsesWith(Ptr); + FrameGet->eraseFromParent(); + if (IsPromise) { + RealId->setOperand(1, new BitCastInst(Ptr->stripPointerCasts(), + Ptr->getType(), "", RealId)); + } + } + + return RampCall->getCalledFunction(); +} + PreservedAnalyses CoroSplitPass::run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM, LazyCallGraph &CG, CGSCCUpdateResult &UR) { @@ -2082,6 +2150,8 @@ } } + SmallVector UnpreparedInitFuncs; + SmallVector InlinedRampFuncs; // Split all the coroutines. for (LazyCallGraph::Node *N : Coroutines) { Function &F = N->getFunction(); @@ -2089,12 +2159,24 @@ StringRef Value = Attr.getValueAsString(); LLVM_DEBUG(dbgs() << "CoroSplit: Processing coroutine '" << F.getName() << "' state: " << Value << "\n"); - if (Value == UNPREPARED_FOR_SPLIT) { + if (Value == DO_NOT_PROCESS) + continue; + if (Value == UNPREPARED_FOR_SPLIT_RAMP) { // Enqueue a second iteration of the CGSCC pipeline on this SCC. UR.CWorklist.insert(&C); - F.addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT); + // Once we allow the ramp function to be optimized, we will split + // the init function directly and ignore the ramp function. + F.addFnAttr(CORO_PRESPLIT_ATTR, DO_NOT_PROCESS); + UnpreparedInitFuncs.push_back(getCoroInitFunction(F)); continue; } + if (Value == PREPARED_FOR_SPLIT_INIT) { + Function *RampFunc = inlineRampFunction(F); + InlinedRampFuncs.push_back(RampFunc); + RampFunc->removeDeadConstantUsers(); + RampFunc->dropAllReferences(); + updateCGAndAnalysisManagerForCGSCCPass(CG, C, *N, AM, UR, FAM); + } F.removeFnAttr(CORO_PRESPLIT_ATTR); SmallVector Clones; @@ -2109,6 +2191,23 @@ UR.RCWorklist.insert(CG.lookupRefSCC(CG.get(*Clones[0]))); } } + for (Function *F : UnpreparedInitFuncs) + F->addFnAttr(CORO_PRESPLIT_ATTR, PREPARED_FOR_SPLIT_INIT); + for (Function *DeadF : InlinedRampFuncs) { + auto &DeadC = *CG.lookupSCC(*CG.lookup(*DeadF)); + FAM.clear(*DeadF, DeadF->getName()); + AM.clear(DeadC, DeadC.getName()); + auto &DeadRC = DeadC.getOuterRefSCC(); + CG.removeDeadFunction(*DeadF); + + // Mark the relevant parts of the call graph as invalid so we don't visit + // them. + UR.InvalidatedSCCs.insert(&DeadC); + UR.InvalidatedRefSCCs.insert(&DeadRC); + + DeadF->getBasicBlockList().clear(); + M.getFunctionList().remove(DeadF); + } if (!PrepareFns.empty()) { for (auto *PrepareFn : PrepareFns) { @@ -2179,6 +2278,7 @@ createDevirtTriggerFunc(CG, SCC); // Split all the coroutines. + // FIXME: adapt to the new split model for (Function *F : Coroutines) { Attribute Attr = F->getFnAttribute(CORO_PRESPLIT_ATTR); StringRef Value = Attr.getValueAsString(); @@ -2190,7 +2290,7 @@ F->removeFnAttr(CORO_PRESPLIT_ATTR); continue; } - if (Value == UNPREPARED_FOR_SPLIT) { + if (Value == UNPREPARED_FOR_SPLIT_RAMP) { prepareForSplit(*F, CG); continue; }