diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td @@ -56,7 +56,11 @@ Type valueType = type; } -def Async_AnyValueType : Type()">, - "async value type">; +def Async_AnyValueType : DialectType()">, + "async value type">; + +def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType, + Async_TokenType]>; #endif // ASYNC_BASE_TD diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -40,24 +40,40 @@ state). All dependencies must be made explicit with async execute arguments (`async.token` or `async.value`). + `async.execute` operation unwraps asynchronous value arguments, and passes + values of the underlying type to the attached body region. + + Example: + ```mlir - %done, %values = async.execute { - %0 = "compute0"(...) : !some.type - async.yield %1 : f32 - } : !async.token, !async.value + %token = ... : !async.token + %value = ... : !async.value + + %done, %values = + async.execute (%token: !async.token, %value: !async.value) + -> (!async.token, !async.value) + { + ^bb0(%arg0: f32): + %0 = "compute0"(...) : !some.type + async.yield %0 : !some.type + } %1 = "compute1"(...) : !some.type ``` + + In the example above asynchronous execution starts only after token and + value arguments become ready. }]; - // TODO: Take async.tokens/async.values as arguments. - let arguments = (ins ); + let arguments = (ins Variadic:$arguments); let results = (outs Async_TokenType:$done, Variadic:$values); let regions = (region SizedRegion<1>:$body); let printer = [{ return ::mlir::async::print(p, *this); }]; let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }]; + + let verifier = [{ return ::mlir::async::verify(*this); }]; } def Async_YieldOp : diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Traits.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Transforms/InliningUtils.h" @@ -121,43 +122,99 @@ //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ExecuteOp op) { - p << "async.execute "; - p.printRegion(op.body()); + p << "async.execute"; p.printOptionalAttrDict(op.getAttrs()); - p << " : "; - p.printType(op.done().getType()); - if (!op.values().empty()) - p << ", "; - llvm::interleaveComma(op.values(), p, [&](const OpResult &result) { - p.printType(result.getType()); - }); + p << " ("; + llvm::interleaveComma(llvm::zip(op.arguments(), op.getOperandTypes()), p, + [&](const auto &tuple) { + p << std::get<0>(tuple) << ": " << std::get<1>(tuple); + }); + p << ")"; + p.printArrowTypeList(op.getResultTypes()); + p.printRegion(op.body()); } static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { MLIRContext *ctx = result.getContext(); - // Parse asynchronous region. - Region *body = result.addRegion(); - if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}, - /*enableNameShadowing=*/false)) - return failure(); - // Parse operation attributes. NamedAttrList attrs; if (parser.parseOptionalAttrDict(attrs)) return failure(); result.addAttributes(attrs); - // Parse result types. + SmallVector entryArgs; + SmallVector argAttrs; + SmallVector resultAttrs; + SmallVector argTypes; SmallVector resultTypes; - if (parser.parseColonTypeList(resultTypes)) + + // Parse the async.execute signature which is similar to function signature. + auto signatureLocation = parser.getCurrentLocation(); + bool allowVariadic = false; + bool isVariadic = false; + if (impl::parseFunctionSignature(parser, allowVariadic, entryArgs, argTypes, + argAttrs, isVariadic, resultTypes, + resultAttrs)) return failure(); + // async.execute does not support arg or result attributes, verify that they + // are empty. + auto emptyAttrs = [](const NamedAttrList &list) { return list.empty(); }; + + if (!llvm::all_of(argAttrs, emptyAttrs)) + return parser.emitError(signatureLocation) + << "argument attributes are not supported"; + + if (!llvm::all_of(resultAttrs, emptyAttrs)) + return parser.emitError(signatureLocation) + << "result attributes are not supported"; + + parser.resolveOperands(entryArgs, argTypes, signatureLocation, + result.operands); + // First result type must be an async token type. if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx)) - return failure(); + return parser.emitError(signatureLocation) + << "first return type must be an async.token"; parser.addTypesToList(resultTypes, result.types); + // Parse asynchronous region. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}, + /*enableNameShadowing=*/false)) + return failure(); + + return success(); +} + +static LogicalResult verify(ExecuteOp op) { + // Unwrap async.execute operation arguments: + // - `async.token` arguments are skipped + // - `async.value` arguments unwrapped to `T` + SmallVector unwrappedTypes; + for (Type type : op.getOperandTypes()) { + if (auto token = type.dyn_cast()) + continue; + + if (auto value = type.dyn_cast()) { + unwrappedTypes.push_back(value.getValueType()); + continue; + } + + llvm_unreachable("unexpected 'async' type kind"); + } + + // Verify that unwrapped argument types matches the body region arguments. + if (llvm::size(unwrappedTypes) != llvm::size(op.body().getArgumentTypes())) + return op.emitOpError("the number of async body region arguments does not " + "match the number of execute operation arguments"); + + if (!std::equal(unwrappedTypes.begin(), unwrappedTypes.end(), + op.body().getArgumentTypes().begin())) + return op.emitOpError("async body region argument types do not match the " + "execute operation arguments types"); + return success(); } diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt %s | FileCheck %s +// RUN: mlir-opt %s | FileCheck %s // CHECK-LABEL: @identity_token -func @identity_token(%arg0 : !async.token) -> !async.token { +func @identity_token(%arg0: !async.token) -> !async.token { // CHECK: return %arg0 : !async.token return %arg0 : !async.token } @@ -14,9 +14,9 @@ // CHECK-LABEL: @empty_async_execute func @empty_async_execute() -> !async.token { - %done = async.execute { + %done = async.execute () -> !async.token { async.yield - } : !async.token + } // CHECK: return %done : !async.token return %done : !async.token @@ -24,23 +24,72 @@ // CHECK-LABEL: @return_async_value func @return_async_value() -> !async.value { - %done, %values = async.execute { + %done, %values = async.execute () -> (!async.token, !async.value) { %cst = constant 1.000000e+00 : f32 async.yield %cst : f32 - } : !async.token, !async.value + } // CHECK: return %values : !async.value return %values : !async.value } +// CHECK-LABEL: @return_captured_value +func @return_captured_value() -> !async.token { + %cst = constant 1.000000e+00 : f32 + + %done, %value = async.execute () -> (!async.token, !async.value) { + async.yield %cst : f32 + } + + // CHECK: return %done : !async.token + return %done : !async.token +} + // CHECK-LABEL: @return_async_values func @return_async_values() -> (!async.value, !async.value) { - %done, %values:2 = async.execute { + %done, %values:2 = async.execute () -> (!async.token, !async.value, !async.value) { %cst1 = constant 1.000000e+00 : f32 %cst2 = constant 2.000000e+00 : f32 async.yield %cst1, %cst2 : f32, f32 - } : !async.token, !async.value, !async.value + } // CHECK: return %values#0, %values#1 : !async.value, !async.value return %values#0, %values#1 : !async.value, !async.value } + +// CHECK-LABEL: @async_token_operands +func @async_token_operands(%arg0: !async.token) -> !async.token { + // CHECK: async.execute (%arg0: !async.token) -> !async.token { + %done = async.execute (%arg0 : !async.token) -> !async.token { + async.yield + } + + // CHECK: return %done : !async.token + return %done : !async.token +} + +// CHECK-LABEL: @async_value_operands +func @async_value_operands(%arg0: !async.value) -> !async.token { + // CHECK: async.execute (%arg0: !async.value) -> !async.token + %done = async.execute (%arg0: !async.value) -> !async.token { + ^bb0(%0: f32): + async.yield + } + + // CHECK: return %done : !async.token + return %done : !async.token +} + +// CHECK-LABEL: @async_token_and_value_operands +func @async_token_and_value_operands(%arg0: !async.token, %arg1: !async.value) -> !async.token { + // CHECK: async.execute (%arg0: !async.token, %arg1: !async.value) + // CHECK-SAME: -> (!async.token, !async.value) + %done, %value = async.execute (%arg0: !async.token, %arg1: !async.value) -> (!async.token, !async.value) { + ^bb0(%0: f32): + async.yield %0 : f32 + } + + // CHECK: return %done : !async.token + return %done : !async.token +} + diff --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/verify.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +// FileCheck test must have at least one CHECK statement. +func @no_op(%arg0: !async.token) { + // CHECK: @no_op + return +} + +// ----- + +func @wrong_execute_arg(%arg0: f32) { + // expected-error @+1 {{'async.execute' op operand #0 must be async value type or token type, but got 'f32'}} + %done = async.execute (%arg0: f32) -> !async.token { + ^bb0(%0: f32): + async.yield + } +} + +// ----- + +func @wrong_async_body_num_args(%arg0: !async.value) { + // expected-error @+1 {{'async.execute' op the number of async body region arguments does not match the number of execute operation arguments}} + %done = async.execute (%arg0: !async.value) -> !async.token { + ^bb0(%0: f32, %1: f32): + async.yield + } +} + +// ----- + +func @wrong_async_body_arg_type(%arg0: !async.value) { + // expected-error @+1 {{'async.execute' op async body region argument types do not match the execute operation arguments types}} + %done = async.execute (%arg0: !async.value) -> !async.token { + ^bb0(%0 : f64): + async.yield + } +}