diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -30,7 +30,7 @@ SmallVector ResumeAddr; DenseMap> DestroyAddr; SmallVector CoroFrees; - CoroSuspendInst *CoroFinalSuspend; + SmallPtrSet CoroSuspendSwitches; Lowerer(Module &M) : LowererBase(M) {} @@ -38,6 +38,8 @@ bool shouldElide(Function *F, DominatorTree &DT) const; void collectPostSplitCoroIds(Function *F); bool processCoroId(CoroIdInst *, AAResults &AA, DominatorTree &DT); + bool hasEscapePath(const CoroBeginInst *, + const SmallPtrSetImpl &) const; }; } // end anonymous namespace @@ -142,6 +144,52 @@ removeTailCallAttribute(Frame, AA); } +bool Lowerer::hasEscapePath(const CoroBeginInst *CB, + const SmallPtrSetImpl &TIs) const { + const auto &It = DestroyAddr.find(CB); + assert(It != DestroyAddr.end()); + + // Limit the number of blocks we visit. + unsigned Limit = 32 * (1 + It->second.size()); + + SmallVector Worklist; + Worklist.push_back(CB->getParent()); + + SmallPtrSet Visited; + // Consider basicblock of coro.destroy as visited one, so that we + // skip the path pass through coro.destroy. + for (auto *DA : It->second) + Visited.insert(DA->getParent()); + + do { + const auto *BB = Worklist.pop_back_val(); + if (!Visited.insert(BB).second) + continue; + if (TIs.count(BB)) + return true; + + // Conservatively say that there is potentially a path. + if (!--Limit) + return true; + + auto TI = BB->getTerminator(); + // Although the default dest of coro.suspend switches is suspend pointer + // which means a escape path to normal terminator, it is reasonable to skip + // it since coroutine frame doesn't change outside the coroutine body. + if (isa(TI) && + CoroSuspendSwitches.count(cast(TI))) { + Worklist.push_back(cast(TI)->getSuccessor(1)); + Worklist.push_back(cast(TI)->getSuccessor(2)); + } else + Worklist.append(succ_begin(BB), succ_end(BB)); + + } while (!Worklist.empty()); + + // We have exhausted all possible paths and are certain that coro.begin can + // not reach to any of terminators. + return false; +} + bool Lowerer::shouldElide(Function *F, DominatorTree &DT) const { // If no CoroAllocs, we cannot suppress allocation, so elision is not // possible. @@ -154,61 +202,34 @@ // If the value escaped, then coro.destroy would have been referencing a // memory location storing that value and not the virtual register. - SmallPtrSet Terminators; - bool HasMultiPred = false; + SmallPtrSet Terminators; // First gather all of the non-exceptional terminators for the function. // Consider the final coro.suspend as the real terminator when the current // function is a coroutine. - if (CoroFinalSuspend) { - // If block of final coro.suspend has more than one predecessor, - // then there is one resume path and the others are exceptional paths, - // consider these predecessors as terminators. - BasicBlock *FinalBB = CoroFinalSuspend->getParent(); - if (FinalBB->hasNPredecessorsOrMore(2)) { - HasMultiPred = true; - for (auto *B : predecessors(FinalBB)) - Terminators.insert(B->getTerminator()); - } else - Terminators.insert(CoroFinalSuspend); - } else { for (BasicBlock &B : *F) { auto *TI = B.getTerminator(); if (TI->getNumSuccessors() == 0 && !TI->isExceptionalTerminator() && !isa(TI)) - Terminators.insert(TI); + Terminators.insert(&B); } - } // Filter out the coro.destroy that lie along exceptional paths. - SmallPtrSet DAs; - SmallPtrSet TIs; SmallPtrSet ReferencedCoroBegins; for (auto &It : DestroyAddr) { - for (CoroSubFnInst *DA : It.second) { - for (Instruction *TI : Terminators) { - if (DT.dominates(DA, TI)) { - if (HasMultiPred) - TIs.insert(TI); - else - DAs.insert(DA); + for (Instruction *DA : It.second) { + for (BasicBlock *TI : Terminators) { + if (DT.dominates(DA, TI->getTerminator())) { + ReferencedCoroBegins.insert(It.first); break; } } } - // If all the predecessors dominate coro.destroys that reference same - // coro.begin, record the coro.begin - if (TIs.size() == Terminators.size()) { - ReferencedCoroBegins.insert(It.first); - TIs.clear(); - } - } - // Find all the coro.begin referenced by coro.destroy along happy paths. - for (CoroSubFnInst *DA : DAs) { - if (auto *CB = dyn_cast(DA->getFrame())) - ReferencedCoroBegins.insert(CB); - else - return false; + // Whether there is any paths from coro.begin to Terminators which not pass + // through any of the coro.destroys. + if (!ReferencedCoroBegins.count(It.first) && + !hasEscapePath(It.first, Terminators)) + ReferencedCoroBegins.insert(It.first); } // If size of the set is the same as total number of coro.begin, that means we @@ -219,7 +240,7 @@ void Lowerer::collectPostSplitCoroIds(Function *F) { CoroIds.clear(); - CoroFinalSuspend = nullptr; + CoroSuspendSwitches.clear(); for (auto &I : instructions(F)) { if (auto *CII = dyn_cast(&I)) if (CII->getInfo().isPostSplit()) @@ -227,12 +248,16 @@ if (CII->getCoroutine() != CII->getFunction()) CoroIds.push_back(CII); + // Consider case like: + // %0 = call i8 @llvm.coro.suspend(...) + // switch i8 %0, label %suspend [i8 0, label %resume + // i8 1, label %cleanup] + // and collect the SwitchInsts which are used by escape analysis later. if (auto *CSI = dyn_cast(&I)) - if (CSI->isFinal()) { - if (!CoroFinalSuspend) - CoroFinalSuspend = CSI; - else - report_fatal_error("Only one suspend point can be marked as final"); + if (CSI->hasOneUse() && isa(CSI->use_begin()->getUser())) { + SwitchInst *SWI = cast(CSI->use_begin()->getUser()); + if (SWI->getNumCases() == 2) + CoroSuspendSwitches.insert(SWI); } } } diff --git a/llvm/test/Transforms/Coroutines/coro-heap-elide.ll b/llvm/test/Transforms/Coroutines/coro-heap-elide.ll --- a/llvm/test/Transforms/Coroutines/coro-heap-elide.ll +++ b/llvm/test/Transforms/Coroutines/coro-heap-elide.ll @@ -196,6 +196,58 @@ ret void } +; CHECK-LABEL: @callResume_with_coro_suspend_3( +define void @callResume_with_coro_suspend_3(i8 %cond) { +entry: +; CHECK: alloca %f.frame + switch i8 %cond, label %coro.ret [ + i8 0, label %init.suspend + i8 1, label %coro.ret + ] + +init.suspend: +; CHECK-NOT: llvm.coro.begin +; CHECK-NOT: CustomAlloc +; CHECK: call void @may_throw() + %hdl = call i8* @f() +; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8* %vFrame) + %0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0) + %1 = bitcast i8* %0 to void (i8*)* + call fastcc void %1(i8* %hdl) + %2 = call token @llvm.coro.save(i8* %hdl) + %3 = call i8 @llvm.coro.suspend(token %2, i1 false) + switch i8 %3, label %coro.ret [ + i8 0, label %final.suspend + i8 1, label %cleanups + ] + +; CHECK-LABEL: final.suspend: +final.suspend: +; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame) + %4 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %5 = bitcast i8* %4 to void (i8*)* + call fastcc void %5(i8* %hdl) + %6 = call token @llvm.coro.save(i8* %hdl) + %7 = call i8 @llvm.coro.suspend(token %6, i1 true) + switch i8 %7, label %coro.ret [ + i8 0, label %coro.ret + i8 1, label %cleanups + ] + +; CHECK-LABEL: cleanups: +cleanups: +; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* %vFrame) + %8 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %9 = bitcast i8* %8 to void (i8*)* + call fastcc void %9(i8* %hdl) + br label %coro.ret + +; CHECK-LABEL: coro.ret: +coro.ret: +; CHECK-NEXT: ret void + ret void +} + ; CHECK-LABEL: @callResume_PR34897_no_elision( @@ -231,6 +283,41 @@ ret void } +; CHECK-LABEL: @callResume_PR34897_elision( +define void @callResume_PR34897_elision(i1 %cond) { +; CHECK-LABEL: entry: +entry: +; CHECK: alloca %f.frame +; CHECK: tail call void @bar( + tail call void @bar(i8* null) + br i1 %cond, label %if.then, label %if.else + +if.then: +; CHECK-NOT: CustomAlloc +; CHECK: call void @may_throw() + %hdl = call i8* @f() +; CHECK: call void @bar( + tail call void @bar(i8* %hdl) +; CHECK: call fastcc void bitcast (void (%f.frame*)* @f.resume to void (i8*)*)(i8* + %0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0) + %1 = bitcast i8* %0 to void (i8*)* + call fastcc void %1(i8* %hdl) +; CHECK-NEXT: call fastcc void bitcast (void (%f.frame*)* @f.cleanup to void (i8*)*)(i8* + %2 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %3 = bitcast i8* %2 to void (i8*)* + call fastcc void %3(i8* %hdl) + br label %return + +if.else: + br label %return + +; CHECK-LABEL: return: +return: +; CHECK: ret void + ret void +} + + ; a coroutine start function (cannot elide heap alloc, due to second argument to ; coro.begin not pointint to coro.alloc) define i8* @f_no_elision() personality i8* null {