diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -400,13 +400,14 @@ auto fnAttr = (*this)->getAttrOfType("callee"); if (!fnAttr) return emitOpError("requires a 'callee' symbol reference attribute"); - FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); + auto fn = + symbolTable.lookupNearestSymbolFrom(*this, fnAttr); if (!fn) return emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid async function"; // Verify that the operand and result types match the callee. - auto fnType = fn.getFunctionType(); + auto fnType = fn.getFunctionType().cast(); if (fnType.getNumInputs() != getNumOperands()) return emitOpError("incorrect number of operands for callee"); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -840,8 +840,9 @@ target.addDynamicallyLegalOp( [coros](Operation *op) { + auto exec = op->getParentOfType(); auto func = op->getParentOfType(); - return coros->find(func) == coros->end(); + return exec || coros->find(func) == coros->end(); }); } diff --git a/mlir/test/Dialect/Async/async-to-async-runtime.mlir b/mlir/test/Dialect/Async/async-to-async-runtime.mlir --- a/mlir/test/Dialect/Async/async-to-async-runtime.mlir +++ b/mlir/test/Dialect/Async/async-to-async-runtime.mlir @@ -455,3 +455,44 @@ // CHECK-SAME: !async.value // CHECK: async.coro.suspend %[[SAVED]] // CHECK-SAME: ^[[SUSPEND:.*]], ^[[RESUME:.*]], ^[[CLEANUP:.*]] + +// ----- +// Async Functions can call regular func + +// CHECK-LABEL: @async_func_call_func +func.func private @regular_func() -> !async.value + +async.func @async_func_call_func() -> !async.token { + %0 = async.call @regular_func() : () -> !async.value + %1 = async.await %0 : !async.value + return +} +// CHECK: %[[RES:.*]] = call @regular_func() : () -> !async.value + +// ----- +// Async execute inside async func + +// CHECK-LABEL: @execute_in_async_func +async.func @execute_in_async_func(%arg0: f32, %arg1: memref<1xf32>) + -> !async.token { + %token = async.execute { + %c0 = arith.constant 0 : index + memref.store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + async.await %token : !async.token + return +} +// Call outlind async execute Function +// CHECK: %[[RES:.*]] = call @async_execute_fn( +// CHECK-SAME: %[[VALUE:arg[0-9]+]], +// CHECK-SAME: %[[MEMREF:arg[0-9]+]] +// CHECK-SAME: ) : (f32, memref<1xf32>) -> !async.token + +// Function outlined from the async.execute operation. +// CHECK-LABEL: func private @async_execute_fn( +// CHECK-SAME: %[[VALUE:arg[0-9]+]]: f32, +// CHECK-SAME: %[[MEMREF:arg[0-9]+]]: memref<1xf32> +// CHECK-SAME: ) -> !async.token +// CHECK: %[[CST:.*]] = arith.constant 0 : index +// CHECK: memref.store %[[VALUE]], %[[MEMREF]][%[[CST]]]