diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -411,20 +411,16 @@ // ^cleanup: ... // ^suspend: ... // -static Block *addSuspensionPoint(CoroMachinery coro, Value coroState, - Operation *op) { +static void addSuspensionPoint(CoroMachinery coro, Value coroState, + Operation *op, Block *splitBlock, Block *resume, + OpBuilder::Listener *listener = nullptr) { + Location loc = op->getLoc(); MLIRContext *ctx = op->getContext(); auto i1 = LLVM::LLVMType::getInt1Ty(ctx); auto i8 = LLVM::LLVMType::getInt8Ty(ctx); - Location loc = op->getLoc(); - Block *splitBlock = op->getBlock(); - - // Split the block before `op`, newly added block is the resume block. - Block *resume = splitBlock->splitBlock(op); - // Add a coroutine suspension in place of original `op` in the split block. - OpBuilder builder = OpBuilder::atBlockEnd(splitBlock); + OpBuilder builder = OpBuilder::atBlockEnd(splitBlock, listener); auto constFalse = builder.create(loc, i1, builder.getBoolAttr(false)); @@ -457,8 +453,6 @@ loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); builder.create(loc, isZero, /*trueDest=*/resume, /*falseDest=*/coro.cleanup); - - return resume; } // Outline the body region attached to the `async.execute` op into a standalone @@ -518,8 +512,11 @@ builder.create(loc, TypeRange(), kExecute, executeArgs); // Split the entry block before the terminator. - Block *resume = addSuspensionPoint(coro, coroSave.getResult(0), - entryBlock->getTerminator()); + auto *terminatorOp = entryBlock->getTerminator(); + Block *splitBlock = terminatorOp->getBlock(); + Block *resume = splitBlock->splitBlock(terminatorOp); + addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, splitBlock, + resume); // Await on all dependencies before starting to execute the body region. builder.setInsertionPointToStart(resume); @@ -740,7 +737,7 @@ if (isInCoroutine) { const CoroMachinery &coro = outlined->getSecond(); - OpBuilder builder(op); + OpBuilder builder(op, rewriter.getListener()); MLIRContext *ctx = op->getContext(); // A pointer to coroutine resume intrinsic wrapper. @@ -760,8 +757,12 @@ builder.create(loc, TypeRange(), coroAwaitFuncName, awaitAndExecuteArgs); + Block *splitBlock = op->getBlock(); + // Split the entry block before the await operation. - addSuspensionPoint(coro, coroSave.getResult(0), op); + Block *resume = rewriter.splitBlock(splitBlock, Block::iterator(op)); + addSuspensionPoint(coro, coroSave.getResult(0), op, splitBlock, resume, + rewriter.getListener()); } // Original operation was replaced by function call or suspension point.