diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -17,6 +17,7 @@ namespace mlir { class ModuleOp; +class ConversionTarget; #define GEN_PASS_DECL #include "mlir/Dialect/Async/Passes.h.inc" @@ -27,6 +28,11 @@ int32_t numWorkerThreads, int32_t minTaskSize); +void populateAsyncFuncToCoroutineConversionPatterns(RewritePatternSet &patterns, + ConversionTarget &target); + +std::unique_ptr> createAsyncFuncToCoroutinePass(); + std::unique_ptr> createAsyncToAsyncRuntimePass(); std::unique_ptr createAsyncRuntimeRefCountingPass(); diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -47,6 +47,13 @@ let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"]; } +def AsyncFuncToCoroutine : Pass<"async-func-to-coroutine", "ModuleOp"> { + let summary = "Lower high level async.func operations to the explicit" + "async.runtime and async.coro operations"; + let constructor = "mlir::createAsyncFuncToCoroutinePass()"; + let dependentDialects = ["async::AsyncDialect", "func::FuncDialect"]; +} + def AsyncRuntimeRefCounting : Pass<"async-runtime-ref-counting"> { let summary = "Automatic reference counting for Async runtime operations"; let description = [{ diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -30,6 +30,7 @@ namespace mlir { #define GEN_PASS_DEF_ASYNCTOASYNCRUNTIME +#define GEN_PASS_DEF_ASYNCFUNCTOCOROUTINE #include "mlir/Dialect/Async/Passes.h.inc" } // namespace mlir @@ -51,6 +52,17 @@ } // namespace +namespace { + +class AsyncFuncToCoroutinePass + : public impl::AsyncFuncToCoroutineBase { +public: + AsyncFuncToCoroutinePass() = default; + void runOnOperation() override; +}; + +} // namespace + /// Function targeted for coroutine transformation has two additional blocks at /// the end: coroutine cleanup and coroutine suspension. /// @@ -84,6 +96,9 @@ }; } // namespace +using FuncCoroMapPtr = + std::shared_ptr>; + /// Utility to partially update the regular function CFG to the coroutine CFG /// compatible with LLVM coroutines switched-resume lowering using /// `async.runtime.*` and `async.coro.*` operations. Adds a new entry block @@ -399,9 +414,8 @@ class AsyncFuncOpLowering : public OpConversionPattern { public: - AsyncFuncOpLowering(MLIRContext *ctx, - llvm::DenseMap &coros) - : OpConversionPattern(ctx), coros(coros) {} + AsyncFuncOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros) + : OpConversionPattern(ctx), coros_(coros) {} LogicalResult matchAndRewrite(async::FuncOp op, OpAdaptor adaptor, @@ -423,7 +437,7 @@ newFuncOp.end()); CoroMachinery coro = setupCoroMachinery(newFuncOp); - coros[newFuncOp] = coro; + (*coros_)[newFuncOp] = coro; // no initial suspend, we should hot-start rewriter.eraseOp(op); @@ -431,7 +445,7 @@ } private: - llvm::DenseMap &coros; + FuncCoroMapPtr coros_; }; //===----------------------------------------------------------------------===// @@ -458,16 +472,15 @@ class AsyncReturnOpLowering : public OpConversionPattern { public: - AsyncReturnOpLowering(MLIRContext *ctx, - llvm::DenseMap &coros) - : OpConversionPattern(ctx), coros(coros) {} + AsyncReturnOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros) + : OpConversionPattern(ctx), coros_(coros) {} LogicalResult matchAndRewrite(async::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto func = op->template getParentOfType(); - auto funcCoro = coros.find(func); - if (funcCoro == coros.end()) + auto funcCoro = coros_->find(func); + if (funcCoro == coros_->end()) return rewriter.notifyMatchFailure( op, "operation is not inside the async coroutine function"); @@ -494,7 +507,7 @@ } private: - llvm::DenseMap &coros; + FuncCoroMapPtr coros_; }; } // namespace @@ -509,9 +522,10 @@ using AwaitAdaptor = typename AwaitType::Adaptor; public: - AwaitOpLoweringBase(MLIRContext *ctx, - llvm::DenseMap &coros) - : OpConversionPattern(ctx), coros(coros) {} + AwaitOpLoweringBase(MLIRContext *ctx, FuncCoroMapPtr coros, + bool should_lower_blocking_wait) + : OpConversionPattern(ctx), coros_(coros), + should_lower_blocking_wait_(should_lower_blocking_wait) {} LogicalResult matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, @@ -521,16 +535,20 @@ if (!op.getOperand().getType().template isa()) return rewriter.notifyMatchFailure(op, "unsupported awaitable type"); - // Check if await operation is inside the outlined coroutine function. + // Check if await operation is inside the coroutine function. auto func = op->template getParentOfType(); - auto funcCoro = coros.find(func); - const bool isInCoroutine = funcCoro != coros.end(); + auto funcCoro = coros_->find(func); + const bool isInCoroutine = funcCoro != coros_->end(); Location loc = op->getLoc(); Value operand = adaptor.getOperand(); Type i1 = rewriter.getI1Type(); + // Delay lowering to block wait in case await op is inside async.execute + if (!isInCoroutine && !should_lower_blocking_wait_) + return failure(); + // Inside regular functions we use the blocking wait operation to wait for // the async object (token, value or group) to become available. if (!isInCoroutine) { @@ -602,7 +620,8 @@ } private: - llvm::DenseMap &coros; + FuncCoroMapPtr coros_; + bool should_lower_blocking_wait_; }; /// Lowering for `async.await` with a token operand. @@ -645,17 +664,16 @@ class YieldOpLowering : public OpConversionPattern { public: - YieldOpLowering(MLIRContext *ctx, - const llvm::DenseMap &coros) - : OpConversionPattern(ctx), coros(coros) {} + YieldOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros) + : OpConversionPattern(ctx), coros_(coros) {} LogicalResult matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if yield operation is inside the async coroutine function. auto func = op->template getParentOfType(); - auto funcCoro = coros.find(func); - if (funcCoro == coros.end()) + auto funcCoro = coros_->find(func); + if (funcCoro == coros_->end()) return rewriter.notifyMatchFailure( op, "operation is not inside the async coroutine function"); @@ -682,7 +700,7 @@ } private: - const llvm::DenseMap &coros; + FuncCoroMapPtr coros_; }; //===----------------------------------------------------------------------===// @@ -691,17 +709,16 @@ class AssertOpLowering : public OpConversionPattern { public: - AssertOpLowering(MLIRContext *ctx, - llvm::DenseMap &coros) - : OpConversionPattern(ctx), coros(coros) {} + AssertOpLowering(MLIRContext *ctx, FuncCoroMapPtr coros) + : OpConversionPattern(ctx), coros_(coros) {} LogicalResult matchAndRewrite(cf::AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if assert operation is inside the async coroutine function. auto func = op->template getParentOfType(); - auto funcCoro = coros.find(func); - if (funcCoro == coros.end()) + auto funcCoro = coros_->find(func); + if (funcCoro == coros_->end()) return rewriter.notifyMatchFailure( op, "operation is not inside the async coroutine function"); @@ -721,7 +738,7 @@ } private: - llvm::DenseMap &coros; + FuncCoroMapPtr coros_; }; //===----------------------------------------------------------------------===// @@ -730,22 +747,23 @@ SymbolTable symbolTable(module); // Functions with coroutine CFG setups, which are results of outlining - // `async.execute` body regions and converting async.func. - llvm::DenseMap coros; + // `async.execute` body regions + FuncCoroMapPtr coros = + std::make_shared>(); module.walk([&](ExecuteOp execute) { - coros.insert(outlineExecuteOp(symbolTable, execute)); + coros->insert(outlineExecuteOp(symbolTable, execute)); }); LLVM_DEBUG({ - llvm::dbgs() << "Outlined " << coros.size() + llvm::dbgs() << "Outlined " << coros->size() << " functions built from async.execute operations\n"; }); // Returns true if operation is inside the coroutine. auto isInCoroutine = [&](Operation *op) -> bool { auto parentFunc = op->getParentOfType(); - return coros.find(parentFunc) != coros.end(); + return coros->find(parentFunc) != coros->end(); }; // Lower async operations to async.runtime operations. @@ -762,22 +780,18 @@ // types for async.runtime operations. asyncPatterns.add(ctx); - // Lower async.func to func.func with coroutine cfg. - asyncPatterns.add(ctx); - asyncPatterns.add(ctx, coros); - - asyncPatterns.add(ctx, coros); + asyncPatterns + .add( + ctx, coros, /*should_lower_blocking_wait=*/true); // Lower assertions to conditional branches into error blocks. - asyncPatterns.add(ctx, coros); + asyncPatterns.add(ctx, coros); // All high level async operations must be lowered to the runtime operations. ConversionTarget runtimeTarget(*ctx); runtimeTarget.addLegalDialect(); runtimeTarget.addIllegalOp(); - runtimeTarget.addIllegalOp(); + runtimeTarget.addIllegalOp(); // Decide if structured control flow has to be lowered to branch-based CFG. runtimeTarget.addDynamicallyLegalDialect([&](Operation *op) { @@ -795,7 +809,7 @@ runtimeTarget.addDynamicallyLegalOp( [&](cf::AssertOp op) -> bool { auto func = op->getParentOfType(); - return coros.find(func) == coros.end(); + return coros->find(func) == coros->end(); }); if (failed(applyPartialConversion(module, runtimeTarget, @@ -805,6 +819,58 @@ } } +//===----------------------------------------------------------------------===// +void mlir::populateAsyncFuncToCoroutineConversionPatterns( + RewritePatternSet &patterns, ConversionTarget &target) { + // Functions with coroutine CFG setups, which are results of converting + // async.func. + FuncCoroMapPtr coros = + std::make_shared>(); + MLIRContext *ctx = patterns.getContext(); + // Lower async.func to func.func with coroutine cfg. + patterns.add(ctx); + patterns.add(ctx, coros); + + patterns.add( + ctx, coros, /*should_lower_blocking_wait=*/false); + patterns.add(ctx, coros); + + target.addDynamicallyLegalOp( + [coros](Operation *op) { + auto func = op->getParentOfType(); + return coros->find(func) == coros->end(); + }); +} + +void AsyncFuncToCoroutinePass::runOnOperation() { + ModuleOp module = getOperation(); + + // Lower async operations to async.runtime operations. + MLIRContext *ctx = module->getContext(); + RewritePatternSet asyncPatterns(ctx); + ConversionTarget runtimeTarget(*ctx); + + // Lower async.func to func.func with coroutine cfg. + populateAsyncFuncToCoroutineConversionPatterns(asyncPatterns, runtimeTarget); + + runtimeTarget.addLegalDialect(); + runtimeTarget.addIllegalOp(); + + runtimeTarget.addLegalOp(); + + if (failed(applyPartialConversion(module, runtimeTarget, + std::move(asyncPatterns)))) { + signalPassFailure(); + return; + } +} + std::unique_ptr> mlir::createAsyncToAsyncRuntimePass() { return std::make_unique(); } + +std::unique_ptr> +mlir::createAsyncFuncToCoroutinePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -split-input-file -async-to-async-runtime \ -// RUN: | FileCheck %s --dump-input=always +// RUN: mlir-opt %s -split-input-file -async-func-to-coroutine \ +// RUN: -async-to-async-runtime | FileCheck %s --dump-input=always // CHECK-LABEL: @execute_no_async_args func.func @execute_no_async_args(%arg0: f32, %arg1: memref<1xf32>) { diff --git a/mlir/test/mlir-cpu-runner/async-func.mlir b/mlir/test/mlir-cpu-runner/async-func.mlir --- a/mlir/test/mlir-cpu-runner/async-func.mlir +++ b/mlir/test/mlir-cpu-runner/async-func.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \ +// RUN: mlir-opt %s -pass-pipeline="builtin.module(async-func-to-coroutine,async-to-async-runtime,func.func(async-runtime-ref-counting,async-runtime-ref-counting-opt),convert-async-to-llvm,func.func(convert-linalg-to-loops,convert-scf-to-cf),convert-linalg-to-llvm,convert-vector-to-llvm,func.func(convert-arith-to-llvm),convert-func-to-llvm,reconcile-unrealized-casts)" \ // RUN: | mlir-cpu-runner \ // RUN: -e main -entry-point-result=void -O0 \ // RUN: -shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext \