diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -75,7 +75,6 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; } @@ -94,4 +93,47 @@ let verifier = [{ return ::verify(*this); }]; } +def Async_AwaitOp : Async_Op<"await", [NoSideEffect]> { + let summary = "waits for the argument to become ready"; + let description = [{ + The `async.await` operation waits until the argument becomes ready, and for + the `async.value` arguments it unwraps the underlying value + + Example: + + ```mlir + %0 = ... : !async.token + async.await %0 : !async.token + + %1 = ... : !async.value + %2 = async.await %1 : !async.value + ``` + }]; + + let arguments = (ins Async_AnyValueOrTokenType:$operand); + let results = (outs Optional:$result); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"mlir::OpBuilder &builder, OperationState &result," + "Value operand, ArrayRef attrs = {}">, + ]; + + let extraClassDeclaration = [{ + Optional getResultType() { + if (getResultTypes().empty()) return None; + return getResultTypes()[0]; + } + }]; + + let assemblyFormat = [{ + attr-dict $operand `:` custom( + type($operand), type($result) + ) + }]; + + let verifier = [{ return ::verify(*this); }]; +} + #endif // ASYNC_OPS diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -250,5 +250,54 @@ return success(); } +//===----------------------------------------------------------------------===// +/// AwaitOp +//===----------------------------------------------------------------------===// + +void AwaitOp::build(OpBuilder &, OperationState &result, Value operand, + ArrayRef attrs) { + result.addOperands({operand}); + result.attributes.append(attrs.begin(), attrs.end()); + + // Add unwrapped async.value type to the returned values types. + if (auto valueType = operand.getType().dyn_cast()) + result.addTypes(valueType.getValueType()); +} + +static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, + Type &resultType) { + if (parser.parseType(operandType)) + return failure(); + + // Add unwrapped async.value type to the returned values types. + if (auto valueType = operandType.dyn_cast()) + resultType = valueType.getValueType(); + + return success(); +} + +static void printAwaitResultType(OpAsmPrinter &p, Type operandType, + Type resultType) { + p << operandType; +} + +static LogicalResult verify(AwaitOp op) { + Type argType = op.operand().getType(); + + // Awaiting on a token does not have any results. + if (argType.isa() && !op.getResultTypes().empty()) + return op.emitOpError("awaiting on a token must have empty result"); + + // Awaiting on a value unwraps the async value type. + if (auto value = argType.dyn_cast()) { + if (*op.getResultType() != value.getValueType()) + return op.emitOpError() + << "result type " << *op.getResultType() + << " does not match async value type " << value.getValueType(); + } + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -106,3 +106,17 @@ %token4 = async.execute [] { async.yield } return } + +// CHECK-LABEL: @await_token +func @await_token(%arg0: !async.token) { + // CHECK: async.await %arg0 + async.await %arg0 : !async.token + return +} + +// CHECK-LABEL: @await_value +func @await_value(%arg0: !async.value) -> f32 { + // CHECK: async.await %arg0 + %0 = async.await %arg0 : !async.value + return %0 : f32 +} diff --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/verify.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// FileCheck test must have at least one CHECK statement. +// CHECK-LABEL: @no_op +func @no_op(%arg0: !async.token) { + return +} + +// ----- + +func @wrong_async_await_arg_type(%arg0: f32) { + // expected-error @+1 {{'async.await' op operand #0 must be async value type or token type, but got 'f32'}} + async.await %arg0 : f32 +} + +// ----- + +func @wrong_async_await_result_type(%arg0: !async.value) { + // expected-error @+1 {{'async.await' op result type 'f64' does not match async value type 'f32'}} + %0 = "async.await"(%arg0): (!async.value) -> f64 +}