diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -840,8 +840,9 @@ target.addDynamicallyLegalOp( [coros](Operation *op) { + auto exec = op->getParentOfType(); auto func = op->getParentOfType(); - return coros->find(func) == coros->end(); + return exec || coros->find(func) == coros->end(); }); } diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -455,3 +455,31 @@ // CHECK-SAME: !async.value // CHECK: async.coro.suspend %[[SAVED]] // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + +// ----- +// Async execute inside async func + +// CHECK-LABEL: @execute_in_async_func +async.func @execute_in_async_func(%arg0: f32, %arg1: memref<1xf32>) + -> !async.token { + %token = async.execute { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + async.await %token : !async.token + return +} +// Call outlind async execute Function +// CHECK: %[[RES:.*]] = call @async_execute_fn( +// CHECK-SAME: %[[VALUE:arg[0-9]+]], +// CHECK-SAME: %[[MEMREF:arg[0-9]+]] +// CHECK-SAME: ) : (f32, memref<1xf32>) -> !async.token + +// Function outlined from the async.execute operation. +// CHECK-LABEL: func private @async_execute_fn( +// CHECK-SAME: %[[VALUE:arg[0-9]+]]: f32, +// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<1xf32> +// CHECK-SAME: ) -> !async.token +// CHECK: %[[CST:.*]] = arith.constant 0 : index +// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]] diff --git a/mlir/test/mlir-cpu-runner/async-func.mlir b/mlir/test/mlir-cpu-runner/async-func.mlir --- a/mlir/test/mlir-cpu-runner/async-func.mlir +++ b/mlir/test/mlir-cpu-runner/async-func.mlir @@ -64,6 +64,19 @@ return } +async.func @async_execute_in_async_func(%arg0 : !async.value>) -> !async.token { + %token0 = async.execute { + %unwrapped = async.await %arg0 : !async.value> + %0 = memref.load %unwrapped[] : memref + %1 = arith.addf %0, %0 : f32 + memref.store %1, %unwrapped[] : memref + async.yield + } + + async.await %token0 : !async.token + return +} + func.func @main() { %false = arith.constant 0 : i1 @@ -140,6 +153,17 @@ // CHECK-NEXT: [0.5] call @printMemrefF32(%6) : (memref<*xf32>) -> () + // ------------------------------------------------------------------------ // + // async.execute inside async.func + // ------------------------------------------------------------------------ // + %token4 = async.call @async_execute_in_async_func(%result1) : (!async.value>) -> !async.token + async.await %token4 : !async.token + + // CHECK: Unranked Memref + // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = [] + // CHECK-NEXT: [1] + call @printMemrefF32(%6) : (memref<*xf32>) -> () + memref.dealloc %5 : memref return