diff --git a/llvm/lib/Transforms/Coroutines/CoroElide.cpp b/llvm/lib/Transforms/Coroutines/CoroElide.cpp --- a/llvm/lib/Transforms/Coroutines/CoroElide.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroElide.cpp @@ -232,17 +232,22 @@ // Filter out the coro.destroy that lie along exceptional paths. SmallPtrSet ReferencedCoroBegins; for (auto &It : DestroyAddr) { + // If there is any coro.destroy dominates all of the terminators for the + // coro.begin, we could know the corresponding coro.begin wouldn't escape. for (Instruction *DA : It.second) { - for (BasicBlock *TI : Terminators) { - if (DT.dominates(DA, TI->getTerminator())) { - ReferencedCoroBegins.insert(It.first); - break; - } + if (llvm::all_of(Terminators, [&](auto *TI) { + return DT.dominates(DA, TI->getTerminator()); + })) { + ReferencedCoroBegins.insert(It.first); + break; } } // Whether there is any paths from coro.begin to Terminators which not pass // through any of the coro.destroys. + // + // hasEscapePath is relatively slow, so we avoid to run it as much as + // possible. if (!ReferencedCoroBegins.count(It.first) && !hasEscapePath(It.first, Terminators)) ReferencedCoroBegins.insert(It.first); diff --git a/llvm/test/Transforms/Coroutines/coro-elide.ll b/llvm/test/Transforms/Coroutines/coro-elide.ll --- a/llvm/test/Transforms/Coroutines/coro-elide.ll +++ b/llvm/test/Transforms/Coroutines/coro-elide.ll @@ -19,15 +19,23 @@ ret void } -@f.resumers = internal constant [2 x void (i8*)*] [void (i8*)* @f.resume, - void (i8*)* @f.destroy] +; cleanup part of the coroutine +define fastcc void @f.cleanup(i8*) { + tail call void @print(i32 2) + ret void +} + +@f.resumers = internal constant [3 x void (i8*)*] [void (i8*)* @f.resume, + void (i8*)* @f.destroy, + void (i8*)* @f.cleanup] ; a coroutine start function define i8* @f() { entry: %id = call token @llvm.coro.id(i32 0, i8* null, i8* bitcast (i8*()* @f to i8*), - i8* bitcast ([2 x void (i8*)*]* @f.resumers to i8*)) + i8* bitcast ([3 x void (i8*)*]* @f.resumers to i8*)) + %alloc = call i1 @llvm.coro.alloc(token %id) %hdl = call i8* @llvm.coro.begin(token %id, i8* null) ret i8* %hdl } @@ -42,7 +50,7 @@ %1 = bitcast i8* %0 to void (i8*)* call fastcc void %1(i8* %hdl) -; CHECK-NEXT: call void @print(i32 1) +; CHECK-NEXT: call void @print(i32 2) %2 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) %3 = bitcast i8* %2 to void (i8*)* call fastcc void %3(i8* %hdl) @@ -51,6 +59,50 @@ ret void } +; CHECK-LABEL: @callResumeMultiRet( +define void @callResumeMultiRet(i1 %b) { +entry: + %hdl = call i8* @f() +; CHECK: %alloc.i = call i1 @llvm.coro.alloc +; CHECK: call void @print(i32 0) + %0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0) + %1 = bitcast i8* %0 to void (i8*)* + call fastcc void %1(i8* %hdl) + br i1 %b, label %destroy, label %ret + +destroy: +; CHECK: call void @print(i32 1) + %2 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %3 = bitcast i8* %2 to void (i8*)* + call fastcc void %3(i8* %hdl) + ret void + +ret: + ret void +} + +; CHECK-LABEL: @callResumeMultiRetDommmed( +define void @callResumeMultiRetDommmed(i1 %b) { +entry: + %hdl = call i8* @f() +; CHECK-NOT: %alloc.i = call i1 @llvm.coro.alloc +; CHECK: call void @print(i32 0) + %0 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 0) + %1 = bitcast i8* %0 to void (i8*)* + call fastcc void %1(i8* %hdl) +; CHECK: call void @print(i32 2) + %2 = call i8* @llvm.coro.subfn.addr(i8* %hdl, i8 1) + %3 = bitcast i8* %2 to void (i8*)* + call fastcc void %3(i8* %hdl) + br i1 %b, label %destroy, label %ret + +destroy: + ret void + +ret: + ret void +} + ; CHECK-LABEL: @eh( define void @eh() personality i8* null { entry: @@ -113,3 +165,4 @@ declare i8* @llvm.coro.begin(token, i8*) declare i8* @llvm.coro.frame() declare i8* @llvm.coro.subfn.addr(i8*, i8) +declare i1 @llvm.coro.alloc(token)