diff --git a/mlir/include/mlir/IR/BuiltinDialect.td b/mlir/include/mlir/IR/BuiltinDialect.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinDialect.td @@ -0,0 +1,27 @@ +//===-- BuiltinDialect.td - Builtin dialect definition -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definition of the Builtin dialect. This dialect +// contains all of the attributes, operations, and types that are core to MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef BUILTIN_BASE +#define BUILTIN_BASE + +include "mlir/IR/OpBase.td" + +def Builtin_Dialect : Dialect { + let summary = + "A dialect containing the builtin Attributes, Operations, and Types"; + + let name = ""; + let cppNamespace = "::mlir"; +} + +#endif // BUILTIN_BASE diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -14,17 +14,10 @@ #ifndef BUILTIN_OPS #define BUILTIN_OPS +include "mlir/IR/BuiltinDialect.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" -def Builtin_Dialect : Dialect { - let summary = - "A dialect containing the builtin Attributes, Operations, and Types"; - - let name = ""; - let cppNamespace = "::mlir"; -} - // Base class for Builtin dialect ops. class Builtin_Op traits = []> : Op; diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -72,23 +72,6 @@ Type getElementType(); }; -//===----------------------------------------------------------------------===// -// IndexType -//===----------------------------------------------------------------------===// - -/// Index is a special integer-like type with unknown platform-dependent bit -/// width. -class IndexType : public Type::TypeBase { -public: - using Base::Base; - - /// Get an instance of the IndexType. - static IndexType get(MLIRContext *context); - - /// Storage bit width used for IndexType by internal compiler data structures. - static constexpr unsigned kInternalStorageBitWidth = 64; -}; - //===----------------------------------------------------------------------===// // IntegerType //===----------------------------------------------------------------------===// @@ -178,67 +161,6 @@ const llvm::fltSemantics &getFloatSemantics(); }; -//===----------------------------------------------------------------------===// -// BFloat16Type - -class BFloat16Type - : public Type::TypeBase { -public: - using Base::Base; - - /// Return an instance of the bfloat16 type. - static BFloat16Type get(MLIRContext *context); -}; - -inline FloatType FloatType::getBF16(MLIRContext *ctx) { - return BFloat16Type::get(ctx); -} - -//===----------------------------------------------------------------------===// -// Float16Type - -class Float16Type : public Type::TypeBase { -public: - using Base::Base; - - /// Return an instance of the float16 type. - static Float16Type get(MLIRContext *context); -}; - -inline FloatType FloatType::getF16(MLIRContext *ctx) { - return Float16Type::get(ctx); -} - -//===----------------------------------------------------------------------===// -// Float32Type - -class Float32Type : public Type::TypeBase { -public: - using Base::Base; - - /// Return an instance of the float32 type. - static Float32Type get(MLIRContext *context); -}; - -inline FloatType FloatType::getF32(MLIRContext *ctx) { - return Float32Type::get(ctx); -} - -//===----------------------------------------------------------------------===// -// Float64Type - -class Float64Type : public Type::TypeBase { -public: - using Base::Base; - - /// Return an instance of the float64 type. - static Float64Type get(MLIRContext *context); -}; - -inline FloatType FloatType::getF64(MLIRContext *ctx) { - return Float64Type::get(ctx); -} - //===----------------------------------------------------------------------===// // FunctionType //===----------------------------------------------------------------------===// @@ -267,20 +189,6 @@ ArrayRef resultIndices); }; -//===----------------------------------------------------------------------===// -// NoneType -//===----------------------------------------------------------------------===// - -/// NoneType is a unit type, i.e. a type with exactly one possible value, where -/// its value does not have a defined dynamic representation. -class NoneType : public Type::TypeBase { -public: - using Base::Base; - - /// Get an instance of the NoneType. - static NoneType get(MLIRContext *context); -}; - //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===// @@ -706,11 +614,20 @@ return getTypes()[index]; } }; +} // end namespace mlir + +//===----------------------------------------------------------------------===// +// Tablegen Type Declarations +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/IR/BuiltinTypes.h.inc" //===----------------------------------------------------------------------===// // Deferred Method Definitions //===----------------------------------------------------------------------===// +namespace mlir { inline bool BaseMemRefType::classof(Type type) { return type.isa(); } @@ -719,6 +636,22 @@ return type.isa(); } +inline FloatType FloatType::getBF16(MLIRContext *ctx) { + return BFloat16Type::get(ctx); +} + +inline FloatType FloatType::getF16(MLIRContext *ctx) { + return Float16Type::get(ctx); +} + +inline FloatType FloatType::getF32(MLIRContext *ctx) { + return Float32Type::get(ctx); +} + +inline FloatType FloatType::getF64(MLIRContext *ctx) { + return Float64Type::get(ctx); +} + inline bool ShapedType::classof(Type type) { return type.isa(); diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -0,0 +1,114 @@ +//===- BuiltinTypes.td - Builtin type definitions ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the set of builtin MLIR types, or the set of types necessary for the +// validity of and defining the IR. +// +//===----------------------------------------------------------------------===// + +#ifndef BUILTIN_TYPES +#define BUILTIN_TYPES + +include "mlir/IR/BuiltinDialect.td" + +// TODO: Currently the types defined in this file are prefixed with `Builtin_`. +// This is to differentiate the types here with the ones in OpBase.td. We should +// remove the definitions in OpBase.td, and repoint users to this file instead. + +// Base class for Builtin dialect types. +class Builtin_Type : TypeDef { + let mnemonic = ?; +} + +//===----------------------------------------------------------------------===// +// FloatType +//===----------------------------------------------------------------------===// + +// Base class for Builtin dialect float types. +class Builtin_FloatType : TypeDef { + let extraClassDeclaration = [{ + static }] # name # [{Type get(MLIRContext *context); + }]; +} + +//===----------------------------------------------------------------------===// +// BFloat16Type + +def Builtin_BFloat16 : Builtin_FloatType<"BFloat16"> { + let summary = "bfloat16 floating-point type"; +} + +//===----------------------------------------------------------------------===// +// Float16Type + +def Builtin_Float16 : Builtin_FloatType<"Float16"> { + let summary = "16-bit floating-point type"; +} + +//===----------------------------------------------------------------------===// +// Float32Type + +def Builtin_Float32 : Builtin_FloatType<"Float32"> { + let summary = "32-bit floating-point type"; +} + +//===----------------------------------------------------------------------===// +// Float64Type + +def Builtin_Float64 : Builtin_FloatType<"Float64"> { + let summary = "64-bit floating-point type"; +} + +//===----------------------------------------------------------------------===// +// IndexType +//===----------------------------------------------------------------------===// + +def Builtin_Index : Builtin_Type<"Index"> { + let summary = "Integer-like type with unknown platform-dependent bit width"; + let description = [{ + Syntax: + + ``` + // Target word-sized integer. + index-type ::= `index` + ``` + + The index type is a signless integer whose size is equal to the natural + machine word of the target ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#integer-signedness-semantics) ) + and is used by the affine constructs in MLIR. Unlike fixed-size integers, + it cannot be used as an element of vector ( [rationale](https://mlir.llvm.org/docs/Rationale/Rationale/#index-type-disallowed-in-vector-types) ). + + **Rationale:** integers of platform-specific bit widths are practical to + express sizes, dimensionalities and subscripts. + }]; + let extraClassDeclaration = [{ + static IndexType get(MLIRContext *context); + + /// Storage bit width used for IndexType by internal compiler data + /// structures. + static constexpr unsigned kInternalStorageBitWidth = 64; + }]; +} + +//===----------------------------------------------------------------------===// +// NoneType +//===----------------------------------------------------------------------===// + +def Builtin_None : Builtin_Type<"None"> { + let summary = "A unit type"; + let description = [{ + NoneType is a unit type, i.e. a type with exactly one possible value, where + its value does not have a defined dynamic representation. + }]; + let extraClassDeclaration = [{ + static NoneType get(MLIRContext *context); + }]; +} + +#endif // BUILTIN_TYPES diff --git a/mlir/include/mlir/IR/CMakeLists.txt b/mlir/include/mlir/IR/CMakeLists.txt --- a/mlir/include/mlir/IR/CMakeLists.txt +++ b/mlir/include/mlir/IR/CMakeLists.txt @@ -2,10 +2,18 @@ add_mlir_interface(SymbolInterfaces) add_mlir_interface(RegionKindInterface) +set(LLVM_TARGET_DEFINITIONS BuiltinDialect.td) +mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls) +add_public_tablegen_target(MLIRBuiltinDialectIncGen) + set(LLVM_TARGET_DEFINITIONS BuiltinOps.td) mlir_tablegen(BuiltinOps.h.inc -gen-op-decls) mlir_tablegen(BuiltinOps.cpp.inc -gen-op-defs) -mlir_tablegen(BuiltinDialect.h.inc -gen-dialect-decls) add_public_tablegen_target(MLIRBuiltinOpsIncGen) +set(LLVM_TARGET_DEFINITIONS BuiltinTypes.td) +mlir_tablegen(BuiltinTypes.h.inc -gen-typedef-decls) +mlir_tablegen(BuiltinTypes.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(MLIRBuiltinTypesIncGen) + add_mlir_doc(BuiltinOps -gen-op-doc Builtin Dialects/) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2415,10 +2415,16 @@ // Data type generation //===----------------------------------------------------------------------===// -// Define a new type belonging to a dialect and called 'name'. -class TypeDef { - Dialect dialect = owningdialect; +// Define a new type belonging to a dialect, named 'name', that inherits from +// the given C++ base class. +class TypeDef + : DialectType()">> { + // The name of the C++ Type class. string cppClassName = name # "Type"; + // The name of the C++ base class to use for this Type. + string cppBaseClassName = baseCppClass; // Short summary of the type. string summary = ?; diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h --- a/mlir/include/mlir/TableGen/TypeDef.h +++ b/mlir/include/mlir/TableGen/TypeDef.h @@ -48,6 +48,9 @@ // Returns the name of the C++ class to generate. StringRef getCppClassName() const; + // Returns the name of the C++ base class to use when generating this type. + StringRef getCppBaseClassName() const; + // Returns the name of the storage class for this type. StringRef getStorageClassName() const; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -20,6 +20,13 @@ using namespace mlir; using namespace mlir::detail; +//===----------------------------------------------------------------------===// +/// Tablegen Type Definitions +//===----------------------------------------------------------------------===// + +#define GET_TYPEDEF_CLASSES +#include "mlir/IR/BuiltinTypes.cpp.inc" + //===----------------------------------------------------------------------===// /// ComplexType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -33,7 +33,9 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR DEPENDS + MLIRBuiltinDialectIncGen MLIRBuiltinOpsIncGen + MLIRBuiltinTypesIncGen MLIRCallInterfacesIncGen MLIROpAsmInterfaceIncGen MLIRRegionKindInterfaceIncGen diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -13,6 +13,7 @@ #include "mlir/TableGen/Constraint.h" #include "llvm/TableGen/Record.h" +using namespace mlir; using namespace mlir::tblgen; Constraint::Constraint(const llvm::Record *record) @@ -56,11 +57,18 @@ return getPredicate().getCondition(); } -llvm::StringRef Constraint::getDescription() const { - auto doc = def->getValueAsString("description"); - if (doc.empty()) - return def->getName(); - return doc; +StringRef Constraint::getDescription() const { + // If a summary is found, we use that given that it is a focused single line + // comment. + if (Optional summary = def->getValueAsOptionalString("summary")) + return *summary; + // If a summary can't be found, look for a specific description field to use + // for the constraint. + StringRef desc = def->getValueAsString("description"); + if (!desc.empty()) + return desc; + // Otherwise, fallback to the name of the constraint definition. + return def->getName(); } AppliedConstraint::AppliedConstraint(Constraint &&constraint, diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp --- a/mlir/lib/TableGen/TypeDef.cpp +++ b/mlir/lib/TableGen/TypeDef.cpp @@ -31,6 +31,10 @@ return def->getValueAsString("cppClassName"); } +StringRef TypeDef::getCppBaseClassName() const { + return def->getValueAsString("cppBaseClassName"); +} + bool TypeDef::hasDescription() const { const llvm::RecordVal *s = def->getValue("description"); return s != nullptr && isa(s->getValue()); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -15,7 +15,6 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringSwitch.h" using namespace mlir; @@ -183,77 +182,6 @@ return builder.create(loc, type, value); } -static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, - llvm::SetVector &stack) { - StringRef typeTag; - if (failed(parser.parseKeyword(&typeTag))) - return Type(); - - auto genType = generatedTypeParser(ctxt, parser, typeTag); - if (genType != Type()) - return genType; - - if (typeTag == "test_type") - return TestType::get(parser.getBuilder().getContext()); - - if (typeTag != "test_rec") - return Type(); - - StringRef name; - if (parser.parseLess() || parser.parseKeyword(&name)) - return Type(); - auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); - - // If this type already has been parsed above in the stack, expect just the - // name. - if (stack.contains(rec)) { - if (failed(parser.parseGreater())) - return Type(); - return rec; - } - - // Otherwise, parse the body and update the type. - if (failed(parser.parseComma())) - return Type(); - stack.insert(rec); - Type subtype = parseTestType(ctxt, parser, stack); - stack.pop_back(); - if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) - return Type(); - - return rec; -} - -Type TestDialect::parseType(DialectAsmParser &parser) const { - llvm::SetVector stack; - return parseTestType(getContext(), parser, stack); -} - -static void printTestType(Type type, DialectAsmPrinter &printer, - llvm::SetVector &stack) { - if (succeeded(generatedTypePrinter(type, printer))) - return; - if (type.isa()) { - printer << "test_type"; - return; - } - - auto rec = type.cast(); - printer << "test_rec<" << rec.getName(); - if (!stack.contains(rec)) { - printer << ", "; - stack.insert(rec); - printTestType(rec.getBody(), printer, stack); - stack.pop_back(); - } - printer << ">"; -} - -void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { - llvm::SetVector stack; - printTestType(type, printer, stack); -} - LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -12,9 +12,12 @@ //===----------------------------------------------------------------------===// #include "TestTypes.h" +#include "TestDialect.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Types.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -116,5 +119,84 @@ return success(); } +//===----------------------------------------------------------------------===// +// Tablegen Generated Definitions +//===----------------------------------------------------------------------===// + #define GET_TYPEDEF_CLASSES #include "TestTypeDefs.cpp.inc" + +//===----------------------------------------------------------------------===// +// TestDialect +//===----------------------------------------------------------------------===// + +static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, + llvm::SetVector &stack) { + StringRef typeTag; + if (failed(parser.parseKeyword(&typeTag))) + return Type(); + + auto genType = generatedTypeParser(ctxt, parser, typeTag); + if (genType != Type()) + return genType; + + if (typeTag == "test_type") + return TestType::get(parser.getBuilder().getContext()); + + if (typeTag != "test_rec") + return Type(); + + StringRef name; + if (parser.parseLess() || parser.parseKeyword(&name)) + return Type(); + auto rec = TestRecursiveType::get(parser.getBuilder().getContext(), name); + + // If this type already has been parsed above in the stack, expect just the + // name. + if (stack.contains(rec)) { + if (failed(parser.parseGreater())) + return Type(); + return rec; + } + + // Otherwise, parse the body and update the type. + if (failed(parser.parseComma())) + return Type(); + stack.insert(rec); + Type subtype = parseTestType(ctxt, parser, stack); + stack.pop_back(); + if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) + return Type(); + + return rec; +} + +Type TestDialect::parseType(DialectAsmParser &parser) const { + llvm::SetVector stack; + return parseTestType(getContext(), parser, stack); +} + +static void printTestType(Type type, DialectAsmPrinter &printer, + llvm::SetVector &stack) { + if (succeeded(generatedTypePrinter(type, printer))) + return; + if (type.isa()) { + printer << "test_type"; + return; + } + + auto rec = type.cast(); + printer << "test_rec<" << rec.getName(); + if (!stack.contains(rec)) { + printer << ", "; + stack.insert(rec); + printTestType(rec.getBody(), printer, stack); + stack.pop_back(); + } + printer << ">"; +} + +void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { + llvm::SetVector stack; + printTestType(type, printer, stack); +} diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -11,9 +11,6 @@ // DECL: class DialectAsmPrinter; // DECL: } // namespace mlir -// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic); -// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer); - // DEF: #ifdef GET_TYPEDEF_LIST // DEF: #undef GET_TYPEDEF_LIST // DEF: ::mlir::test::SimpleAType, diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -92,7 +92,7 @@ /// llvm::formatv will call this function when using an instance as a /// replacement value. void format(raw_ostream &os, StringRef options) override { - if (params.size() && prependComma) + if (!params.empty() && prependComma) os << ", "; switch (emitFormat) { @@ -146,8 +146,9 @@ /// case. /// /// {0}: The name of the typeDef class. +/// {1}: The name of the type base class. static const char *const typeDefDeclSingletonBeginStr = R"( - class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{ + class {0}: public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{ public: /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; @@ -158,15 +159,16 @@ /// case. /// /// {0}: The name of the typeDef class. -/// {1}: The typeDef storage class namespace. -/// {2}: The storage class name. -/// {3}: The list of parameters with types. +/// {1}: The name of the type base class. +/// {2}: The typeDef storage class namespace. +/// {3}: The storage class name. +/// {4}: The list of parameters with types. static const char *const typeDefDeclParametricBeginStr = R"( - namespace {1} { - struct {2}; + namespace {2} { + struct {3}; } - class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, - {1}::{2}> {{ + class {0}: public ::mlir::Type::TypeBase<{0}, {1}, + {2}::{3}> {{ public: /// Inherit some necessary constructors from 'TypeBase'. using Base::Base; @@ -196,10 +198,11 @@ // template. if (typeDef.getNumParameters() == 0) os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(), - typeDef.getStorageNamespace(), typeDef.getStorageClassName()); + typeDef.getCppBaseClassName()); else os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(), - typeDef.getStorageNamespace(), typeDef.getStorageClassName()); + typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(), + typeDef.getStorageClassName()); // Emit the extra declarations first in case there's a type definition in // there. @@ -208,8 +211,10 @@ TypeParamCommaFormatter emitTypeNamePairsAfterComma( TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params); - os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", - typeDef.getCppClassName(), emitTypeNamePairsAfterComma); + if (!params.empty()) { + os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", + typeDef.getCppClassName(), emitTypeNamePairsAfterComma); + } // Emit the verify invariants declaration. if (typeDef.genVerifyInvariantsDecl()) @@ -252,17 +257,9 @@ // Output the common "header". os << typeDefDeclHeader; - if (typeDefs.size() > 0) { + if (!typeDefs.empty()) { NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect()); - // Well known print/parse dispatch function declarations. These are called - // from Dialect::parseType() and Dialect::printType() methods. - os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " - "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n"; - os << " ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, " - "::mlir::DialectAsmPrinter& printer);\n"; - os << "\n"; - // Declare all the type classes first (in case they reference each other). for (const TypeDef &typeDef : typeDefs) os << " class " << typeDef.getCppClassName() << ";\n"; @@ -488,14 +485,16 @@ if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0) emitStorageClass(typeDef, os); - os << llvm::formatv( - "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" - " return Base::get(ctxt{2});\n}\n", - typeDef.getCppClassName(), - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters), - TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, - parameters)); + if (!parameters.empty()) { + os << llvm::formatv( + "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" + " return Base::get(ctxt{2});\n}\n", + typeDef.getCppClassName(), + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters), + TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, + parameters)); + } // Emit the parameter accessors. if (typeDef.genAccessors()) @@ -526,38 +525,40 @@ /// Emit the dialect printer/parser dispatcher. User's code should call these /// functions from their dialect's print/parse methods. -static void emitParsePrintDispatch(SmallVectorImpl &typeDefs, - raw_ostream &os) { - if (typeDefs.size() == 0) +static void emitParsePrintDispatch(ArrayRef types, raw_ostream &os) { + if (llvm::none_of(types, [](const TypeDef &type) { + return type.getMnemonic().hasValue(); + })) { return; - const Dialect &dialect = typeDefs.begin()->getDialect(); - NamespaceEmitter ns(os, dialect); + } - // The parser dispatch is just a list of if-elses, matching on the mnemonic - // and calling the class's parse function. - os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " + // The parser dispatch is just a list of if-elses, matching on the + // mnemonic and calling the class's parse function. + os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext* " + "ctxt, " "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n"; - for (const TypeDef &typeDef : typeDefs) - if (typeDef.getMnemonic()) + for (const TypeDef &type : types) + if (type.getMnemonic()) os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " "{0}::{1}::parse(ctxt, parser);\n", - typeDef.getDialect().getCppNamespace(), - typeDef.getCppClassName()); + type.getDialect().getCppNamespace(), + type.getCppClassName()); os << " return ::mlir::Type();\n"; os << "}\n\n"; // The printer dispatch uses llvm::TypeSwitch to find and call the correct // printer. - os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, " + os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type " + "type, " "::mlir::DialectAsmPrinter& printer) {\n" << " ::mlir::LogicalResult found = ::mlir::success();\n" << " ::llvm::TypeSwitch<::mlir::Type>(type)\n"; - for (auto typeDef : typeDefs) - if (typeDef.getMnemonic()) + for (const TypeDef &type : types) + if (type.getMnemonic()) os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ " "t.dyn_cast<{0}::{1}>().print(printer); })\n", - typeDef.getDialect().getCppNamespace(), - typeDef.getCppClassName()); + type.getDialect().getCppNamespace(), + type.getCppClassName()); os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); " "});\n" << " return found;\n"