diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -386,8 +386,10 @@ suspendBlock}; } -// Adds a suspension point before the `op`, and moves `op` and all operations -// after it into the resume block. Returns a pointer to the resume block. +// Add a LLVM coroutine suspension point to the end of suspended block, to +// resume execution in resume block. The caller is responsible for creating the +// two suspended/resume blocks with the desired ops contained in each block. +// This function merely provides the required control flow logic. // // `coroState` must be a value returned from the call to @llvm.coro.save(...) // intrinsic (saved coroutine state). @@ -399,6 +401,8 @@ // "op"(...) // ^cleanup: ... // ^suspend: ... +// ^resume: +// "op"(...) // // After: // @@ -411,20 +415,17 @@ // ^cleanup: ... // ^suspend: ... // -static Block *addSuspensionPoint(CoroMachinery coro, Value coroState, - Operation *op) { +static void addSuspensionPoint(CoroMachinery coro, Value coroState, + Operation *op, Block *suspended, Block *resume, + OpBuilder &builder) { + Location loc = op->getLoc(); MLIRContext *ctx = op->getContext(); auto i1 = LLVM::LLVMType::getInt1Ty(ctx); auto i8 = LLVM::LLVMType::getInt8Ty(ctx); - Location loc = op->getLoc(); - Block *splitBlock = op->getBlock(); - - // Split the block before `op`, newly added block is the resume block. - Block *resume = splitBlock->splitBlock(op); - // Add a coroutine suspension in place of original `op` in the split block. - OpBuilder builder = OpBuilder::atBlockEnd(splitBlock); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToEnd(suspended); auto constFalse = builder.create(loc, i1, builder.getBoolAttr(false)); @@ -445,7 +446,7 @@ Block *resumeOrCleanup = builder.createBlock(resume); // Suspend the coroutine ...? - builder.setInsertionPointToEnd(splitBlock); + builder.setInsertionPointToEnd(suspended); auto isNegOne = builder.create( loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constNegOne); builder.create(loc, isNegOne, /*trueDest=*/coro.suspend, @@ -457,8 +458,6 @@ loc, LLVM::ICmpPredicate::eq, coroSuspend.getResult(0), constZero); builder.create(loc, isZero, /*trueDest=*/resume, /*falseDest=*/coro.cleanup); - - return resume; } // Outline the body region attached to the `async.execute` op into a standalone @@ -518,8 +517,11 @@ builder.create(loc, TypeRange(), kExecute, executeArgs); // Split the entry block before the terminator. - Block *resume = addSuspensionPoint(coro, coroSave.getResult(0), - entryBlock->getTerminator()); + auto *terminatorOp = entryBlock->getTerminator(); + Block *suspended = terminatorOp->getBlock(); + Block *resume = suspended->splitBlock(terminatorOp); + addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended, + resume, builder); // Await on all dependencies before starting to execute the body region. builder.setInsertionPointToStart(resume); @@ -740,7 +742,7 @@ if (isInCoroutine) { const CoroMachinery &coro = outlined->getSecond(); - OpBuilder builder(op); + OpBuilder builder(op, rewriter.getListener()); MLIRContext *ctx = op->getContext(); // A pointer to coroutine resume intrinsic wrapper. @@ -760,8 +762,12 @@ builder.create(loc, TypeRange(), coroAwaitFuncName, awaitAndExecuteArgs); + Block *suspended = op->getBlock(); + // Split the entry block before the await operation. - addSuspensionPoint(coro, coroSave.getResult(0), op); + Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); + addSuspensionPoint(coro, coroSave.getResult(0), op, suspended, resume, + builder); } // Original operation was replaced by function call or suspension point.