diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_ASYNC_IR_ASYNC_H #define MLIR_DIALECT_ASYNC_IR_ASYNC_H +#include "mlir/Dialect/Async/IR/AsyncTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -22,70 +23,27 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -namespace mlir { -namespace async { - -namespace detail { -struct ValueTypeStorage; -} // namespace detail - //===----------------------------------------------------------------------===// -// Async dialect types. +// Async Dialect //===----------------------------------------------------------------------===// -/// The token type to represent asynchronous operation completion. -class TokenType : public Type::TypeBase { -public: - using Base::Base; -}; - -/// The value type to represent values returned from asynchronous operations. -class ValueType - : public Type::TypeBase { -public: - using Base::Base; - - /// Get or create an async ValueType with the provided value type. - static ValueType get(Type valueType); - - Type getValueType(); -}; - -/// The group type to represent async tokens or values grouped together. -class GroupType : public Type::TypeBase { -public: - using Base::Base; -}; +#include "mlir/Dialect/Async/IR/AsyncOpsDialect.h.inc" //===----------------------------------------------------------------------===// -// LLVM coroutines types. +// Async Dialect Operations //===----------------------------------------------------------------------===// -/// The type identifying a switched-resume coroutine. -class CoroIdType : public Type::TypeBase { -public: - using Base::Base; -}; - -/// The coroutine handle type which is a pointer to the coroutine frame. -class CoroHandleType - : public Type::TypeBase { -public: - using Base::Base; -}; - -/// The coroutine saved state type. -class CoroStateType : public Type::TypeBase { -public: - using Base::Base; -}; +#define GET_OP_CLASSES +#include "mlir/Dialect/Async/IR/AsyncOps.h.inc" //===----------------------------------------------------------------------===// // Helper functions of Async dialect transformations. //===----------------------------------------------------------------------===// -/// Returns true if the type is reference counted. All async dialect types are -/// reference counted at runtime. +namespace mlir { +namespace async { + +/// Returns true if the type is reference counted at runtime. inline bool isRefCounted(Type type) { return type.isa(); } @@ -93,9 +51,4 @@ } // namespace async } // namespace mlir -#define GET_OP_CLASSES -#include "mlir/Dialect/Async/IR/AsyncOps.h.inc" - -#include "mlir/Dialect/Async/IR/AsyncOpsDialect.h.inc" - #endif // MLIR_DIALECT_ASYNC_IR_ASYNC_H diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncDialect.td @@ -0,0 +1,33 @@ +//===- AsyncDialect.td -------------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Async dialect definition. +// +//===----------------------------------------------------------------------===// + +#ifndef ASYNC_DIALECT_TD +#define ASYNC_DIALECT_TD + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Async dialect definitions +//===----------------------------------------------------------------------===// + +def AsyncDialect : Dialect { + let name = "async"; + + let summary = "Types and operations for async dialect"; + let description = [{ + This dialect contains operations for modeling asynchronous execution. + }]; + + let cppNamespace = "::mlir::async"; +} + +#endif // ASYNC_DIALECT_TD diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -13,7 +13,8 @@ #ifndef ASYNC_OPS #define ASYNC_OPS -include "mlir/Dialect/Async/IR/AsyncBase.td" +include "mlir/Dialect/Async/IR/AsyncDialect.td" +include "mlir/Dialect/Async/IR/AsyncTypes.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -75,7 +76,7 @@ Variadic:$operands); let results = (outs Async_TokenType:$token, - Variadic:$results); + Variadic:$results); let regions = (region SizedRegion<1>:$body); let printer = [{ return ::print(p, *this); }]; @@ -398,7 +399,7 @@ }]; let arguments = (ins AnyType:$value, - Async_AnyValueType:$storage); + Async_ValueType:$storage); let assemblyFormat = "$value `,` $storage attr-dict `:` type($storage)"; } @@ -412,7 +413,7 @@ async.value storage. }]; - let arguments = (ins Async_AnyValueType:$storage); + let arguments = (ins Async_ValueType:$storage); let results = (outs AnyType:$result); let assemblyFormat = "$storage attr-dict `:` type($storage)"; } diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.h b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.h @@ -0,0 +1,25 @@ +//===- AsyncTypes.h - Async Dialect Types -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the types for the Async dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ASYNC_IR_ASYNCTYPES_H_ +#define MLIR_DIALECT_ASYNC_IR_ASYNCTYPES_H_ + +#include "mlir/IR/Types.h" + +//===----------------------------------------------------------------------===// +// Async Dialect Types +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Async/IR/AsyncOpsTypes.h.inc" + +#endif // MLIR_DIALECT_ASYNC_IR_ASYNCTYPES_H_ diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td rename from mlir/include/mlir/Dialect/Async/IR/AsyncBase.td rename to mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncBase.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td @@ -1,4 +1,4 @@ -//===- AsyncBase.td ----------------------------------------*- tablegen -*-===// +//===- AsyncTypes.td - Async dialect types -----------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,59 +6,53 @@ // //===----------------------------------------------------------------------===// // -// Base definitions for the `async` dialect. +// This file declares the Async dialect types. // //===----------------------------------------------------------------------===// -#ifndef ASYNC_BASE_TD -#define ASYNC_BASE_TD +#ifndef MLIR_DIALECT_ASYNC_IR_ASYNCTYPES +#define MLIR_DIALECT_ASYNC_IR_ASYNCTYPES -include "mlir/IR/OpBase.td" +include "mlir/Dialect/Async/IR/AsyncDialect.td" //===----------------------------------------------------------------------===// -// Async dialect definitions +// Async Types //===----------------------------------------------------------------------===// -def AsyncDialect : Dialect { - let name = "async"; - - let summary = "Types and operations for async dialect"; - let description = [{ - This dialect contains operations for modeling asynchronous execution. - }]; - - let cppNamespace = "::mlir::async"; +class Async_Type : TypeDef { + let mnemonic = typeMnemonic; } -def Async_TokenType : DialectType()">, "token type">, - BuildableType<"$_builder.getType<::mlir::async::TokenType>()"> { +def Async_TokenType : Async_Type<"Token", "token"> { + let summary = "async token type"; let description = [{ `async.token` is a type returned by asynchronous operations, and it becomes - `ready` when the asynchronous operations that created it is completed. + `available` when the asynchronous operations that created it is completed. }]; } -class Async_ValueType - : DialectType()">, - SubstLeaves<"$_self", - "$_self.cast<::mlir::async::ValueType>().getValueType()", - type.predicate> - ]>, "async value type with " # type.summary # " underlying type"> { +def Async_ValueType : Async_Type<"Value", "value"> { + let summary = "async value type"; let description = [{ `async.value` represents a value returned by asynchronous operations, which may or may not be available currently, but will be available at some point in the future. }]; - Type valueType = type; + let parameters = (ins "Type":$valueType); + let builders = [ + TypeBuilderWithInferredContext<(ins "Type":$valueType), [{ + return Base::get(valueType.getContext(), valueType); + }], [{ + return Base::getChecked($_loc, valueType); + }]> + ]; + let skipDefaultBuilders = 1; } -def Async_GroupType : DialectType()">, "group type">, - BuildableType<"$_builder.getType<::mlir::async::GroupType>()"> { +def Async_GroupType : Async_Type<"Group", "group"> { + let summary = "async group type"; let description = [{ `async.group` represent a set of async tokens or values and allows to execute async operations on all of them together (e.g. wait for the @@ -66,14 +60,10 @@ }]; } -def Async_AnyValueType : DialectType()">, - "async value type">; - -def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType, +def Async_AnyValueOrTokenType : AnyTypeOf<[Async_ValueType, Async_TokenType]>; -def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType, +def Async_AnyAsyncType : AnyTypeOf<[Async_ValueType, Async_TokenType, Async_GroupType]>; @@ -86,30 +76,27 @@ // build a properly typed intermediate IR during the Async to LLVM lowering we // define a separate types for values that can be produced by LLVM intrinsics. -def Async_CoroIdType : DialectType()">, "coro.id type">, - BuildableType<"$_builder.getType<::mlir::async::CoroIdType>()"> { +def Async_CoroIdType : Async_Type<"CoroId", "coro.id"> { + let summary = "switched-resume coroutine identifier"; let description = [{ `async.coro.id` is a type identifying a switched-resume coroutine. }]; } -def Async_CoroHandleType : DialectType()">, "coro.handle type">, - BuildableType<"$_builder.getType<::mlir::async::CoroHandleType>()"> { +def Async_CoroHandleType : Async_Type<"CoroHandle", "coro.handle"> { + let summary = "coroutine handle"; let description = [{ `async.coro.handle` is a handle to the coroutine (pointer to the coroutine frame) that can be passed around to resume or destroy the coroutine. }]; } -def Async_CoroStateType : DialectType()">, "coro.state type">, - BuildableType<"$_builder.getType<::mlir::async::CoroStateType>()"> { +def Async_CoroStateType : Async_Type<"CoroState", "coro.state"> { + let summary = "saved coroutine state"; let description = [{ `async.coro.state` is a saved coroutine state that should be passed to the coroutine suspension operation. }]; } -#endif // ASYNC_BASE_TD +#endif // MLIR_DIALECT_ASYNC_IR_ASYNCTYPES diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -19,96 +19,12 @@ #define GET_OP_LIST #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" >(); - addTypes(); // async types - addTypes(); // coro types -} - -/// Parse a type registered to this dialect. -Type AsyncDialect::parseType(DialectAsmParser &parser) const { - StringRef keyword; - if (parser.parseKeyword(&keyword)) - return Type(); - - if (keyword == "token") - return TokenType::get(getContext()); - - if (keyword == "group") - return GroupType::get(getContext()); - - if (keyword == "value") { - Type ty; - if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { - parser.emitError(parser.getNameLoc(), "failed to parse async value type"); - return Type(); - } - return ValueType::get(ty); - } - - if (keyword == "coro.id") - return CoroIdType::get(getContext()); - - if (keyword == "coro.handle") - return CoroHandleType::get(getContext()); - - if (keyword == "coro.state") - return CoroStateType::get(getContext()); - - parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword; - return Type(); -} - -/// Print a type registered to this dialect. -void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { - TypeSwitch(type) - .Case([&](TokenType) { os << "token"; }) - .Case([&](ValueType valueTy) { - os << "value<"; - os.printType(valueTy.getValueType()); - os << '>'; - }) - .Case([&](GroupType) { os << "group"; }) - .Case([&](CoroIdType) { os << "coro.id"; }) - .Case([&](CoroHandleType) { os << "coro.handle"; }) - .Case([&](CoroStateType) { os << "coro.state"; }) - .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); }); -} - -//===----------------------------------------------------------------------===// -/// ValueType -//===----------------------------------------------------------------------===// - -namespace mlir { -namespace async { -namespace detail { - -// Storage for `async.value` type, the only member is the wrapped type. -struct ValueTypeStorage : public TypeStorage { - ValueTypeStorage(Type valueType) : valueType(valueType) {} - - /// The hash key used for uniquing. - using KeyTy = Type; - bool operator==(const KeyTy &key) const { return key == valueType; } - - /// Construction. - static ValueTypeStorage *construct(TypeStorageAllocator &allocator, - Type valueType) { - return new (allocator.allocate()) - ValueTypeStorage(valueType); - } - - Type valueType; -}; - -} // namespace detail -} // namespace async -} // namespace mlir - -ValueType ValueType::get(Type valueType) { - return Base::get(valueType.getContext(), valueType); + addTypes< +#define GET_TYPEDEF_LIST +#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" + >(); } -Type ValueType::getValueType() { return getImpl()->valueType; } - //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// @@ -376,5 +292,47 @@ return success(); } +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + #define GET_OP_CLASSES #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// TableGen'd type method definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc" + +void ValueType::print(DialectAsmPrinter &printer) const { + printer << getMnemonic(); + printer << "<"; + printer.printType(getValueType()); + printer << '>'; +} + +Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) { + Type ty; + if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { + parser.emitError(parser.getNameLoc(), "failed to parse async value type"); + return Type(); + } + return ValueType::get(ty); +} + +/// Print a type registered to this dialect. +void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { + if (failed(generatedTypePrinter(type, os))) + llvm_unreachable("unexpected 'async' type kind"); +} + +/// Parse a type registered to this dialect. +Type AsyncDialect::parseType(DialectAsmParser &parser) const { + StringRef mnemonic; + if (parser.parseKeyword(&mnemonic)) + return Type(); + + return generatedTypeParser(getContext(), parser, mnemonic); +} diff --git a/mlir/test/Dialect/Async/verify.mlir b/mlir/test/Dialect/Async/verify.mlir --- a/mlir/test/Dialect/Async/verify.mlir +++ b/mlir/test/Dialect/Async/verify.mlir @@ -9,7 +9,7 @@ // ----- func @wrong_async_await_arg_type(%arg0: f32) { - // expected-error @+1 {{'async.await' op operand #0 must be async value type or token type, but got 'f32'}} + // expected-error @+1 {{'async.await' op operand #0 must be async value type or async token type, but got 'f32'}} async.await %arg0 : f32 }