diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -22,12 +22,28 @@ namespace mlir { namespace async { +namespace detail { +struct ValueTypeStorage; +} // namespace detail + /// The token type to represent asynchronous operation completion. class TokenType : public Type::TypeBase { public: using Base::Base; }; +/// The value type to represent values returned from asynchronous operations. +class ValueType + : public Type::TypeBase { +public: + using Base::Base; + + /// Get or create an async ValueType with the provided value type. + static ValueType get(Type valueType); + + Type getValueType(); +}; + } // namespace async } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td @@ -39,4 +39,24 @@ }]; } +class Async_ValueType + : DialectType()">, + SubstLeaves<"$_self", + "$_self.cast<::mlir::async::ValueType>().getValueType()", + type.predicate> + ]>, "async value type with " # type.description # " underlying type"> { + let typeDescription = [{ + `async.value` represents a value returned by asynchronous operations, + which may or may not be available currently, but will be available at some + point in the future. + }]; + + Type valueType = type; +} + +def Async_AnyValueType : Type()">, + "async value type">; + #endif // ASYNC_BASE_TD diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -40,24 +40,24 @@ state). All dependencies must be made explicit with async execute arguments (`async.token` or `async.value`). - Example: - ```mlir - %0 = async.execute { - "compute0"(...) - async.yield - } : !async.token + %done, %values = async.execute { + %0 = "compute0"(...) : !some.type + async.yield %1 : f32 + } : !async.token, !async.value - %1 = "compute1"(...) + %1 = "compute1"(...) : !some.type ``` }]; // TODO: Take async.tokens/async.values as arguments. let arguments = (ins ); - let results = (outs Async_TokenType:$done); + let results = (outs Async_TokenType:$done, + Variadic:$values); let regions = (region SizedRegion<1>:$body); - let assemblyFormat = "$body attr-dict `:` type($done)"; + let printer = [{ return ::mlir::async::print(p, *this); }]; + let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }]; } def Async_YieldOp : @@ -71,6 +71,8 @@ let arguments = (ins Variadic:$operands); let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; + + let verifier = [{ return ::mlir::async::verify(*this); }]; } #endif // ASYNC_OPS diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -19,8 +19,8 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/raw_ostream.h" -using namespace mlir; -using namespace mlir::async; +namespace mlir { +namespace async { void AsyncDialect::initialize() { addOperations< @@ -28,6 +28,7 @@ #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" >(); addTypes(); + addTypes(); } /// Parse a type registered to this dialect. @@ -39,6 +40,15 @@ if (keyword == "token") return TokenType::get(getContext()); + if (keyword == "value") { + Type ty; + if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { + parser.emitError(parser.getNameLoc(), "failed to parse async value type"); + return Type(); + } + return ValueType::get(ty); + } + parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword; return Type(); } @@ -46,9 +56,113 @@ /// Print a type registered to this dialect. void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { TypeSwitch(type) - .Case([&](Type) { os << "token"; }) + .Case([&](TokenType) { os << "token"; }) + .Case([&](ValueType valueTy) { + os << "value<"; + os.printType(valueTy.getValueType()); + os << '>'; + }) .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); }); } +//===----------------------------------------------------------------------===// +/// ValueType +//===----------------------------------------------------------------------===// + +namespace detail { + +// Storage for `async.value` type, the only member is the wrapped type. +struct ValueTypeStorage : public TypeStorage { + ValueTypeStorage(Type valueType) : valueType(valueType) {} + + /// The hash key used for uniquing. + using KeyTy = Type; + bool operator==(const KeyTy &key) const { return key == valueType; } + + /// Construction. + static ValueTypeStorage *construct(TypeStorageAllocator &allocator, + Type valueType) { + return new (allocator.allocate()) + ValueTypeStorage(valueType); + } + + Type valueType; +}; + +} // namespace detail + +ValueType ValueType::get(Type valueType) { + return Base::get(valueType.getContext(), valueType); +} + +Type ValueType::getValueType() { return getImpl()->valueType; } + +//===----------------------------------------------------------------------===// +// YieldOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(YieldOp op) { + // Get the underlying value types from async values returned from the + // parent `async.execute` operation. + auto execute = op.getParentOfType(); + auto types = llvm::map_range(execute.values(), [](const OpResult &result) { + return result.getType().cast().getValueType(); + }); + + if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin())) + return op.emitOpError("Operand types do not match the types returned from " + "the parent ExecuteOp"); + + return success(); +} + +//===----------------------------------------------------------------------===// +/// ExecuteOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, ExecuteOp op) { + p << "async.execute "; + p.printRegion(op.body()); + p.printOptionalAttrDict(op.getAttrs()); + p << " : "; + p.printType(op.done().getType()); + if (!op.values().empty()) + p << ", "; + llvm::interleaveComma(op.values(), p, [&](const OpResult &result) { + p.printType(result.getType()); + }); +} + +static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { + MLIRContext *ctx = result.getContext(); + + // Parse asynchronous region. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}, + /*enableNameShadowing=*/false)) + return failure(); + + // Parse operation attributes. + NamedAttrList attrs; + if (parser.parseOptionalAttrDict(attrs)) + return failure(); + result.addAttributes(attrs); + + // Parse result types. + SmallVector resultTypes; + if (parser.parseColonTypeList(resultTypes)) + return failure(); + + // First result type must be an async token type. + if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx)) + return failure(); + parser.addTypesToList(resultTypes, result.types); + + return success(); +} + +} // namespace async +} // namespace mlir + #define GET_OP_CLASSES #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" diff --git a/mlir/test/Dialect/Async/ops.mlir b/mlir/test/Dialect/Async/ops.mlir --- a/mlir/test/Dialect/Async/ops.mlir +++ b/mlir/test/Dialect/Async/ops.mlir @@ -1,16 +1,46 @@ // RUN: mlir-opt %s | FileCheck %s -// CHECK-LABEL: @identity -func @identity(%arg0 : !async.token) -> !async.token { +// CHECK-LABEL: @identity_token +func @identity_token(%arg0 : !async.token) -> !async.token { // CHECK: return %arg0 : !async.token return %arg0 : !async.token } +// CHECK-LABEL: @identity_value +func @identity_value(%arg0 : !async.value) -> !async.value { + // CHECK: return %arg0 : !async.value + return %arg0 : !async.value +} + // CHECK-LABEL: @empty_async_execute func @empty_async_execute() -> !async.token { - %0 = async.execute { + %done = async.execute { async.yield } : !async.token - return %0 : !async.token + // CHECK: return %done : !async.token + return %done : !async.token +} + +// CHECK-LABEL: @return_async_value +func @return_async_value() -> !async.value { + %done, %values = async.execute { + %cst = constant 1.000000e+00 : f32 + async.yield %cst : f32 + } : !async.token, !async.value + + // CHECK: return %values : !async.value + return %values : !async.value +} + +// CHECK-LABEL: @return_async_values +func @return_async_values() -> (!async.value, !async.value) { + %done, %values:2 = async.execute { + %cst1 = constant 1.000000e+00 : f32 + %cst2 = constant 2.000000e+00 : f32 + async.yield %cst1, %cst2 : f32, f32 + } : !async.token, !async.value, !async.value + + // CHECK: return %values#0, %values#1 : !async.value, !async.value + return %values#0, %values#1 : !async.value, !async.value }