Index: lib/Transforms/Coroutines/CoroInstr.h =================================================================== --- lib/Transforms/Coroutines/CoroInstr.h +++ lib/Transforms/Coroutines/CoroInstr.h @@ -80,11 +80,11 @@ enum { AlignArg, PromiseArg, CoroutineArg, InfoArg }; public: - IntrinsicInst *getCoroAlloc() { + CoroAllocInst *getCoroAlloc() { for (User *U : users()) if (auto *II = dyn_cast(U)) if (II->getIntrinsicID() == Intrinsic::coro_alloc) - return II; + return cast(II); return nullptr; } Index: lib/Transforms/Coroutines/CoroSplit.cpp =================================================================== --- lib/Transforms/Coroutines/CoroSplit.cpp +++ lib/Transforms/Coroutines/CoroSplit.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Scalar.h" @@ -273,18 +274,6 @@ // FIXME: coming in upcoming patches: // replaceUnwindCoroEnds(Shape.CoroEnds, VMap); - // We only store resume(0) and destroy(1) addresses in the coroutine frame. - // The cleanup(2) clone is only used during devirtualization when coroutine is - // eligible for heap elision and thus does not participate in indirect calls - // and does not need its address to be stored in the coroutine frame. - if (FnIndex < 2) { - // Store the address of this clone in the coroutine frame. - Builder.SetInsertPoint(Shape.FramePtr->getNextNode()); - auto *G = Builder.CreateConstInBoundsGEP2_32(Shape.FrameTy, Shape.FramePtr, - 0, FnIndex, "fn.addr"); - Builder.CreateStore(NewF, G); - } - // Eliminate coro.free from the clones, replacing it with 'null' in cleanup, // to suppress deallocation code. coro::replaceCoroFree(cast(VMap[Shape.CoroBegin->getId()]), @@ -348,6 +337,31 @@ CoroBegin->getId()->setInfo(BC); } +// Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. +static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, + Function *DestroyFn, Function *CleanupFn) { + + IRBuilder<> Builder(Shape.FramePtr->getNextNode()); + auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32( + Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField, + "resume.addr"); + Builder.CreateStore(ResumeFn, ResumeAddr); + + Value *DestroyOrCleanupFn = DestroyFn; + + CoroIdInst *CoroId = Shape.CoroBegin->getId(); + if (CoroAllocInst *CA = CoroId->getCoroAlloc()) { + // If there is a CoroAlloc and it returns false (meaning we elide the + // allocation, use CleanupFn instead of DestroyFn). + DestroyOrCleanupFn = Builder.CreateSelect(CA, DestroyFn, CleanupFn); + } + + auto *DestroyAddr = Builder.CreateConstInBoundsGEP2_32( + Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::DestroyField, + "destroy.addr"); + Builder.CreateStore(DestroyOrCleanupFn, DestroyAddr); +} + static void postSplitCleanup(Function &F) { removeUnreachableBlocks(F); llvm::legacy::FunctionPassManager FPM(F.getParent()); @@ -496,7 +510,15 @@ postSplitCleanup(*DestroyClone); postSplitCleanup(*CleanupClone); + // Store addresses resume/destroy/cleanup functions in the coroutine frame. + updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); + + // Create a constant array referring to resume/destroy/clone functions pointed + // by the last argument of @llvm.coro.info, so that CoroElide pass can + // determined correct function to call. setCoroInfo(F, Shape.CoroBegin, {ResumeClone, DestroyClone, CleanupClone}); + + // Update call graph and add the functions we created to the SCC. coro::updateCallGraph(F, {ResumeClone, DestroyClone, CleanupClone}, CG, SCC); } Index: test/Transforms/Coroutines/coro-split-00.ll =================================================================== --- test/Transforms/Coroutines/coro-split-00.ll +++ test/Transforms/Coroutines/coro-split-00.ll @@ -4,9 +4,17 @@ define i8* @f() "coroutine.presplit"="1" { entry: %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %need.alloc = call i1 @llvm.coro.alloc(token %id) + br i1 %need.alloc, label %dyn.alloc, label %begin + +dyn.alloc: %size = call i32 @llvm.coro.size.i32() %alloc = call i8* @malloc(i32 %size) - %hdl = call i8* @llvm.coro.begin(token %id, i8* %alloc) + br label %begin + +begin: + %phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ] + %hdl = call i8* @llvm.coro.begin(token %id, i8* %phi) call void @print(i32 0) %0 = call i8 @llvm.coro.suspend(token none, i1 false) switch i8 %0, label %suspend [i8 0, label %resume @@ -26,6 +34,10 @@ ; CHECK-LABEL: @f( ; CHECK: call i8* @malloc +; CHECK: @llvm.coro.begin(token %id, i8* %phi) +; CHECK: store void (%f.Frame*)* @f.resume, void (%f.Frame*)** %resume.addr +; CHECK: %[[SEL:.+]] = select i1 %need.alloc, void (%f.Frame*)* @f.destroy, void (%f.Frame*)* @f.cleanup +; CHECK: store void (%f.Frame*)* %[[SEL]], void (%f.Frame*)** %destroy.addr ; CHECK: call void @print(i32 0) ; CHECK-NOT: call void @print(i32 1) ; CHECK-NOT: call void @free( @@ -45,6 +57,12 @@ ; CHECK: call void @free( ; CHECK: ret void +; CHECK-LABEL: @f.cleanup( +; CHECK-NOT: call i8* @malloc +; CHECK-NOT: call void @print( +; CHECK-NOT: call void @free( +; CHECK: ret void + declare i8* @llvm.coro.free(token, i8*) declare i32 @llvm.coro.size.i32() declare i8 @llvm.coro.suspend(token, i1) @@ -52,7 +70,7 @@ declare void @llvm.coro.destroy(i8*) declare token @llvm.coro.id(i32, i8*, i8*, i8*) -declare i8* @llvm.coro.alloc(token) +declare i1 @llvm.coro.alloc(token) declare i8* @llvm.coro.begin(token, i8*) declare void @llvm.coro.end(i8*, i1)