diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td @@ -28,6 +28,15 @@ }]; let cppNamespace = "::mlir::async"; + + let extraClassDeclaration = [{ + // The name of a unit attribute on funcs that are allowed to have a blocking + // async.runtime.await ops. Only useful in combination with + // 'eliminate-blocking-await-ops' option, which in absence of this attribute + // might convert a func to a coroutine. + static constexpr StringRef kAllowedToBlockAttrName = "async.allowed_to_block"; + }]; + } #endif // ASYNC_DIALECT_TD diff --git a/mlir/include/mlir/Dialect/Async/Passes.td b/mlir/include/mlir/Dialect/Async/Passes.td --- a/mlir/include/mlir/Dialect/Async/Passes.td +++ b/mlir/include/mlir/Dialect/Async/Passes.td @@ -44,7 +44,7 @@ Option<"eliminateBlockingAwaitOps", "eliminate-blocking-await-ops", "bool", /*default=*/"false", "Rewrite functions with blocking async.runtime.await as coroutines " - "with async.runtime.await_and_resume."> + "with async.runtime.await_and_resume.">, ]; let dependentDialects = ["async::AsyncDialect"]; } diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -16,6 +16,8 @@ #include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc" +constexpr StringRef AsyncDialect::kAllowedToBlockAttrName; + void AsyncDialect::initialize() { addOperations< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -614,6 +614,10 @@ oldCall.erase(); } +static bool isAllowedToBlock(FuncOp func) { + return !!func->getAttrOfType(AsyncDialect::kAllowedToBlockAttrName); +} + static LogicalResult funcsToCoroutines(ModuleOp module, llvm::DenseMap &outlinedFunctions) { @@ -628,12 +632,15 @@ // Careful, it's okay to add a func to the worklist multiple times if and only // if the loop processing the worklist will skip the functions that have // already been converted to coroutines. - auto addToWorklist = [&outlinedFunctions, &funcWorklist](FuncOp func) { + auto addToWorklist = [&](FuncOp func) { + if (isAllowedToBlock(func)) + return; // N.B. To refactor this code into a separate pass the lookup in // outlinedFunctions is the most obvious obstacle. Looking at an arbitrary // func and recognizing if it has a coroutine structure is messy. Passing // this dict between the passes is ugly. - if (outlinedFunctions.find(func) == outlinedFunctions.end()) { + if (isAllowedToBlock(func) || + outlinedFunctions.find(func) == outlinedFunctions.end()) { for (Operation &op : func.body().getOps()) { if (dyn_cast(op) || dyn_cast(op)) { funcWorklist.push_back(func); @@ -759,7 +766,10 @@ }); if (eliminateBlockingAwaitOps) - runtimeTarget.addIllegalOp(); + runtimeTarget.addDynamicallyLegalOp( + [&](RuntimeAwaitOp op) -> bool { + return isAllowedToBlock(op->getParentOfType()); + }); if (failed(applyPartialConversion(module, runtimeTarget, std::move(asyncPatterns)))) { diff --git a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime-eliminate-blocking.mlir @@ -302,3 +302,18 @@ // CHECK: async.coro.end %[[HDL]] // CHECK: return %[[TOKEN]] : !async.token } + +// CHECK-LABEL: func @caller_allowed_to_block +// CHECK-SAME: () -> f32 +func @caller_allowed_to_block() -> f32 attributes { async.allowed_to_block } { +// CHECK: %[[CONSTANT:.*]] = constant + %c = constant 1.0 : f32 +// CHECK: %[[RETURNED_TO_CALLER:.*]]:2 = call @simple_callee(%[[CONSTANT]]) : (f32) -> (!async.token, !async.value) +// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#0 +// CHECK: async.runtime.await %[[RETURNED_TO_CALLER]]#1 +// CHECK: %[[RETURNED:.*]] = async.runtime.load %[[RETURNED_TO_CALLER]]#1 + %r = call @simple_callee(%c): (f32) -> f32 + +// CHECK: return %[[RETURNED]] : f32 + return %r: f32 +}