diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -1370,10 +1370,10 @@ ## Type Definitions -MLIR defines the TypeDef class hierarchy to enable generation of data types -from their specifications. A type is defined by specializing the TypeDef -class with concrete contents for all the fields it requires. For example, an -integer type could be defined as: +MLIR defines the TypeDef class hierarchy to enable generation of data types from +their specifications. A type is defined by specializing the TypeDef class with +concrete contents for all the fields it requires. For example, an integer type +could be defined as: ```tablegen // All of the types will extend this class. @@ -1414,45 +1414,43 @@ ### Type name The name of the C++ class which gets generated defaults to -`Type` (e.g. `TestIntegerType` in the above example). This -can be overridden via the `cppClassName` field. The field `mnemonic` is -to specify the asm name for parsing. It is optional and not specifying it -will imply that no parser or printer methods are attached to this class. +`Type` (e.g. `TestIntegerType` in the above example). This can +be overridden via the `cppClassName` field. The field `mnemonic` is to specify +the asm name for parsing. It is optional and not specifying it will imply that +no parser or printer methods are attached to this class. ### Type documentation -The `summary` and `description` fields exist and are to be used the same way -as in Operations. Namely, the summary should be a one-liner and `description` +The `summary` and `description` fields exist and are to be used the same way as +in Operations. Namely, the summary should be a one-liner and `description` should be a longer explanation. ### Type parameters -The `parameters` field is a list of the types parameters. If no parameters -are specified (the default), this type is considered a singleton type. -Parameters are in the `"c++Type":$paramName` format. -To use C++ types as parameters which need allocation in the storage -constructor, there are two options: +The `parameters` field is a list of the types parameters. If no parameters are +specified (the default), this type is considered a singleton type. Parameters +are in the `"c++Type":$paramName` format. To use C++ types as parameters which +need allocation in the storage constructor, there are two options: -- Set `hasCustomStorageConstructor` to generate the TypeStorage class with -a constructor which is just declared -- no definition -- so you can write it -yourself. -- Use the `TypeParameter` tablegen class instead of the "c++Type" string. +- Set `hasCustomStorageConstructor` to generate the TypeStorage class with a + constructor which is just declared -- no definition -- so you can write it + yourself. +- Use the `TypeParameter` tablegen class instead of the "c++Type" string. ### TypeParameter tablegen class -This is used to further specify attributes about each of the types -parameters. It includes documentation (`description` and `syntax`), the C++ -type to use, and a custom allocator to use in the storage constructor method. +This is used to further specify attributes about each of the types parameters. +It includes documentation (`description` and `syntax`), the C++ type to use, and +a custom allocator to use in the storage constructor method. ```tablegen // DO NOT DO THIS! -let parameters = (ins - "ArrayRef":$dims); +let parameters = (ins "ArrayRef":$dims); ``` -The default storage constructor blindly copies fields by value. It does not -know anything about the types. In this case, the ArrayRef requires -allocation with `dims = allocator.copyInto(dims)`. +The default storage constructor blindly copies fields by value. It does not know +anything about the types. In this case, the ArrayRef requires allocation +with `dims = allocator.copyInto(dims)`. You can specify the necessary constructor by specializing the `TypeParameter` tblgen class: @@ -1460,28 +1458,29 @@ ```tablegen class ArrayRefIntParam : TypeParameter<"::llvm::ArrayRef", "Array of ints"> { - let allocator = [{$_dst = $_allocator.copyInto($_self);}]; + let allocator = "$_dst = $_allocator.copyInto($_self);"; } ... -let parameters = (ins - ArrayRefIntParam:$dims); +let parameters = (ins ArrayRefIntParam:$dims); ``` The `allocator` code block has the following substitutions: -- `$_allocator` is the TypeStorageAllocator in which to allocate objects. -- `$_dst` is the variable in which to place the allocated data. + +- `$_allocator` is the TypeStorageAllocator in which to allocate objects. +- `$_dst` is the variable in which to place the allocated data. MLIR includes several specialized classes for common situations: -- `StringRefParameter` for StringRefs. -- `ArrayRefParameter` for ArrayRefs of value -types -- `SelfAllocationParameter` for C++ classes which contain -a method called `allocateInto(StorageAllocator &allocator)` to allocate -itself into `allocator`. -- `ArrayRefOfSelfAllocationParameter` for arrays -of objects which self-allocate as per the last specialization. + +- `StringRefParameter` for StringRefs. +- `ArrayRefParameter` for ArrayRefs of value + types +- `SelfAllocationParameter` for C++ classes which contain + a method called `allocateInto(StorageAllocator &allocator)` to allocate + itself into `allocator`. +- `ArrayRefOfSelfAllocationParameter` for arrays + of objects which self-allocate as per the last specialization. If we were to use one of these included specializations: @@ -1495,45 +1494,46 @@ If a mnemonic is specified, the `printer` and `parser` code fields are active. The rules for both are: -- If null, generate just the declaration. -- If non-null and non-empty, use the code in the definition. The `$_printer` -or `$_parser` substitutions are valid and should be used. -- It is an error to have an empty code block. - -For each dialect, two "dispatch" functions will be created: one for parsing -and one for printing. You should add calls to these in your -`Dialect::printType` and `Dialect::parseType` methods. They are created in -the dialect's namespace and their function signatures are: + +- If null, generate just the declaration. +- If non-null and non-empty, use the code in the definition. The `$_printer` + or `$_parser` substitutions are valid and should be used. +- It is an error to have an empty code block. + +For each dialect, two "dispatch" functions will be created: one for parsing and +one for printing. You should add calls to these in your `Dialect::printType` and +`Dialect::parseType` methods. They are static functions placed alongside the +type class definitions and have the following function signatures: + ```c++ -Type generatedTypeParser(MLIRContext* ctxt, DialectAsmParser& parser, - StringRef mnemonic); +static Type generatedTypeParser(MLIRContext* ctxt, DialectAsmParser& parser, StringRef mnemonic); LogicalResult generatedTypePrinter(Type type, DialectAsmPrinter& printer); ``` -The mnemonic, parser, and printer fields are optional. If they're not -defined, the generated code will not include any parsing or printing code and -omit the type from the dispatch functions above. In this case, the dialect -author is responsible for parsing/printing the types in `Dialect::printType` -and `Dialect::parseType`. +The mnemonic, parser, and printer fields are optional. If they're not defined, +the generated code will not include any parsing or printing code and omit the +type from the dispatch functions above. In this case, the dialect author is +responsible for parsing/printing the types in `Dialect::printType` and +`Dialect::parseType`. ### Other fields -- If the `genStorageClass` field is set to 1 (the default) a storage class is -generated with member variables corresponding to each of the specified -`parameters`. -- If the `genAccessors` field is 1 (the default) accessor methods will be -generated on the Type class (e.g. `int getWidth() const` in the example -above). -- If the `genVerifyInvariantsDecl` field is set, a declaration for a method -`static LogicalResult verifyConstructionInvariants(Location, parameters...)` -is added to the class as well as a `getChecked(Location, parameters...)` -method which gets the result of `verifyConstructionInvariants` before calling -`get`. -- The `storageClass` field can be used to set the name of the storage class. -- The `storageNamespace` field is used to set the namespace where the storage -class should sit. Defaults to "detail". -- The `extraClassDeclaration` field is used to include extra code in the -class declaration. +- If the `genStorageClass` field is set to 1 (the default) a storage class is + generated with member variables corresponding to each of the specified + `parameters`. +- If the `genAccessors` field is 1 (the default) accessor methods will be + generated on the Type class (e.g. `int getWidth() const` in the example + above). +- If the `genVerifyInvariantsDecl` field is set, a declaration for a method + `static LogicalResult verifyConstructionInvariants(Location, parameters...)` + is added to the class as well as a `getChecked(Location, parameters...)` + method which gets the result of `verifyConstructionInvariants` before + calling `get`. +- The `storageClass` field can be used to set the name of the storage class. +- The `storageNamespace` field is used to set the namespace where the storage + class should sit. Defaults to "detail". +- The `extraClassDeclaration` field is used to include extra code in the class + declaration. ## Debugging Tips diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinDialect.td @@ -0,0 +1,27 @@ +//===-- BuiltinDialect.td - Builtin dialect definition -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the Builtin dialect. This dialect +// contains all of the attributes, operations, and types that are core to MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef BUILTIN_BASE +#define BUILTIN_BASE + +include "mlir/IR/OpBase.td" + +def Builtin_Dialect : Dialect { + let summary = + "A dialect containing the builtin Attributes, Operations, and Types"; + + let name = ""; + let cppNamespace = "::mlir"; +} + +#endif // BUILTIN_BASE diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -14,17 +14,10 @@ #ifndef BUILTIN_OPS #define BUILTIN_OPS +include "mlir/IR/BuiltinDialect.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" -def Builtin_Dialect : Dialect { - let summary = - "A dialect containing the builtin Attributes, Operations, and Types"; - - let name = ""; - let cppNamespace = "::mlir"; -} - // Base class for Builtin dialect ops. class Builtin_Op traits = []> : Op; diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -72,23 +72,6 @@ Type getElementType(); }; -//===----------------------------------------------------------------------===// -// IndexType -//===----------------------------------------------------------------------===// - -/// Index is a special integer-like type with unknown platform-dependent bit -/// width. -class IndexType : public Type::TypeBase { -public: - using Base::Base; - - /// Get an instance of the IndexType. - static IndexType get(MLIRContext *context); - - /// Storage bit width used for IndexType by internal compiler data structures. - static constexpr unsigned kInternalStorageBitWidth = 64; -}; - //===----------------------------------------------------------------------===// // IntegerType //===----------------------------------------------------------------------===// @@ -187,67 +170,6 @@ const llvm::fltSemantics &getFloatSemantics(); }; -//===----------------------------------------------------------------------===// -// BFloat16Type - -class BFloat16Type - : public Type::TypeBase { -public: - using Base::Base; - - /// Return an instance of the bfloat16 type. - static BFloat16Type get(MLIRContext *context); -}; - -inline FloatType FloatType::getBF16(MLIRContext *ctx) { - return BFloat16Type::get(ctx); -} - -//===----------------------------------------------------------------------===// -// Float16Type - -class Float16Type : public Type::TypeBase { -public: - using Base::Base; - - /// Return an instance of the float16 type. - static Float16Type get(MLIRContext *context); -}; - -inline FloatType FloatType::getF16(MLIRContext *ctx) { - return Float16Type::get(ctx); -} - -//===----------------------------------------------------------------------===// -// Float32Type - -class Float32Type : public Type::TypeBase { -public: - using Base::Base; - - /// Return an instance of the float32 type. - static Float32Type get(MLIRContext *context); -}; - -inline FloatType FloatType::getF32(MLIRContext *ctx) { - return Float32Type::get(ctx); -} - -//===----------------------------------------------------------------------===// -// Float64Type - -class Float64Type : public Type::TypeBase { -public: - using Base::Base; - - /// Return an instance of the float64 type. - static Float64Type get(MLIRContext *context); -}; - -inline FloatType FloatType::getF64(MLIRContext *ctx) { - return Float64Type::get(ctx); -} - //===----------------------------------------------------------------------===// // FunctionType //===----------------------------------------------------------------------===// @@ -276,20 +198,6 @@ ArrayRef resultIndices); }; -//===----------------------------------------------------------------------===// -// NoneType -//===----------------------------------------------------------------------===// - -/// NoneType is a unit type, i.e. a type with exactly one possible value, where -/// its value does not have a defined dynamic representation. -class NoneType : public Type::TypeBase { -public: - using Base::Base; - - /// Get an instance of the NoneType. - static NoneType get(MLIRContext *context); -}; - //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// @@ -720,11 +628,20 @@ return getTypes()[index]; } }; +} // end namespace mlir + +//===----------------------------------------------------------------------===// +// Tablegen Type Declarations +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/IR/BuiltinTypes.h.inc" //===----------------------------------------------------------------------===// // Deferred Method Definitions //===----------------------------------------------------------------------===// +namespace mlir { inline bool BaseMemRefType::classof(Type type) { return type.isa(); } @@ -733,6 +650,22 @@ return type.isa(); } +inline FloatType FloatType::getBF16(MLIRContext *ctx) { + return BFloat16Type::get(ctx); +} + +inline FloatType FloatType::getF16(MLIRContext *ctx) { + return Float16Type::get(ctx); +} + +inline FloatType FloatType::getF32(MLIRContext *ctx) { + return Float32Type::get(ctx); +} + +inline FloatType FloatType::getF64(MLIRContext *ctx) { + return Float64Type::get(ctx); +} + inline bool ShapedType::classof(Type type) { return type.isa(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -0,0 +1,114 @@ +//===- BuiltinTypes.td - Builtin type definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the set of builtin MLIR types, or the set of types necessary for the +// validity of and defining the IR. +// +//===----------------------------------------------------------------------===// + +#ifndef BUILTIN_TYPES +#define BUILTIN_TYPES + +include "mlir/IR/BuiltinDialect.td" + +// TODO: Currently the types defined in this file are prefixed with `Builtin_`. +// This is to differentiate the types here with the ones in OpBase.td. We should +// remove the definitions in OpBase.td, and repoint users to this file instead. + +// Base class for Builtin dialect types. +class Builtin_Type : TypeDef { + let mnemonic = ?; +} + +//===----------------------------------------------------------------------===// +// FloatType +//===----------------------------------------------------------------------===// + +// Base class for Builtin dialect float types. +class Builtin_FloatType : TypeDef { + let extraClassDeclaration = [{ + static }] # name # [{Type get(MLIRContext *context); + }]; +} + +//===----------------------------------------------------------------------===// +// BFloat16Type + +def Builtin_BFloat16 : Builtin_FloatType<"BFloat16"> { + let summary = "bfloat16 floating-point type"; +} + +//===----------------------------------------------------------------------===// +// Float16Type + +def Builtin_Float16 : Builtin_FloatType<"Float16"> { + let summary = "16-bit floating-point type"; +} + +//===----------------------------------------------------------------------===// +// Float32Type + +def Builtin_Float32 : Builtin_FloatType<"Float32"> { + let summary = "32-bit floating-point type"; +} + +//===----------------------------------------------------------------------===// +// Float64Type + +def Builtin_Float64 : Builtin_FloatType<"Float64"> { + let summary = "64-bit floating-point type"; +} + +//===----------------------------------------------------------------------===// +// IndexType +//===----------------------------------------------------------------------===// + +def Builtin_Index : Builtin_Type<"Index"> { + let summary = "Integer-like type with unknown platform-dependent bit width"; + let description = [{ + Syntax: + + ``` + // Target word-sized integer. + index-type ::= `index` + ``` + + The index type is a signless integer whose size is equal to the natural + machine word of the target ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#integer-signedness-semantics) ) + and is used by the affine constructs in MLIR. Unlike fixed-size integers, + it cannot be used as an element of vector ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#index-type-disallowed-in-vector-types) ). + + **Rationale:** integers of platform-specific bit widths are practical to + express sizes, dimensionalities and subscripts. + }]; + let extraClassDeclaration = [{ + static IndexType get(MLIRContext *context); + + /// Storage bit width used for IndexType by internal compiler data + /// structures. + static constexpr unsigned kInternalStorageBitWidth = 64; + }]; +} + +//===----------------------------------------------------------------------===// +// NoneType +//===----------------------------------------------------------------------===// + +def Builtin_None : Builtin_Type<"None"> { + let summary = "A unit type"; + let description = [{ + NoneType is a unit type, i.e. a type with exactly one possible value, where + its value does not have a defined dynamic representation. + }]; + let extraClassDeclaration = [{ + static NoneType get(MLIRContext *context); + }]; +} + +#endif // BUILTIN_TYPES diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -2,10 +2,18 @@ add_mlir_interface(SymbolInterfaces) add_mlir_interface(RegionKindInterface) +set(LLVM_TARGET_DEFINITIONS BuiltinDialect.td) +mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls) +add_public_tablegen_target(MLIRBuiltinDialectIncGen) + set(LLVM_TARGET_DEFINITIONS BuiltinOps.td) mlir_tablegen(BuiltinOps.h.inc -gen-op-decls) mlir_tablegen(BuiltinOps.cpp.inc -gen-op-defs) -mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls) add_public_tablegen_target(MLIRBuiltinOpsIncGen) +set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td) +mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls) +mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(MLIRBuiltinTypesIncGen) + add_mlir_doc(BuiltinOps -gen-op-doc Builtin Dialects/) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2415,15 +2415,18 @@ // Data type generation //===----------------------------------------------------------------------===// -// Define a new type belonging to a dialect and called 'name'. -class TypeDef { - Dialect dialect = owningdialect; +// Define a new type, named `name`, belonging to `dialect` that inherits from +// the given C++ base class. +class TypeDef + : DialectType> { + // The name of the C++ Type class. string cppClassName = name # "Type"; + // The name of the C++ base class to use for this Type. + string cppBaseClassName = baseCppClass; // Short summary of the type. string summary = ?; - // The longer description of this type. - string description = ?; // Name of storage class to generate or use. string storageClass = name # "TypeStorage"; @@ -2477,6 +2480,15 @@ bit genVerifyInvariantsDecl = 0; // Extra code to include in the class declaration. code extraClassDeclaration = [{}]; + + // The predicate for when this type is used as a type constraint. + let predicate = CPred<"$_self.isa<" # dialect.cppNamespace # + "::" # cppClassName # ">()">; + // A constant builder provided when the type has no parameters. + let builderCall = !if(!empty(parameters), + "$_builder.getType<" # dialect.cppNamespace # + "::" # cppClassName # ">()", + ""); } // 'Parameters' should be subclasses of this or simple strings (which is a diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h --- a/mlir/include/mlir/TableGen/TypeDef.h +++ b/mlir/include/mlir/TableGen/TypeDef.h @@ -48,6 +48,9 @@ // Returns the name of the C++ class to generate. StringRef getCppClassName() const; + // Returns the name of the C++ base class to use when generating this type. + StringRef getCppBaseClassName() const; + // Returns the name of the storage class for this type. StringRef getStorageClassName() const; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -20,6 +20,13 @@ using namespace mlir; using namespace mlir::detail; +//===----------------------------------------------------------------------===// +/// Tablegen Type Definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/IR/BuiltinTypes.cpp.inc" + //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -33,7 +33,9 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR DEPENDS + MLIRBuiltinDialectIncGen MLIRBuiltinOpsIncGen + MLIRBuiltinTypesIncGen MLIRCallInterfacesIncGen MLIROpAsmInterfaceIncGen MLIRRegionKindInterfaceIncGen diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -13,6 +13,7 @@ #include "mlir/TableGen/Constraint.h" #include "llvm/TableGen/Record.h" +using namespace mlir; using namespace mlir::tblgen; Constraint::Constraint(const llvm::Record *record) @@ -56,11 +57,18 @@ return getPredicate().getCondition(); } -llvm::StringRef Constraint::getDescription() const { - auto doc = def->getValueAsString("description"); - if (doc.empty()) - return def->getName(); - return doc; +StringRef Constraint::getDescription() const { + // If a summary is found, we use that given that it is a focused single line + // comment. + if (Optional summary = def->getValueAsOptionalString("summary")) + return *summary; + // If a summary can't be found, look for a specific description field to use + // for the constraint. + StringRef desc = def->getValueAsString("description"); + if (!desc.empty()) + return desc; + // Otherwise, fallback to the name of the constraint definition. + return def->getName(); } AppliedConstraint::AppliedConstraint(Constraint &&constraint, diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp --- a/mlir/lib/TableGen/TypeDef.cpp +++ b/mlir/lib/TableGen/TypeDef.cpp @@ -31,6 +31,10 @@ return def->getValueAsString("cppClassName"); } +StringRef TypeDef::getCppBaseClassName() const { + return def->getValueAsString("cppBaseClassName"); +} + bool TypeDef::hasDescription() const { const llvm::RecordVal *s = def->getValue("description"); return s != nullptr && isa(s->getValue()); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -15,7 +15,6 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSwitch.h" using namespace mlir; @@ -183,77 +182,6 @@ return builder.create(loc, type, value); } -static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, - llvm::SetVector &stack) { - StringRef typeTag; - if (failed(parser.parseKeyword(&typeTag))) - return Type(); - - auto genType = generatedTypeParser(ctxt, parser, typeTag); - if (genType != Type()) - return genType; - - if (typeTag == "test_type") - return TestType::get(parser.getBuilder().getContext()); - - if (typeTag != "test_rec") - return Type(); - - StringRef name; - if (parser.parseLess() || parser.parseKeyword(&name)) - return Type(); - auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); - - // If this type already has been parsed above in the stack, expect just the - // name. - if (stack.contains(rec)) { - if (failed(parser.parseGreater())) - return Type(); - return rec; - } - - // Otherwise, parse the body and update the type. - if (failed(parser.parseComma())) - return Type(); - stack.insert(rec); - Type subtype = parseTestType(ctxt, parser, stack); - stack.pop_back(); - if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) - return Type(); - - return rec; -} - -Type TestDialect::parseType(DialectAsmParser &parser) const { - llvm::SetVector stack; - return parseTestType(getContext(), parser, stack); -} - -static void printTestType(Type type, DialectAsmPrinter &printer, - llvm::SetVector &stack) { - if (succeeded(generatedTypePrinter(type, printer))) - return; - if (type.isa()) { - printer << "test_type"; - return; - } - - auto rec = type.cast(); - printer << "test_rec<" << rec.getName(); - if (!stack.contains(rec)) { - printer << ", "; - stack.insert(rec); - printTestType(rec.getBody(), printer, stack); - stack.pop_back(); - } - printer << ">"; -} - -void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { - llvm::SetVector stack; - printTestType(type, printer, stack); -} - LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -12,9 +12,12 @@ //===----------------------------------------------------------------------===// #include "TestTypes.h" +#include "TestDialect.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Types.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -116,5 +119,84 @@ return success(); } +//===----------------------------------------------------------------------===// +// Tablegen Generated Definitions +//===----------------------------------------------------------------------===// + #define GET_TYPEDEF_CLASSES #include "TestTypeDefs.cpp.inc" + +//===----------------------------------------------------------------------===// +// TestDialect +//===----------------------------------------------------------------------===// + +static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, + llvm::SetVector &stack) { + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + auto genType = generatedTypeParser(ctxt, parser, typeTag); + if (genType != Type()) + return genType; + + if (typeTag == "test_type") + return TestType::get(parser.getBuilder().getContext()); + + if (typeTag != "test_rec") + return Type(); + + StringRef name; + if (parser.parseLess() || parser.parseKeyword(&name)) + return Type(); + auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); + + // If this type already has been parsed above in the stack, expect just the + // name. + if (stack.contains(rec)) { + if (failed(parser.parseGreater())) + return Type(); + return rec; + } + + // Otherwise, parse the body and update the type. + if (failed(parser.parseComma())) + return Type(); + stack.insert(rec); + Type subtype = parseTestType(ctxt, parser, stack); + stack.pop_back(); + if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) + return Type(); + + return rec; +} + +Type TestDialect::parseType(DialectAsmParser &parser) const { + llvm::SetVector stack; + return parseTestType(getContext(), parser, stack); +} + +static void printTestType(Type type, DialectAsmPrinter &printer, + llvm::SetVector &stack) { + if (succeeded(generatedTypePrinter(type, printer))) + return; + if (type.isa()) { + printer << "test_type"; + return; + } + + auto rec = type.cast(); + printer << "test_rec<" << rec.getName(); + if (!stack.contains(rec)) { + printer << ", "; + stack.insert(rec); + printTestType(rec.getBody(), printer, stack); + stack.pop_back(); + } + printer << ">"; +} + +void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { + llvm::SetVector stack; + printTestType(type, printer, stack); +} diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -11,9 +11,6 @@ // DECL: class DialectAsmPrinter; // DECL: } // namespace mlir -// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic); -// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer); - // DEF: #ifdef GET_TYPEDEF_LIST // DEF: #undef GET_TYPEDEF_LIST // DEF: ::mlir::test::SimpleAType, diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -92,7 +92,7 @@ /// llvm::formatv will call this function when using an instance as a /// replacement value. void format(raw_ostream &os, StringRef options) override { - if (params.size() && prependComma) + if (!params.empty() && prependComma) os << ", "; switch (emitFormat) { @@ -146,8 +146,9 @@ /// case. /// /// {0}: The name of the typeDef class. +/// {1}: The name of the type base class. static const char *const typeDefDeclSingletonBeginStr = R"( - class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{ + class {0}: public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{ public: /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; @@ -158,15 +159,16 @@ /// case. /// /// {0}: The name of the typeDef class. -/// {1}: The typeDef storage class namespace. -/// {2}: The storage class name. -/// {3}: The list of parameters with types. +/// {1}: The name of the type base class. +/// {2}: The typeDef storage class namespace. +/// {3}: The storage class name. +/// {4}: The list of parameters with types. static const char *const typeDefDeclParametricBeginStr = R"( - namespace {1} { - struct {2}; + namespace {2} { + struct {3}; } - class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, - {1}::{2}> {{ + class {0}: public ::mlir::Type::TypeBase<{0}, {1}, + {2}::{3}> {{ public: /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; @@ -196,10 +198,11 @@ // template. if (typeDef.getNumParameters() == 0) os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(), - typeDef.getStorageNamespace(), typeDef.getStorageClassName()); + typeDef.getCppBaseClassName()); else os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(), - typeDef.getStorageNamespace(), typeDef.getStorageClassName()); + typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(), + typeDef.getStorageClassName()); // Emit the extra declarations first in case there's a type definition in // there. @@ -208,8 +211,10 @@ TypeParamCommaFormatter emitTypeNamePairsAfterComma( TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params); - os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", - typeDef.getCppClassName(), emitTypeNamePairsAfterComma); + if (!params.empty()) { + os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", + typeDef.getCppClassName(), emitTypeNamePairsAfterComma); + } // Emit the verify invariants declaration. if (typeDef.genVerifyInvariantsDecl()) @@ -252,17 +257,9 @@ // Output the common "header". os << typeDefDeclHeader; - if (typeDefs.size() > 0) { + if (!typeDefs.empty()) { NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect()); - // Well known print/parse dispatch function declarations. These are called - // from Dialect::parseType() and Dialect::printType() methods. - os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " - "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n"; - os << " ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, " - "::mlir::DialectAsmPrinter& printer);\n"; - os << "\n"; - // Declare all the type classes first (in case they reference each other). for (const TypeDef &typeDef : typeDefs) os << " class " << typeDef.getCppClassName() << ";\n"; @@ -488,14 +485,16 @@ if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0) emitStorageClass(typeDef, os); - os << llvm::formatv( - "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" - " return Base::get(ctxt{2});\n}\n", - typeDef.getCppClassName(), - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters), - TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, - parameters)); + if (!parameters.empty()) { + os << llvm::formatv( + "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" + " return Base::get(ctxt{2});\n}\n", + typeDef.getCppClassName(), + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters), + TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, + parameters)); + } // Emit the parameter accessors. if (typeDef.genAccessors()) @@ -526,38 +525,40 @@ /// Emit the dialect printer/parser dispatcher. User's code should call these /// functions from their dialect's print/parse methods. -static void emitParsePrintDispatch(SmallVectorImpl &typeDefs, - raw_ostream &os) { - if (typeDefs.size() == 0) +static void emitParsePrintDispatch(ArrayRef types, raw_ostream &os) { + if (llvm::none_of(types, [](const TypeDef &type) { + return type.getMnemonic().hasValue(); + })) { return; - const Dialect &dialect = typeDefs.begin()->getDialect(); - NamespaceEmitter ns(os, dialect); + } - // The parser dispatch is just a list of if-elses, matching on the mnemonic - // and calling the class's parse function. - os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " + // The parser dispatch is just a list of if-elses, matching on the + // mnemonic and calling the class's parse function. + os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* " + "ctxt, " "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n"; - for (const TypeDef &typeDef : typeDefs) - if (typeDef.getMnemonic()) + for (const TypeDef &type : types) + if (type.getMnemonic()) os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " "{0}::{1}::parse(ctxt, parser);\n", - typeDef.getDialect().getCppNamespace(), - typeDef.getCppClassName()); + type.getDialect().getCppNamespace(), + type.getCppClassName()); os << " return ::mlir::Type();\n"; os << "}\n\n"; // The printer dispatch uses llvm::TypeSwitch to find and call the correct // printer. - os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, " + os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type " + "type, " "::mlir::DialectAsmPrinter& printer) {\n" << " ::mlir::LogicalResult found = ::mlir::success();\n" << " ::llvm::TypeSwitch<::mlir::Type>(type)\n"; - for (auto typeDef : typeDefs) - if (typeDef.getMnemonic()) + for (const TypeDef &type : types) + if (type.getMnemonic()) os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ " "t.dyn_cast<{0}::{1}>().print(printer); })\n", - typeDef.getDialect().getCppNamespace(), - typeDef.getCppClassName()); + type.getDialect().getCppNamespace(), + type.getCppClassName()); os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); " "});\n" << " return found;\n"