diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -43,7 +43,12 @@ Option<"eliminateBlockingAwaitOps", "eliminate-blocking-await-ops", "bool", /*default=*/"false", "Rewrite functions with blocking async.runtime.await as coroutines " - "with async.runtime.await_and_resume."> + "with async.runtime.await_and_resume.">, + ListOption<"funcsAllowedToBlock", "funcs-allowed-to-block", "std::string", + "Comma separated list of funcs that allow blocking " + "async.runtime.await ops. Only useful in combination with " + "'eliminate-blocking-await-ops' option.", + "llvm::cl::MiscFlags::CommaSeparated"> ]; let dependentDialects = ["async::AsyncDialect"]; } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -613,7 +613,8 @@ } void funcsToCoroutines( - ModuleOp module, llvm::DenseMap &outlinedFunctions) { + ModuleOp module, llvm::DenseMap &outlinedFunctions, + AsyncToAsyncRuntimePass::ListOption &funcsAllowedToBlock) { // The following code supports the general case when 2 functions mutually // recurse into each other. Because of this and that we are relying on // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase @@ -622,15 +623,23 @@ std::vector funcWorklist; + auto isAllowedToBlock = [&funcsAllowedToBlock](FuncOp func) { + return llvm::is_contained(funcsAllowedToBlock, func.getName()); + }; + // Careful, it's okay to add a func to the worklist multiple times if and only // if the loop processing the worklist will skip the functions that have // already been converted to coroutines. - auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) { + auto addToWorklist = [&outlinedFunctions, &funcWorklist, + &isAllowedToBlock](FuncOp func) { + if (isAllowedToBlock(func)) + return; // N.B. To refactor this code into a separate pass the lookup in // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary // func and recognizing if it has a coroutine structure is messy. Passing // this dict between the passes is ugly. - if (outlinedFunctions.find(func) == outlinedFunctions.end()) { + if (isAllowedToBlock(func) || + outlinedFunctions.find(func) == outlinedFunctions.end()) { for (Operation &op : func.body().getOps()) { if (dyn_cast(op) || dyn_cast(op)) { funcWorklist.push_back(func); @@ -698,7 +707,7 @@ }); if (eliminateBlockingAwaitOps) - funcsToCoroutines(module, outlinedFunctions); + funcsToCoroutines(module, outlinedFunctions, funcsAllowedToBlock); // Lower async operations to async.runtime operations. MLIRContext *ctx = module->getContext(); @@ -728,7 +737,11 @@ runtimeTarget.addLegalOp(); if (eliminateBlockingAwaitOps) - runtimeTarget.addIllegalOp(); + runtimeTarget.addDynamicallyLegalOp( + [&](RuntimeAwaitOp op) -> bool { + return llvm::is_contained(funcsAllowedToBlock, + op->getParentOfType().getName()); + }); if (failed(applyPartialConversion(module, runtimeTarget, std::move(asyncPatterns)))) { diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s -split-input-file \ -// RUN: -async-to-async-runtime="eliminate-blocking-await-ops=true" \ +// RUN: -async-to-async-runtime="\ +// RUN: eliminate-blocking-await-ops=true \ +// RUN: funcs-allowed-to-block=caller_allowed_to_block" \ // RUN: | FileCheck %s --dump-input=always // CHECK-LABEL: func @simple_callee @@ -302,3 +304,18 @@ // CHECK: async.coro.end %[[HDL]] // CHECK: return %[[TOKEN]] : !async.token } + +// CHECK-LABEL: func @caller_allowed_to_block +// CHECK-SAME: () -> f32 +func @caller_allowed_to_block() -> f32 { +// CHECK: %[[CONSTANT:.*]] = constant + %c = constant 1.0 : f32 +// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value) +// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#0 +// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#1 +// CHECK: %[[RETURNED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 + %r = call @simple_callee(%c): (f32) -> f32 + +// CHECK: return %[[RETURNED]] : f32 + return %r: f32 +}