diff --git a/mlir/include/mlir/Dialect/Async/Passes.h b/mlir/include/mlir/Dialect/Async/Passes.h --- a/mlir/include/mlir/Dialect/Async/Passes.h +++ b/mlir/include/mlir/Dialect/Async/Passes.h @@ -27,7 +27,8 @@ int32_t numWorkerThreads, int32_t minTaskSize); -std::unique_ptr> createAsyncToAsyncRuntimePass(); +std::unique_ptr> +createAsyncToAsyncRuntimePass(bool async_funcs_only = false); std::unique_ptr createAsyncRuntimeRefCountingPass(); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -45,8 +45,12 @@ class AsyncToAsyncRuntimePass : public impl::AsyncToAsyncRuntimeBase { public: - AsyncToAsyncRuntimePass() = default; + AsyncToAsyncRuntimePass(bool async_funcs_only) + : async_funcs_only_(async_funcs_only) {} void runOnOperation() override; + +private: + bool async_funcs_only_; }; } // namespace @@ -733,9 +737,10 @@ // `async.execute` body regions and converting async.func. llvm::DenseMap coros; - module.walk([&](ExecuteOp execute) { - coros.insert(outlineExecuteOp(symbolTable, execute)); - }); + if (!async_funcs_only_) + module.walk([&](ExecuteOp execute) { + coros.insert(outlineExecuteOp(symbolTable, execute)); + }); LLVM_DEBUG({ llvm::dbgs() << "Outlined " << coros.size() @@ -760,24 +765,27 @@ // Async lowering does not use type converter because it must preserve all // types for async.runtime operations. - asyncPatterns.add(ctx); + if (!async_funcs_only_) { + asyncPatterns.add(ctx); + asyncPatterns.add(ctx, coros); + // Lower assertions to conditional branches into error blocks. + asyncPatterns.add(ctx, coros); + } // Lower async.func to func.func with coroutine cfg. asyncPatterns.add(ctx); asyncPatterns.add(ctx, coros); - asyncPatterns.add(ctx, coros); - - // Lower assertions to conditional branches into error blocks. - asyncPatterns.add(ctx, coros); - // All high level async operations must be lowered to the runtime operations. ConversionTarget runtimeTarget(*ctx); runtimeTarget.addLegalDialect(); - runtimeTarget.addIllegalOp(); - runtimeTarget.addIllegalOp(); + if (!async_funcs_only_) { + runtimeTarget.addIllegalOp(); + runtimeTarget.addIllegalOp(); + } + + runtimeTarget.addIllegalOp(); // Decide if structured control flow has to be lowered to branch-based CFG. runtimeTarget.addDynamicallyLegalDialect([&](Operation *op) { @@ -786,7 +794,7 @@ return isAsync && isInCoroutine(nested) ? WalkResult::interrupt() : WalkResult::advance(); }); - return !walkResult.wasInterrupted(); + return async_funcs_only_ || !walkResult.wasInterrupted(); }); runtimeTarget.addLegalOp(); @@ -795,7 +803,7 @@ runtimeTarget.addDynamicallyLegalOp( [&](cf::AssertOp op) -> bool { auto func = op->getParentOfType(); - return coros.find(func) == coros.end(); + return async_funcs_only_ || coros.find(func) == coros.end(); }); if (failed(applyPartialConversion(module, runtimeTarget, @@ -805,6 +813,7 @@ } } -std::unique_ptr> mlir::createAsyncToAsyncRuntimePass() { - return std::make_unique(); +std::unique_ptr> +mlir::createAsyncToAsyncRuntimePass(bool async_funcs_only) { + return std::make_unique(async_funcs_only); }