diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -51,10 +51,6 @@ } // namespace -//===----------------------------------------------------------------------===// -// async.execute op outlining to the coroutine functions. -//===----------------------------------------------------------------------===// - /// Function targeted for coroutine transformation has two additional blocks at /// the end: coroutine cleanup and coroutine suspension. /// @@ -64,6 +60,12 @@ struct CoroMachinery { func::FuncOp func; + // Async function returns an optional token, followed by some async values + // + // async.func @foo() -> !async.value { + // %cst = arith.constant 42.0 : f32 + // return %cst: f32 + // } // Async execute region returns a completion token, and an async value for // each yielded value. // @@ -71,7 +73,8 @@ // %0 = arith.constant ... : T // async.yield %0 : T // } - Value asyncToken; // token representing completion of the async region + Optional + asyncToken; // token representing completion of the async region llvm::SmallVector returnValues; // returned async values Value coroHandle; // coroutine handle (!async.coro.getHandle value) @@ -87,13 +90,9 @@ /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block /// that branches into preexisting entry block. Also inserts trailing blocks. /// -/// The result types of the passed `func` must start with an `async.token` +/// The result types of the passed `func` start with an optional `async.token` /// and be continued with some number of `async.value`s. /// -/// The func given to this function needs to have been preprocessed to have -/// either branch or yield ops as terminators. Branches to the cleanup block are -/// inserted after each yield. -/// /// See LLVM coroutines documentation: https://llvm.org/docs/Coroutines.html /// /// - `entry` block sets up the coroutine. @@ -110,7 +109,7 @@ /// ^entry(): /// %token = : !async.token // create async runtime token /// %value = : !async.value // create async value -/// %id = async.coro.getId // create a coroutine id +/// %id = async.coro.getId // create a coroutine id /// %hdl = async.coro.begin %id // create a coroutine handle /// cf.br ^preexisting_entry_block /// @@ -142,11 +141,19 @@ // ------------------------------------------------------------------------ // // Allocate async token/values that we will return from a ramp function. // ------------------------------------------------------------------------ // - auto retToken = - builder.create(TokenType::get(ctx)).getResult(); + /// We treat TokenType as state update marker to represent side-effects of + /// async computations + bool isStateful = func.getCallableResults().front().isa(); + + Optional retToken; + if (isStateful) + retToken = builder.create(TokenType::get(ctx)).getResult(); llvm::SmallVector retValues; - for (auto resType : func.getCallableResults().drop_front()) + ArrayRef resValueTypes = isStateful + ? func.getCallableResults().drop_front() + : func.getCallableResults(); + for (auto resType : resValueTypes) retValues.emplace_back( builder.create(resType).getResult()); @@ -179,26 +186,17 @@ // Mark the end of a coroutine: async.coro.end builder.create(coroHdlOp.getHandle()); - // Return created `async.token` and `async.values` from the suspend block. - // This will be the return value of a coroutine ramp function. - SmallVector ret{retToken}; + // Return created optional `async.token` and `async.values` from the suspend + // block. This will be the return value of a coroutine ramp function. + SmallVector ret; + if (retToken.has_value()) + ret.push_back(retToken.value()); ret.insert(ret.end(), retValues.begin(), retValues.end()); builder.create(ret); // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. - for (Block &block : func.getBody().getBlocks()) { - if (&block == entryBlock || &block == cleanupBlock || - &block == suspendBlock) - continue; - Operation *terminator = block.getTerminator(); - if (auto yield = dyn_cast(terminator)) { - builder.setInsertionPointToEnd(&block); - builder.create(cleanupBlock); - } - } - // The switch-resumed API based coroutine should be marked with // coroutine.presplit attribute to mark the function as a coroutine. func->setAttr("passthrough", builder.getArrayAttr( @@ -229,7 +227,10 @@ ImplicitLocOpBuilder::atBlockBegin(coro.func->getLoc(), coro.setError); // Coroutine set_error block: set error on token and all returned values. - builder.create(coro.asyncToken); + if (coro.asyncToken.has_value()) { + builder.create(coro.asyncToken.value()); + } + for (Value retValue : coro.returnValues) builder.create(retValue); @@ -239,6 +240,10 @@ return coro.setError; } +//===----------------------------------------------------------------------===// +// async.execute op outlining to the coroutine functions. +//===----------------------------------------------------------------------===// + /// Outline the body region attached to the `async.execute` op into a standalone /// function. /// @@ -382,6 +387,120 @@ }; } // namespace +//===----------------------------------------------------------------------===// +// Convert async.func, async.return and async.call operations to non-blocking +// operations based on llvm coroutine +//===----------------------------------------------------------------------===// + +namespace { + +//===----------------------------------------------------------------------===// +// Convert async.func operation to func.func +//===----------------------------------------------------------------------===// + +class AsyncFuncOpLowering : public OpConversionPattern { +public: + AsyncFuncOpLowering(MLIRContext *ctx, + llvm::DenseMap &coros) + : OpConversionPattern(ctx), coros(coros) {} + + LogicalResult + matchAndRewrite(async::FuncOp asyncFuncOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = asyncFuncOp->getLoc(); + + auto newFuncOp = rewriter.create( + loc, asyncFuncOp.getName(), asyncFuncOp.getFunctionType()); + + SymbolTable::setSymbolVisibility( + newFuncOp, SymbolTable::getSymbolVisibility(asyncFuncOp)); + // Copy over all attributes other than the name. + for (const auto &namedAttr : asyncFuncOp->getAttrs()) { + if (namedAttr.getName() != SymbolTable::getSymbolAttrName()) + newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + + rewriter.inlineRegionBefore(asyncFuncOp.getBody(), newFuncOp.getBody(), + newFuncOp.end()); + + CoroMachinery coro = setupCoroMachinery(newFuncOp); + coros[newFuncOp] = coro; + // no initial suspend, we should hot-start + + rewriter.eraseOp(asyncFuncOp); + return success(); + } + +private: + llvm::DenseMap &coros; +}; + +//===----------------------------------------------------------------------===// +// Convert async.call operation to func.call +//===----------------------------------------------------------------------===// + +class AsyncCallOpLowering : public OpConversionPattern { +public: + AsyncCallOpLowering(MLIRContext *ctx) + : OpConversionPattern(ctx) {} + + LogicalResult + matchAndRewrite(async::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op.getCallee(), op.getResultTypes(), op.getOperands()); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Convert async.return operation to async.runtime operations. +//===----------------------------------------------------------------------===// + +class AsyncReturnOpLowering : public OpConversionPattern { +public: + AsyncReturnOpLowering(MLIRContext *ctx, + llvm::DenseMap &coros) + : OpConversionPattern(ctx), coros(coros) {} + + LogicalResult + matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto func = op->template getParentOfType(); + auto funcCoro = coros.find(func); + if (funcCoro == coros.end()) + return rewriter.notifyMatchFailure( + op, "operation is not inside the async coroutine function"); + + Location loc = op->getLoc(); + const CoroMachinery &coro = funcCoro->getSecond(); + rewriter.setInsertionPointAfter(op); + + // Store return values into the async values storage and switch async + // values state to available. + for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { + Value returnValue = std::get<0>(tuple); + Value asyncValue = std::get<1>(tuple); + rewriter.create(loc, returnValue, asyncValue); + rewriter.create(loc, asyncValue); + } + + if (coro.asyncToken.has_value()) { + // Switch the coroutine completion token to available state. + rewriter.create(op->getLoc(), + coro.asyncToken.value()); + } + + rewriter.eraseOp(op); + rewriter.create(loc, coro.cleanup); + return success(); + } + +private: + llvm::DenseMap &coros; +}; +} // namespace + //===----------------------------------------------------------------------===// // Convert async.await and async.await_all operations to the async.runtime.await // or async.runtime.await_and_resume operations. @@ -393,11 +512,9 @@ using AwaitAdaptor = typename AwaitType::Adaptor; public: - AwaitOpLoweringBase( - MLIRContext *ctx, - llvm::DenseMap &outlinedFunctions) - : OpConversionPattern(ctx), - outlinedFunctions(outlinedFunctions) {} + AwaitOpLoweringBase(MLIRContext *ctx, + llvm::DenseMap &coros) + : OpConversionPattern(ctx), coros(coros) {} LogicalResult matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, @@ -409,8 +526,8 @@ // Check if await operation is inside the outlined coroutine function. auto func = op->template getParentOfType(); - auto outlined = outlinedFunctions.find(func); - const bool isInCoroutine = outlined != outlinedFunctions.end(); + auto funcCoro = coros.find(func); + const bool isInCoroutine = funcCoro != coros.end(); Location loc = op->getLoc(); Value operand = adaptor.getOperand(); @@ -436,7 +553,7 @@ // Inside the coroutine we convert await operation into coroutine suspension // point, and resume execution asynchronously. if (isInCoroutine) { - CoroMachinery &coro = outlined->getSecond(); + CoroMachinery &coro = funcCoro->getSecond(); Block *suspended = op->getBlock(); ImplicitLocOpBuilder builder(loc, op, rewriter.getListener()); @@ -488,7 +605,7 @@ } private: - llvm::DenseMap &outlinedFunctions; + llvm::DenseMap &coros; }; /// Lowering for `async.await` with a token operand. @@ -531,24 +648,22 @@ class YieldOpLowering : public OpConversionPattern { public: - YieldOpLowering( - MLIRContext *ctx, - const llvm::DenseMap &outlinedFunctions) - : OpConversionPattern(ctx), - outlinedFunctions(outlinedFunctions) {} + YieldOpLowering(MLIRContext *ctx, + const llvm::DenseMap &coros) + : OpConversionPattern(ctx), coros(coros) {} LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if yield operation is inside the async coroutine function. auto func = op->template getParentOfType(); - auto outlined = outlinedFunctions.find(func); - if (outlined == outlinedFunctions.end()) + auto funcCoro = coros.find(func); + if (funcCoro == coros.end()) return rewriter.notifyMatchFailure( op, "operation is not inside the async coroutine function"); Location loc = op->getLoc(); - const CoroMachinery &coro = outlined->getSecond(); + const CoroMachinery &coro = funcCoro->getSecond(); // Store yielded values into the async values storage and switch async // values state to available. @@ -559,14 +674,20 @@ rewriter.create(loc, asyncValue); } - // Switch the coroutine completion token to available state. - rewriter.replaceOpWithNewOp(op, coro.asyncToken); + if (coro.asyncToken.has_value()) { + // Switch the coroutine completion token to available state. + rewriter.create(op->getLoc(), + coro.asyncToken.value()); + } + + rewriter.eraseOp(op); + rewriter.create(loc, coro.cleanup); return success(); } private: - const llvm::DenseMap &outlinedFunctions; + const llvm::DenseMap &coros; }; //===----------------------------------------------------------------------===// @@ -575,24 +696,22 @@ class AssertOpLowering : public OpConversionPattern { public: - AssertOpLowering( - MLIRContext *ctx, - llvm::DenseMap &outlinedFunctions) - : OpConversionPattern(ctx), - outlinedFunctions(outlinedFunctions) {} + AssertOpLowering(MLIRContext *ctx, + llvm::DenseMap &coros) + : OpConversionPattern(ctx), coros(coros) {} LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if assert operation is inside the async coroutine function. auto func = op->template getParentOfType(); - auto outlined = outlinedFunctions.find(func); - if (outlined == outlinedFunctions.end()) + auto funcCoro = coros.find(func); + if (funcCoro == coros.end()) return rewriter.notifyMatchFailure( op, "operation is not inside the async coroutine function"); Location loc = op->getLoc(); - CoroMachinery &coro = outlined->getSecond(); + CoroMachinery &coro = funcCoro->getSecond(); Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); @@ -607,7 +726,7 @@ } private: - llvm::DenseMap &outlinedFunctions; + llvm::DenseMap &coros; }; //===----------------------------------------------------------------------===// @@ -615,22 +734,23 @@ ModuleOp module = getOperation(); SymbolTable symbolTable(module); - // Outline all `async.execute` body regions into async functions (coroutines). - llvm::DenseMap outlinedFunctions; + // Functions with coroutine CFG setups, which are results of outlining + // `async.execute` body regions and converting async.func. + llvm::DenseMap coros; module.walk([&](ExecuteOp execute) { - outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); + coros.insert(outlineExecuteOp(symbolTable, execute)); }); LLVM_DEBUG({ - llvm::dbgs() << "Outlined " << outlinedFunctions.size() + llvm::dbgs() << "Outlined " << coros.size() << " functions built from async.execute operations\n"; }); // Returns true if operation is inside the coroutine. auto isInCoroutine = [&](Operation *op) -> bool { auto parentFunc = op->getParentOfType(); - return outlinedFunctions.find(parentFunc) != outlinedFunctions.end(); + return coros.find(parentFunc) != coros.end(); }; // Lower async operations to async.runtime operations. @@ -646,18 +766,23 @@ // Async lowering does not use type converter because it must preserve all // types for async.runtime operations. asyncPatterns.add(ctx); + + // Lower async.func to func.func with coroutine cfg. + asyncPatterns.add(ctx); + asyncPatterns.add(ctx, coros); + asyncPatterns.add(ctx, - outlinedFunctions); + AwaitAllOpLowering, YieldOpLowering>(ctx, coros); // Lower assertions to conditional branches into error blocks. - asyncPatterns.add(ctx, outlinedFunctions); + asyncPatterns.add(ctx, coros); // All high level async operations must be lowered to the runtime operations. ConversionTarget runtimeTarget(*ctx); - runtimeTarget.addLegalDialect(); + runtimeTarget.addLegalDialect(); runtimeTarget.addIllegalOp(); - runtimeTarget.addIllegalOp(); + runtimeTarget.addIllegalOp(); // Decide if structured control flow has to be lowered to branch-based CFG. runtimeTarget.addDynamicallyLegalDialect([&](Operation *op) { @@ -675,7 +800,7 @@ runtimeTarget.addDynamicallyLegalOp( [&](cf::AssertOp op) -> bool { auto func = op->getParentOfType(); - return outlinedFunctions.find(func) == outlinedFunctions.end(); + return coros.find(func) == coros.end(); }); if (failed(applyPartialConversion(module, runtimeTarget, diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -433,3 +433,25 @@ // CHECK-SAME: ) -> !async.token // CHECK: %[[CST:.*]] = arith.constant 0 : index // CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]] + +// ----- +// Async Functions should be none blocking + +// CHECK-LABEL: @async_func_await +async.func @async_func_await(%arg0: f32, %arg1: !async.value) + -> !async.token { + %0 = async.await %arg1 : !async.value + return +} +// Create token for return op, and mark a function as a coroutine. +// CHECK: %[[TOKEN:.*]] = async.runtime.create : !async.token +// CHECK: %[[ID:.*]] = async.coro.id +// CHECK: %[[HDL:.*]] = async.coro.begin +// CHECK: cf.br ^[[ORIGIN_ENTRY:.*]] + +// CHECK: ^[[ORIGIN_ENTRY]]: +// CHECK: %[[SAVED:.*]] = async.coro.save %[[HDL]] +// CHECK: async.runtime.await_and_resume %[[arg1:.*]], %[[HDL]] : +// CHECK-SAME: !async.value +// CHECK: async.coro.suspend %[[SAVED]] +// CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] diff --git a/mlir/test/mlir-cpu-runner/async-func.mlir b/mlir/test/mlir-cpu-runner/async-func.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/async-func.mlir @@ -0,0 +1,149 @@ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \ +// RUN: | mlir-cpu-runner \ +// RUN: -e main -entry-point-result=void -O0 \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_lib_dir/libmlir_async_runtime%shlibext \ +// RUN: | FileCheck %s --dump-input=always + +// FIXME: https://github.com/llvm/llvm-project/issues/57231 +// UNSUPPORTED: hwasan + +async.func @async_func_empty() -> !async.token { + return +} + +async.func @async_func_assert() -> !async.token { + %false = arith.constant 0 : i1 + cf.assert %false, "error" + return +} + +async.func @async_func_nested_assert() -> !async.token { + %token0 = async.call @async_func_assert() : () -> !async.token + async.await %token0 : !async.token + return +} + +async.func @async_func_value_assert() -> !async.value { + %false = arith.constant 0 : i1 + cf.assert %false, "error" + %0 = arith.constant 123.45 : f32 + return %0 : f32 +} + +async.func @async_func_value_nested_assert() -> !async.value { + %value0 = async.call @async_func_value_assert() : () -> !async.value + %ret = async.await %value0 : !async.value + return %ret : f32 +} + +async.func @async_func_return_value() -> !async.value { + %0 = arith.constant 456.789 : f32 + return %0 : f32 +} + +async.func @async_func_non_blocking_await() -> !async.value { + %value0 = async.call @async_func_return_value() : () -> !async.value + %1 = async.await %value0 : !async.value + return %1 : f32 +} + +async.func @async_func_inside_memref() -> !async.value> { + %0 = memref.alloc() : memref + %c0 = arith.constant 0.25 : f32 + memref.store %c0, %0[] : memref + return %0 : memref +} + +async.func @async_func_passed_memref(%arg0 : !async.value>) -> !async.token { + %unwrapped = async.await %arg0 : !async.value> + %0 = memref.load %unwrapped[] : memref + %1 = arith.addf %0, %0 : f32 + memref.store %1, %unwrapped[] : memref + return +} + + +func.func @main() { + %false = arith.constant 0 : i1 + + // ------------------------------------------------------------------------ // + // Check that simple async.func completes without errors. + // ------------------------------------------------------------------------ // + %token0 = async.call @async_func_empty() : () -> !async.token + async.runtime.await %token0 : !async.token + + // CHECK: 0 + %err0 = async.runtime.is_error %token0 : !async.token + vector.print %err0 : i1 + + // ------------------------------------------------------------------------ // + // Check that assertion in the async.func converted to async error. + // ------------------------------------------------------------------------ // + %token1 = async.call @async_func_assert() : () -> !async.token + async.runtime.await %token1 : !async.token + + // CHECK: 1 + %err1 = async.runtime.is_error %token1 : !async.token + vector.print %err1 : i1 + + // ------------------------------------------------------------------------ // + // Check error propagation from the nested async.func. + // ------------------------------------------------------------------------ // + %token2 = async.call @async_func_nested_assert() : () -> !async.token + async.runtime.await %token2 : !async.token + + // CHECK: 1 + %err2 = async.runtime.is_error %token2 : !async.token + vector.print %err2 : i1 + + // ------------------------------------------------------------------------ // + // Check error propagation from the nested async.func with async values. + // ------------------------------------------------------------------------ // + %value3 = async.call @async_func_value_nested_assert() : () -> !async.value + async.runtime.await %value3 : !async.value + + // CHECK: 1 + %err3_0 = async.runtime.is_error %value3 : !async.value + vector.print %err3_0 : i1 + +// ------------------------------------------------------------------------ // + // Non-blocking async.await inside the async.func + // ------------------------------------------------------------------------ // + %result0 = async.call @async_func_non_blocking_await() : () -> !async.value + %4 = async.await %result0 : !async.value + + // CHECK: 456.789 + vector.print %4 : f32 + + // ------------------------------------------------------------------------ // + // Memref allocated inside async.func. + // ------------------------------------------------------------------------ // + %result1 = async.call @async_func_inside_memref() : () -> !async.value> + %5 = async.await %result1 : !async.value> + %6 = memref.cast %5 : memref to memref<*xf32> + + // CHECK: Unranked Memref + // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = [] + // CHECK-NEXT: [0.25] + call @printMemrefF32(%6) : (memref<*xf32>) -> () + + // ------------------------------------------------------------------------ // + // Memref passed as async.func parameter + // ------------------------------------------------------------------------ // + %token3 = async.call @async_func_passed_memref(%result1) : (!async.value>) -> !async.token + async.await %token3 : !async.token + + // CHECK: Unranked Memref + // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = [] + // CHECK-NEXT: [0.5] + call @printMemrefF32(%6) : (memref<*xf32>) -> () + + memref.dealloc %5 : memref + + return +} + +func.func private @printMemrefF32(memref<*xf32>) + attributes { llvm.emit_c_interface }