diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -8,6 +8,7 @@ #include "llvm/Transforms/Coroutines/CoroElide.h" #include "CoroInternal.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/Dominators.h" @@ -27,8 +28,9 @@ SmallVector CoroBegins; SmallVector CoroAllocs; SmallVector ResumeAddr; - SmallVector DestroyAddr; + DenseMap> DestroyAddr; SmallVector CoroFrees; + CoroSuspendInst *CoroFinalSuspend; Lowerer(Module &M) : LowererBase(M) {} @@ -146,33 +148,62 @@ if (CoroAllocs.empty()) return false; - // Check that for every coro.begin there is a coro.destroy directly - // referencing the SSA value of that coro.begin along a non-exceptional path. + // Check that for every coro.begin there is at least one coro.destroy directly + // referencing the SSA value of that coro.begin along each + // non-exceptional path. // If the value escaped, then coro.destroy would have been referencing a // memory location storing that value and not the virtual register. - // First gather all of the non-exceptional terminators for the function. SmallPtrSet Terminators; - for (BasicBlock &B : *F) { - auto *TI = B.getTerminator(); - if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() && - !isa(TI)) - Terminators.insert(TI); + bool HasMultiPred = false; + // First gather all of the non-exceptional terminators for the function. + // Consider the final coro.suspend as the real terminator when the current + // function is a coroutine. + if (CoroFinalSuspend) { + // If block of final coro.suspend has more than one predecessor, + // then there is one resume path and the others are exceptional paths, + // consider these predecessors as terminators. + BasicBlock *FinalBB = CoroFinalSuspend->getParent(); + if (FinalBB->hasNPredecessorsOrMore(2)) { + HasMultiPred = true; + for (auto *B : predecessors(FinalBB)) + Terminators.insert(B->getTerminator()); + } else + Terminators.insert(CoroFinalSuspend); + } else { + for (BasicBlock &B : *F) { + auto *TI = B.getTerminator(); + if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() && + !isa(TI)) + Terminators.insert(TI); + } } // Filter out the coro.destroy that lie along exceptional paths. SmallPtrSet DAs; - for (CoroSubFnInst *DA : DestroyAddr) { - for (Instruction *TI : Terminators) { - if (DT.dominates(DA, TI)) { - DAs.insert(DA); - break; + SmallPtrSet TIs; + SmallPtrSet ReferencedCoroBegins; + for (auto &It : DestroyAddr) { + for (CoroSubFnInst *DA : It.second) { + for (Instruction *TI : Terminators) { + if (DT.dominates(DA, TI)) { + if (HasMultiPred) + TIs.insert(TI); + else + DAs.insert(DA); + break; + } } } + // If all the predecessors dominate coro.destroys that reference same + // coro.begin, record the coro.begin + if (TIs.size() == Terminators.size()) { + ReferencedCoroBegins.insert(It.first); + TIs.clear(); + } } // Find all the coro.begin referenced by coro.destroy along happy paths. - SmallPtrSet ReferencedCoroBegins; for (CoroSubFnInst *DA : DAs) { if (auto *CB = dyn_cast(DA->getFrame())) ReferencedCoroBegins.insert(CB); @@ -188,12 +219,22 @@ void Lowerer::collectPostSplitCoroIds(Function *F) { CoroIds.clear(); - for (auto &I : instructions(F)) + CoroFinalSuspend = nullptr; + for (auto &I : instructions(F)) { if (auto *CII = dyn_cast(&I)) if (CII->getInfo().isPostSplit()) // If it is the coroutine itself, don't touch it. if (CII->getCoroutine() != CII->getFunction()) CoroIds.push_back(CII); + + if (auto *CSI = dyn_cast(&I)) + if (CSI->isFinal()) { + if (!CoroFinalSuspend) + CoroFinalSuspend = CSI; + else + report_fatal_error("Only one suspend point can be marked as final"); + } + } } bool Lowerer::processCoroId(CoroIdInst *CoroId, AAResults &AA, @@ -226,7 +267,7 @@ ResumeAddr.push_back(II); break; case CoroSubFnInst::DestroyIndex: - DestroyAddr.push_back(II); + DestroyAddr[CB].push_back(II); break; default: llvm_unreachable("unexpected coro.subfn.addr constant"); @@ -249,7 +290,8 @@ Resumers, ShouldElide ? CoroSubFnInst::CleanupIndex : CoroSubFnInst::DestroyIndex); - replaceWithConstant(DestroyAddrConstant, DestroyAddr); + for (auto &It : DestroyAddr) + replaceWithConstant(DestroyAddrConstant, It.second); if (ShouldElide) { auto *FrameTy = getFrameType(cast(ResumeAddrConstant)); diff --git a/llvm/test/Transforms/Coroutines/coro-heap-elide.ll b/llvm/test/Transforms/Coroutines/coro-heap-elide.ll --- a/llvm/test/Transforms/Coroutines/coro-heap-elide.ll +++ b/llvm/test/Transforms/Coroutines/coro-heap-elide.ll @@ -84,6 +84,120 @@ ret void } +; CHECK-LABEL: @callResume_with_coro_suspend_1( +define void @callResume_with_coro_suspend_1() { +entry: +; CHECK: alloca %f.frame +; CHECK-NOT: coro.begin +; CHECK-NOT: CustomAlloc +; CHECK: call void @may_throw() + %hdl = call i8* @f() + +; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8* %vFrame) + %0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0) + %1 = bitcast i8* %0 to void (i8*)* + call fastcc void %1(i8* %hdl) + %2 = call token @llvm.coro.save(i8* %hdl) + %3 = call i8 @llvm.coro.suspend(token %2, i1 false) + switch i8 %3, label %coro.ret [ + i8 0, label %final.suspend + i8 1, label %cleanups + ] + +; CHECK-LABEL: final.suspend: +final.suspend: +; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame) + %4 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %5 = bitcast i8* %4 to void (i8*)* + call fastcc void %5(i8* %hdl) + %6 = call token @llvm.coro.save(i8* %hdl) + %7 = call i8 @llvm.coro.suspend(token %6, i1 true) + switch i8 %7, label %coro.ret [ + i8 0, label %coro.ret + i8 1, label %cleanups + ] + +; CHECK-LABEL: cleanups: +cleanups: +; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame) + %8 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %9 = bitcast i8* %8 to void (i8*)* + call fastcc void %9(i8* %hdl) + br label %coro.ret + +; CHECK-LABEL: coro.ret: +coro.ret: +; CHECK-NEXT: ret void + ret void +} + +; CHECK-LABEL: @callResume_with_coro_suspend_2( +define void @callResume_with_coro_suspend_2() personality i8* null { +entry: +; CHECK: alloca %f.frame +; CHECK-NOT: coro.begin +; CHECK-NOT: CustomAlloc +; CHECK: call void @may_throw() + %hdl = call i8* @f() + + %0 = call token @llvm.coro.save(i8* %hdl) +; CHECK: invoke fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8* %vFrame) + %1 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0) + %2 = bitcast i8* %1 to void (i8*)* + invoke fastcc void %2(i8* %hdl) + to label %invoke.cont1 unwind label %lpad + +; CHECK-LABEL: invoke.cont1: +invoke.cont1: + %3 = call i8 @llvm.coro.suspend(token %0, i1 false) + switch i8 %3, label %coro.ret [ + i8 0, label %final.ready + i8 1, label %cleanups + ] + +; CHECK-LABEL: lpad: +lpad: + %4 = landingpad { i8*, i32 } + catch i8* null +; CHECK: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame) + %5 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %6 = bitcast i8* %5 to void (i8*)* + call fastcc void %6(i8* %hdl) + br label %final.suspend + +; CHECK-LABEL: final.ready: +final.ready: +; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame) + %7 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %8 = bitcast i8* %7 to void (i8*)* + call fastcc void %8(i8* %hdl) + br label %final.suspend + +; CHECK-LABEL: final.suspend: +final.suspend: + %9 = call token @llvm.coro.save(i8* %hdl) + %10 = call i8 @llvm.coro.suspend(token %9, i1 true) + switch i8 %10, label %coro.ret [ + i8 0, label %coro.ret + i8 1, label %cleanups + ] + +; CHECK-LABEL: cleanups: +cleanups: +; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame) + %11 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %12 = bitcast i8* %11 to void (i8*)* + call fastcc void %12(i8* %hdl) + br label %coro.ret + +; CHECK-LABEL: coro.ret: +coro.ret: +; CHECK-NEXT: ret void + ret void +} + + + ; CHECK-LABEL: @callResume_PR34897_no_elision( define void @callResume_PR34897_no_elision(i1 %cond) { ; CHECK-LABEL: entry: @@ -161,3 +275,5 @@ declare i8* @llvm.coro.begin(token, i8*) declare i8* @llvm.coro.frame(token) declare i8* @llvm.coro.subfn.addr(i8*, i8) +declare i8 @llvm.coro.suspend(token, i1) +declare token @llvm.coro.save(i8*)