diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -462,14 +462,15 @@ OpBuilder moduleBuilder(module.getBody()->getTerminator()); - // Get values captured by the async region - llvm::SetVector usedAbove; - getUsedValuesDefinedAbove(execute.body(), usedAbove); - - // Collect types of the captured values. - auto usedAboveTypes = - llvm::map_range(usedAbove, [](Value value) { return value.getType(); }); - SmallVector inputTypes(usedAboveTypes.begin(), usedAboveTypes.end()); + // Collect all outlined function inputs. + llvm::SetVector functionInputs(execute.dependencies().begin(), + execute.dependencies().end()); + getUsedValuesDefinedAbove(execute.body(), functionInputs); + + // Collect types for the outlined function inputs and outputs. + auto typesRange = llvm::map_range( + functionInputs, [](Value value) { return value.getType(); }); + SmallVector inputTypes(typesRange.begin(), typesRange.end()); auto outputTypes = execute.getResultTypes(); auto funcType = moduleBuilder.getFunctionType(inputTypes, outputTypes); @@ -510,14 +511,19 @@ Block *resume = addSuspensionPoint(coro, coroSave.getResult(0), entryBlock->getTerminator()); - // Map from values defined above the execute op to the function arguments. + // Await on all dependencies before starting to execute the body region. + builder.setInsertionPointToStart(resume); + for (size_t i = 0; i < execute.dependencies().size(); ++i) + builder.create(loc, func.getArgument(i)); + + // Map from function inputs defined above the execute op to the function + // arguments. BlockAndValueMapping valueMapping; - valueMapping.map(usedAbove, func.getArguments()); + valueMapping.map(functionInputs, func.getArguments()); // Clone all operations from the execute operation body into the outlined // function body, and replace all `async.yield` operations with a call // to async runtime to emplace the result token. - builder.setInsertionPointToStart(resume); for (Operation &op : execute.body().getOps()) { if (isa(op)) { builder.create(loc, kEmplaceToken, Type(), coro.asyncToken); @@ -528,9 +534,9 @@ // Replace the original `async.execute` with a call to outlined function. OpBuilder callBuilder(execute); - SmallVector usedAboveArgs(usedAbove.begin(), usedAbove.end()); - auto callOutlinedFunc = callBuilder.create( - loc, func.getName(), execute.getResultTypes(), usedAboveArgs); + auto callOutlinedFunc = + callBuilder.create(loc, func.getName(), execute.getResultTypes(), + functionInputs.getArrayRef()); execute.replaceAllUsesWith(callOutlinedFunc.getResults()); execute.erase(); @@ -673,13 +679,11 @@ llvm::DenseMap outlinedFunctions; WalkResult outlineResult = module.walk([&](ExecuteOp execute) { - // We currently do not support execute operations that take async - // token dependencies, async value arguments or produce async results. - if (!execute.dependencies().empty() || !execute.operands().empty() || - !execute.results().empty()) { - execute.emitOpError( - "Can't outline async.execute op with async dependencies, arguments " - "or returned async results"); + // We currently do not support execute operations that have async value + // operands or produce async results. + if (!execute.operands().empty() || !execute.results().empty()) { + execute.emitOpError("can't outline async.execute op with async value " + "operands or returned async results"); return WalkResult::interrupt(); } diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -15,7 +15,7 @@ } // Function outlined from the async.execute operation. -// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>) +// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>) // CHECK-SAME: -> !llvm.ptr attributes {sym_visibility = "private"} // Create token for return op, and mark a function as a coroutine. @@ -79,7 +79,7 @@ } // Function outlined from the inner async.execute operation. -// CHECK: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index) +// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>, %arg2: index) // CHECK-SAME: -> !llvm.ptr attributes {sym_visibility = "private"} // CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken() // CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin @@ -89,7 +89,7 @@ // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]]) // Function outlined from the outer async.execute operation. -// CHECK: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32) +// CHECK-LABEL: func @async_execute_fn_0(%arg0: f32, %arg1: memref<1xf32>, %arg2: f32) // CHECK-SAME: -> !llvm.ptr attributes {sym_visibility = "private"} // CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken() // CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin @@ -108,4 +108,52 @@ // CHECK: store %arg2, %arg1[%c0] : memref<1xf32> // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]]) +// ----- + +// CHECK-LABEL: async_execute_token_dependency +func @async_execute_token_dependency(%arg0: f32, %arg1: memref<1xf32>) { + // CHECK: %0 = call @async_execute_fn(%arg0, %arg1) + %token = async.execute { + %c0 = constant 0 : index + store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + // CHECK: %1 = call @async_execute_fn_0(%0, %arg0, %arg1) + %token_0 = async.execute [%token] { + %c0 = constant 0 : index + store %arg0, %arg1[%c0] : memref<1xf32> + async.yield + } + return +} + +// Function outlined from the first async.execute operation. +// CHECK-LABEL: func @async_execute_fn(%arg0: f32, %arg1: memref<1xf32>) +// CHECK-SAME: -> !llvm.ptr attributes {sym_visibility = "private"} +// CHECK: %[[RET_0:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[HDL_0:.*]] = llvm.call @llvm.coro.begin +// CHECK: call @mlirAsyncRuntimeExecute +// CHECK: llvm.call @llvm.coro.suspend +// CHECK: store %arg0, %arg1[%c0] : memref<1xf32> +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_0]]) + +// Function outlined from the second async.execute operation with dependency. +// CHECK-LABEL: func @async_execute_fn_0(%arg0: !llvm.ptr, %arg1: f32, %arg2: memref<1xf32>) +// CHECK-SAME: -> !llvm.ptr attributes {sym_visibility = "private"} +// CHECK: %[[RET_1:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[HDL_1:.*]] = llvm.call @llvm.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL_1]], +// CHECK: llvm.call @llvm.coro.suspend + +// Suspend coroutine second time waiting for the completion of token dependency. +// CHECK: llvm.call @llvm.coro.save +// CHECK: call @mlirAsyncRuntimeAwaitTokenAndExecute(%arg0, %[[HDL_1]], +// CHECK: llvm.call @llvm.coro.suspend + +// Emplace result token after second resumption. +// CHECK: store %arg1, %arg2[%c0] : memref<1xf32> +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[RET_1]]) + diff --git a/mlir/test/mlir-cpu-runner/async.mlir b/mlir/test/mlir-cpu-runner/async.mlir --- a/mlir/test/mlir-cpu-runner/async.mlir +++ b/mlir/test/mlir-cpu-runner/async.mlir @@ -41,8 +41,15 @@ call @mlirAsyncRuntimePrintCurrentThreadId(): () -> () call @print_memref_f32(%U): (memref<*xf32>) -> () - %inner = async.execute { + // No op async region to create a token for testing async dependency. + %noop = async.execute { // CHECK: Current thread id: [[THREAD1:.*]] + call @mlirAsyncRuntimePrintCurrentThreadId(): () -> () + async.yield + } + + %inner = async.execute [%noop] { + // CHECK: Current thread id: [[THREAD2:.*]] // CHECK: [1, 2, 3, 0] store %c3, %A[%i2]: memref<4xf32> call @mlirAsyncRuntimePrintCurrentThreadId(): () -> () @@ -52,7 +59,7 @@ } async.await %inner : !async.token - // CHECK: Current thread id: [[THREAD2:.*]] + // CHECK: Current thread id: [[THREAD3:.*]] // CHECK: [1, 2, 3, 4] store %c4, %A[%i3]: memref<4xf32> call @mlirAsyncRuntimePrintCurrentThreadId(): () -> ()