diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -44,7 +44,12 @@ Option<"eliminateBlockingAwaitOps", "eliminate-blocking-await-ops", "bool", /*default=*/"false", "Rewrite functions with blocking async.runtime.await as coroutines " - "with async.runtime.await_and_resume."> + "with async.runtime.await_and_resume.">, + ListOption<"funcsAllowedToBlock", "funcs-allowed-to-block", "std::string", + "Comma separated list of funcs that allow blocking " + "async.runtime.await ops. Only useful in combination with " + "'eliminate-blocking-await-ops' option.", + "llvm::cl::MiscFlags::CommaSeparated"> ]; let dependentDialects = ["async::AsyncDialect"]; } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -612,9 +612,9 @@ oldCall.erase(); } -static LogicalResult -funcsToCoroutines(ModuleOp module, - llvm::DenseMap &outlinedFunctions) { +static LogicalResult funcsToCoroutines( + ModuleOp module, llvm::DenseMap &outlinedFunctions, + AsyncToAsyncRuntimePass::ListOption &funcsAllowedToBlock) { // The following code supports the general case when 2 functions mutually // recurse into each other. Because of this and that we are relying on // SymbolUserMap to find pointers to calling FuncOps, we cannot simply erase @@ -623,15 +623,22 @@ SmallVector funcWorklist; + auto isAllowedToBlock = [&funcsAllowedToBlock](FuncOp func) { + return llvm::is_contained(funcsAllowedToBlock, func.getName()); + }; + // Careful, it's okay to add a func to the worklist multiple times if and only // if the loop processing the worklist will skip the functions that have // already been converted to coroutines. - auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) { + auto addToWorklist = [&](FuncOp func) { + if (isAllowedToBlock(func)) + return; // N.B. To refactor this code into a separate pass the lookup in // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary // func and recognizing if it has a coroutine structure is messy. Passing // this dict between the passes is ugly. - if (outlinedFunctions.find(func) == outlinedFunctions.end()) { + if (isAllowedToBlock(func) || + outlinedFunctions.find(func) == outlinedFunctions.end()) { for (Operation &op : func.body().getOps()) { if (dyn_cast(op) || dyn_cast(op)) { funcWorklist.push_back(func); @@ -702,7 +709,8 @@ }); if (eliminateBlockingAwaitOps && - failed(funcsToCoroutines(module, outlinedFunctions))) { + failed( + funcsToCoroutines(module, outlinedFunctions, funcsAllowedToBlock))) { signalPassFailure(); return; } @@ -735,7 +743,11 @@ runtimeTarget.addLegalOp(); if (eliminateBlockingAwaitOps) - runtimeTarget.addIllegalOp(); + runtimeTarget.addDynamicallyLegalOp( + [&](RuntimeAwaitOp op) -> bool { + return llvm::is_contained(funcsAllowedToBlock, + op->getParentOfType().getName()); + }); if (failed(applyPartialConversion(module, runtimeTarget, std::move(asyncPatterns)))) { diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt %s -split-input-file \ -// RUN: -async-to-async-runtime="eliminate-blocking-await-ops=true" \ +// RUN: -async-to-async-runtime="\ +// RUN: eliminate-blocking-await-ops=true \ +// RUN: funcs-allowed-to-block=caller_allowed_to_block" \ // RUN: | FileCheck %s --dump-input=always // CHECK-LABEL: func @simple_callee @@ -302,3 +304,18 @@ // CHECK: async.coro.end %[[HDL]] // CHECK: return %[[TOKEN]] : !async.token } + +// CHECK-LABEL: func @caller_allowed_to_block +// CHECK-SAME: () -> f32 +func @caller_allowed_to_block() -> f32 { +// CHECK: %[[CONSTANT:.*]] = constant + %c = constant 1.0 : f32 +// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value) +// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#0 +// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#1 +// CHECK: %[[RETURNED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 + %r = call @simple_callee(%c): (f32) -> f32 + +// CHECK: return %[[RETURNED]] : f32 + return %r: f32 +}