diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -75,7 +75,6 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; - let verifier = [{ return ::verify(*this); }]; } @@ -94,4 +93,45 @@ let verifier = [{ return ::verify(*this); }]; } +def Async_AwaitOp : Async_Op<"await", [NoSideEffect]> { + let summary = "waits for the argument to become ready"; + let description = [{ + The `async.await` operation waits until the argument become ready, and for + the `async.value` arguments it unwraps the underlying value + + Example: + + ```mlir + %0 = ... : !async.token + async.await %0 : !async.token + + %1 = ... : !async.value + %2 = async.await %1 : !async.value + ``` + }]; + + let arguments = (ins Async_AnyValueOrTokenType:$operand); + let results = (outs Variadic:$value); + + let skipDefaultBuilders = 1; + + let builders = [ + OpBuilder<"mlir::OpBuilder &builder, OperationState &result," + "Value token, ArrayRef attrs = {}">, + OpBuilder<"mlir::OpBuilder &builder, OperationState &result," + "Type type, Value value, ArrayRef attrs = {}">, + ]; + + let extraClassDeclaration = [{ + Optional getResultType() { + if (getResultTypes().empty()) return None; + return getResultTypes()[0]; + } + }]; + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; + let verifier = [{ return ::verify(*this); }]; +} + #endif // ASYNC_OPS diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -250,5 +250,72 @@ return success(); } +//===----------------------------------------------------------------------===// +/// AwaitOp +//===----------------------------------------------------------------------===// + +void AwaitOp::build(OpBuilder &, OperationState &result, Value token, + ArrayRef attrs) { + result.addOperands({token}); + result.attributes.append(attrs.begin(), attrs.end()); +} + +void AwaitOp::build(OpBuilder &, OperationState &result, Type type, Value value, + ArrayRef attrs) { + result.addOperands({value}); + result.addTypes({type}); + result.attributes.append(attrs.begin(), attrs.end()); +} + +static void print(OpAsmPrinter &p, AwaitOp op) { + p << op.getOperationName(); + p << " "; + p.printOperand(op.operand()); + p << " : "; + p.printType(op.operand().getType()); + p.printOptionalAttrDict(op.getAttrs()); +} + +static ParseResult parseAwaitOp(OpAsmParser &parser, OperationState &result) { + // Parse async operand and type. + OpAsmParser::OperandType arg; + Type argType; + if (parser.parseOperand(arg) || parser.parseColonType(argType)) + return failure(); + if (parser.resolveOperands(arg, argType, result.operands)) + return failure(); + + // Add unwrapped async.value type to the result list. + if (auto value = argType.dyn_cast()) { + parser.addTypeToList(value.getValueType(), result.types); + } + + // Parse operation attributes. + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs)) + return failure(); + result.addAttributes(attrs); + + return success(); +} + +static LogicalResult verify(AwaitOp op) { + Type argType = op.operand().getType(); + + // Awaiting on a token does not have any results. + if (argType.isa() && !op.getResultTypes().empty()) + return op.emitOpError("awaiting on a token must have empty result"); + + // Awaiting on a value unwrapps the async value type. + if (auto value = argType.dyn_cast()) { + if (op.getResultTypes().size() != 1) + return op.emitOpError("awaiting on a value must have one result"); + if (*op.getResultType() != value.getValueType()) + return op.emitOpError("result type does not match async value type"); + } + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -106,3 +106,17 @@ %token4 = async.execute [] { async.yield } return } + +// CHECK-LABEL: @await_token +func @await_token(%arg0: !async.token) { + // CHECK: async.await %arg0 + async.await %arg0 : !async.token + return +} + +// CHECK-LABEL: @await_value +func @await_value(%arg0: !async.value) -> f32 { + // CHECK: async.await %arg0 + %0 = async.await %arg0 : !async.value + return %0 : f32 +} diff --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Async/verify.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// FileCheck test must have at least one CHECK statement. +// CHECK-LABEL: @no_op +func @no_op(%arg0: !async.token) { + return +} + +// ----- + +func @wrong_async_await_arg_type(%arg0: f32) { + // expected-error @+1 {{'async.await' op operand #0 must be async value type or token type, but got 'f32'}} + async.await %arg0 : f32 +}