Index: lib/Transforms/Coroutines/CoroInstr.h =================================================================== --- lib/Transforms/Coroutines/CoroInstr.h +++ lib/Transforms/Coroutines/CoroInstr.h @@ -80,6 +80,14 @@ enum { AlignArg, PromiseArg, CoroutineArg, InfoArg }; public: + IntrinsicInst *getCoroAlloc() { + for (User *U : users()) + if (auto *II = dyn_cast(U)) + if (II->getIntrinsicID() == Intrinsic::coro_alloc) + return II; + return nullptr; + } + IntrinsicInst *getCoroBegin() { for (User *U : users()) if (auto *II = dyn_cast(U)) Index: lib/Transforms/Coroutines/CoroSplit.cpp =================================================================== --- lib/Transforms/Coroutines/CoroSplit.cpp +++ lib/Transforms/Coroutines/CoroSplit.cpp @@ -361,14 +361,130 @@ FPM.doFinalization(); } +// Coroutine has no suspend points. Remove heap allocation for the coroutine +// frame if possible. +static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { + auto *CoroId = CoroBegin->getId(); + auto *AllocInst = CoroId->getCoroAlloc(); + coro::replaceCoroFree(CoroId, /*Elide=*/AllocInst != nullptr); + if (AllocInst) { + IRBuilder<> Builder(AllocInst); + // FIXME: Need to handle overaligned members. + auto *Frame = Builder.CreateAlloca(FrameTy); + auto *vFrame = Builder.CreateBitCast(Frame, Builder.getInt8PtrTy()); + AllocInst->replaceAllUsesWith(Builder.getFalse()); + AllocInst->eraseFromParent(); + CoroBegin->replaceAllUsesWith(vFrame); + } else + CoroBegin->replaceAllUsesWith(CoroBegin->getMem()); + + CoroBegin->eraseFromParent(); +} + +// look for a very simple pattern +// coro.save +// no other calls +// resume or destroy call +// coro.suspend +// +// If there are other calls between coro.save and coro.suspend, they can +// potentially resume or destroy the coroutine, so it is unsafe to eliminate a +// suspend point. +static bool simplifySuspendPoint(CoroSuspendInst *Suspend, + CoroBeginInst *CoroBegin) { + auto *Save = Suspend->getCoroSave(); + auto *BB = Suspend->getParent(); + if (BB != Save->getParent()) + return false; + + CallSite SingleCallSite; + + // Check that we have only one CallSite. + for (Instruction *I = Save->getNextNode(); I != Suspend; + I = I->getNextNode()) { + if (isa(I)) + continue; + if (isa(I)) + continue; + if (CallSite CS = CallSite(I)) { + if (SingleCallSite) + return false; + else + SingleCallSite = CS; + } + } + auto *CallInstr = SingleCallSite.getInstruction(); + if (!CallInstr) + return false; + + auto *Callee = SingleCallSite.getCalledValue(); + + if (isa(Callee)) + return false; + + // See if the callsite is for resumption or destruction of the coroutine. + Callee = Callee->stripPointerCasts(); + auto *SubFn = dyn_cast(Callee); + if (!SubFn) + return false; + + // Does not refer to the current coroutine, we cannot do anything with it. + if (SubFn->getFrame() != CoroBegin) + return false; + + // Replace llvm.coro.suspend with the value that results in resumption over + // the resume or cleanup path. + Suspend->replaceAllUsesWith(SubFn->getRawIndex()); + Suspend->eraseFromParent(); + Save->eraseFromParent(); + + // No longer need a call to coro.resume or coro.destroy. + CallInstr->eraseFromParent(); + + if (SubFn->user_empty()) + SubFn->eraseFromParent(); + + return true; +} + +// Remove suspend points that are simplified. +static void simplifySuspendPoints(coro::Shape &Shape) { + auto &S = Shape.CoroSuspends; + size_t I = 0, N = S.size(); + if (N == 0) + return; + for (;;) { + if (simplifySuspendPoint(S[I], Shape.CoroBegin)) { + if (--N == I) + break; + std::swap(S[I], S[N]); + continue; + } + if (++I == N) + break; + } + S.resize(N); +} + static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { coro::Shape Shape(F); if (!Shape.CoroBegin) return; + simplifySuspendPoints(Shape); buildCoroutineFrame(F, Shape); replaceFrameSize(Shape); + // If there are no suspend points, no split required, just remove + // the allocation and deallocation blocks, they are not needed. + if (Shape.CoroSuspends.empty()) { + handleNoSuspendCoroutine(Shape.CoroBegin, Shape.FrameTy); + removeCoroEnds(Shape); + postSplitCleanup(F); + coro::updateCallGraph(F, {}, CG, SCC); + return; + } + auto *ResumeEntry = createResumeEntryBlock(F, Shape); auto ResumeClone = createClone(F, ".resume", Shape, ResumeEntry, 0); auto DestroyClone = createClone(F, ".destroy", Shape, ResumeEntry, 1); Index: test/Transforms/Coroutines/no-suspend.ll =================================================================== --- /dev/null +++ test/Transforms/Coroutines/no-suspend.ll @@ -0,0 +1,189 @@ +; Test no suspend coroutines +; RUN: opt < %s -O2 -enable-coroutines -S | FileCheck %s + +; Coroutine with no-suspends will turn into: +; +; CHECK-LABEL: define void @no_suspends( +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @print(i32 %n) +; CHECK-NEXT: ret void +; +define void @no_suspends(i32 %n) { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) + br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin +dyn.alloc: + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + br label %coro.begin +coro.begin: + %phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ] + %hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %phi) + br label %body +body: + call void @print(i32 %n) + br label %cleanup +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + %need.dyn.free = icmp ne i8* %mem, null + br i1 %need.dyn.free, label %dyn.free, label %suspend +dyn.free: + call void @free(i8* %mem) + br label %suspend +suspend: + call void @llvm.coro.end(i8* %hdl, i1 false) + ret void +} + +; SimplifySuspendPoint will detect that coro.resume resumes itself and will +; replace suspend with a jump to %resume label turning it into no-suspend +; coroutine. +; +; CHECK-LABEL: define void @simplify_resume( +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @print(i32 0) +; CHECK-NEXT: ret void +; +define void @simplify_resume() { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) + br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin +dyn.alloc: + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + br label %coro.begin +coro.begin: + %phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ] + %hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %phi) + br label %body +body: + %save = call token @llvm.coro.save(i8* %hdl) + call void @llvm.coro.resume(i8* %hdl) + %0 = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %0, label %suspend [i8 0, label %resume + i8 1, label %pre.cleanup] +resume: + call void @print(i32 0) + br label %cleanup + +pre.cleanup: + call void @print(i32 1) + br label %cleanup + +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend +suspend: + call void @llvm.coro.end(i8* %hdl, i1 false) + ret void +} + +; SimplifySuspendPoint will detect that coroutine destroys itself and will +; replace suspend with a jump to %cleanup label turning it into no-suspend +; coroutine. +; +; CHECK-LABEL: define void @simplify_destroy( +; CHECK-NEXT: entry: +; CHECK-NEXT: call void @print(i32 1) +; CHECK-NEXT: ret void +; +define void @simplify_destroy() { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) + br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin +dyn.alloc: + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + br label %coro.begin +coro.begin: + %phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ] + %hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %phi) + br label %body +body: + %save = call token @llvm.coro.save(i8* %hdl) + call void @llvm.coro.destroy(i8* %hdl) + %0 = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %0, label %suspend [i8 0, label %resume + i8 1, label %pre.cleanup] +resume: + call void @print(i32 0) + br label %cleanup + +pre.cleanup: + call void @print(i32 1) + br label %cleanup + +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend +suspend: + call void @llvm.coro.end(i8* %hdl, i1 false) + ret void +} + +; SimplifySuspendPoint won't be able to simplify if it detects that there are +; other calls between coro.save and coro.suspend. They potentially can call +; resume or destroy, so we should not simplify this suspend point. +; +; CHECK-LABEL: define void @cannot_simplify( +; CHECK-NEXT: entry: +; CHECK-NEXT: call i8* @malloc + +define void @cannot_simplify() { +entry: + %id = call token @llvm.coro.id(i32 0, i8* null, i8* null, i8* null) + %need.dyn.alloc = call i1 @llvm.coro.alloc(token %id) + br i1 %need.dyn.alloc, label %dyn.alloc, label %coro.begin +dyn.alloc: + %size = call i32 @llvm.coro.size.i32() + %alloc = call i8* @malloc(i32 %size) + br label %coro.begin +coro.begin: + %phi = phi i8* [ null, %entry ], [ %alloc, %dyn.alloc ] + %hdl = call noalias i8* @llvm.coro.begin(token %id, i8* %phi) + br label %body +body: + %save = call token @llvm.coro.save(i8* %hdl) + call void @foo() + call void @llvm.coro.destroy(i8* %hdl) + %0 = call i8 @llvm.coro.suspend(token %save, i1 false) + switch i8 %0, label %suspend [i8 0, label %resume + i8 1, label %pre.cleanup] +resume: + call void @print(i32 0) + br label %cleanup + +pre.cleanup: + call void @print(i32 1) + br label %cleanup + +cleanup: + %mem = call i8* @llvm.coro.free(token %id, i8* %hdl) + call void @free(i8* %mem) + br label %suspend +suspend: + call void @llvm.coro.end(i8* %hdl, i1 false) + ret void +} + +declare i8* @malloc(i32) +declare void @free(i8*) +declare void @print(i32) +declare void @foo() + +declare token @llvm.coro.id(i32, i8*, i8*, i8*) +declare i1 @llvm.coro.alloc(token) +declare i32 @llvm.coro.size.i32() +declare i8* @llvm.coro.begin(token, i8*) +declare token @llvm.coro.save(i8* %hdl) +declare i8 @llvm.coro.suspend(token, i1) +declare i8* @llvm.coro.free(token, i8*) +declare void @llvm.coro.end(i8*, i1) + +declare void @llvm.coro.resume(i8*) +declare void @llvm.coro.destroy(i8*)