diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -586,7 +586,7 @@ // Collect all outlined function inputs. llvm::SetVector functionInputs(execute.dependencies().begin(), execute.dependencies().end()); - assert(execute.operands().empty() && "operands are not supported"); + functionInputs.insert(execute.operands().begin(), execute.operands().end()); getUsedValuesDefinedAbove(execute.body(), functionInputs); // Collect types for the outlined function inputs and outputs. @@ -636,15 +636,26 @@ addSuspensionPoint(coro, coroSave.getResult(0), terminatorOp, suspended, resume, builder); + size_t numDependencies = execute.dependencies().size(); + size_t numOperands = execute.operands().size(); + // Await on all dependencies before starting to execute the body region. builder.setInsertionPointToStart(resume); - for (size_t i = 0; i < execute.dependencies().size(); ++i) + for (size_t i = 0; i < numDependencies; ++i) builder.create(func.getArgument(i)); + // Await on all async value operands and unwrap the payload. + SmallVector unwrappedOperands(numOperands); + for (size_t i = 0; i < numOperands; ++i) { + Value operand = func.getArgument(numDependencies + i); + unwrappedOperands[i] = builder.create(loc, operand).result(); + } + // Map from function inputs defined above the execute op to the function // arguments. BlockAndValueMapping valueMapping; valueMapping.map(functionInputs, func.getArguments()); + valueMapping.map(execute.body().getArguments(), unwrappedOperands); // Clone all operations from the execute operation body into the outlined // function body. @@ -1069,14 +1080,6 @@ return WalkResult::interrupt(); } - // We currently do not support execute operations that have async value - // operands or produce async results. - if (!execute.operands().empty()) { - execute.emitOpError( - "can't outline async.execute op with async value operands"); - return WalkResult::interrupt(); - } - outlinedFunctions.insert(outlineExecuteOp(symbolTable, execute)); return WalkResult::advance(); diff --git a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AsyncToLLVM/convert-to-llvm.mlir @@ -252,3 +252,54 @@ // Emplace result token. // CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]]) +// ----- + +// RUN: mlir-opt %s -split-input-file -convert-async-to-llvm | FileCheck %s + +func @async_value_operands() { + // CHECK: %[[RET:.*]]:2 = call @async_execute_fn + %token, %result = async.execute -> !async.value { + %c0 = constant 123.0 : f32 + async.yield %c0 : f32 + } + + // CHECK: %[[TOKEN:.*]] = call @async_execute_fn_0(%[[RET]]#1) + %token0 = async.execute(%result as %value: !async.value) { + %0 = addf %value, %value : f32 + async.yield + } + + // CHECK: call @mlirAsyncRuntimeAwaitToken(%[[TOKEN]]) + async.await %token0 : !async.token + + return +} + +// Function outlined from the first async.execute operation. +// CHECK-LABEL: func private @async_execute_fn() + +// Function outlined from the second async.execute operation. +// CHECK-LABEL: func private @async_execute_fn_0(%arg0: !llvm.ptr) +// CHECK: %[[TOKEN:.*]] = call @mlirAsyncRuntimeCreateToken() +// CHECK: %[[HDL:.*]] = llvm.call @llvm.coro.begin + +// Suspend coroutine in the beginning. +// CHECK: call @mlirAsyncRuntimeExecute(%[[HDL]], +// CHECK: llvm.call @llvm.coro.suspend + +// Suspend coroutine second time waiting for the async operand. +// CHECK: llvm.call @llvm.coro.save +// CHECK: call @mlirAsyncRuntimeAwaitValueAndExecute(%arg0, %[[HDL]], +// CHECK: llvm.call @llvm.coro.suspend + +// Get the operand value storage, cast to f32 and add the value. +// CHECK: %[[STORAGE:.*]] = call @mlirAsyncRuntimeGetValueStorage(%arg0) +// CHECK: %[[ST_F32:.*]] = llvm.bitcast %[[STORAGE]] +// CHECK: %[[LOADED:.*]] = llvm.load %[[ST_F32]] : !llvm.ptr +// CHECK: %[[CASTED:.*]] = llvm.mlir.cast %[[LOADED]] : !llvm.float to f32 +// CHECK: addf %[[CASTED]], %[[CASTED]] : f32 + +// Emplace result token. +// CHECK: call @mlirAsyncRuntimeEmplaceToken(%[[TOKEN]]) + + diff --git a/mlir/test/mlir-cpu-runner/async-value.mlir b/mlir/test/mlir-cpu-runner/async-value.mlir --- a/mlir/test/mlir-cpu-runner/async-value.mlir +++ b/mlir/test/mlir-cpu-runner/async-value.mlir @@ -44,7 +44,7 @@ // ------------------------------------------------------------------------ // %token2, %result2 = async.execute[%token0] -> !async.value> { %5 = alloc() : memref - %c0 = constant 987.654 : f32 + %c0 = constant 0.25 : f32 store %c0, %5[]: memref async.yield %5 : memref } @@ -53,8 +53,25 @@ // CHECK: Unranked Memref // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = [] - // CHECK-NEXT: [987.654] + // CHECK-NEXT: [0.25] call @print_memref_f32(%7): (memref<*xf32>) -> () + + // ------------------------------------------------------------------------ // + // Memref passed as async.execute operand. + // ------------------------------------------------------------------------ // + %token3 = async.execute(%result2 as %unwrapped : !async.value>) { + %8 = load %unwrapped[]: memref + %9 = addf %8, %8 : f32 + store %9, %unwrapped[]: memref + async.yield + } + async.await %token3 : !async.token + + // CHECK: Unranked Memref + // CHECK-SAME: rank = 0 offset = 0 sizes = [] strides = [] + // CHECK-NEXT: [0.5] + call @print_memref_f32(%7): (memref<*xf32>) -> () + dealloc %6 : memref return