Index: llvm/include/llvm/TableGen/Record.h =================================================================== --- llvm/include/llvm/TableGen/Record.h +++ llvm/include/llvm/TableGen/Record.h @@ -1634,6 +1634,16 @@ /// or if the value is not a string. StringRef getValueAsString(StringRef FieldName) const; + /// This method looks up the specified field and returns + /// its value as a string, throwing an exception if the field if the value is + /// not a string and llvm::Optional() if the field does not exist. + llvm::Optional getValueAsOptionalString(StringRef FieldName) const; + + /// This method looks up the specified field and returns + /// its value as a string, throwing an exception if the field if the value is + /// not a code block and llvm::Optional() if the field does not exist. + llvm::Optional getValueAsOptionalCode(StringRef FieldName) const; + /// This method looks up the specified field and returns /// its value as a BitsInit, throwing an exception if the field does not exist /// or if the value is not the right type. Index: llvm/lib/TableGen/Record.cpp =================================================================== --- llvm/lib/TableGen/Record.cpp +++ llvm/lib/TableGen/Record.cpp @@ -2218,6 +2218,34 @@ PrintFatalError(getLoc(), "Record `" + getName() + "', field `" + FieldName + "' does not have a string initializer!"); } +llvm::Optional +Record::getValueAsOptionalString(StringRef FieldName) const { + const RecordVal *R = getValue(FieldName); + if (!R || !R->getValue()) + return llvm::Optional(); + if (isa(R->getValue())) + return llvm::Optional(); + + if (StringInit *SI = dyn_cast(R->getValue())) + return SI->getValue(); + if (CodeInit *CI = dyn_cast(R->getValue())) + return CI->getValue(); + + PrintFatalError(getLoc(), "Record `" + getName() + "', field `" + FieldName + + "' does not have a string initializer!"); +} +llvm::Optional +Record::getValueAsOptionalCode(StringRef FieldName) const { + const RecordVal *R = getValue(FieldName); + if (!R || !R->getValue()) + return llvm::Optional(); + + if (CodeInit *CI = dyn_cast(R->getValue())) + return CI->getValue(); + + PrintFatalError(getLoc(), "Record `" + getName() + "', field `" + FieldName + + "' does not have a code initializer!"); +} BitsInit *Record::getValueAsBitsInit(StringRef FieldName) const { const RecordVal *R = getValue(FieldName); Index: mlir/cmake/modules/AddMLIR.cmake =================================================================== --- mlir/cmake/modules/AddMLIR.cmake +++ mlir/cmake/modules/AddMLIR.cmake @@ -9,6 +9,8 @@ set(LLVM_TARGET_DEFINITIONS ${dialect}.td) mlir_tablegen(${dialect}.h.inc -gen-op-decls) mlir_tablegen(${dialect}.cpp.inc -gen-op-defs) + mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls) + mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs) mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace}) add_public_tablegen_target(MLIR${dialect}IncGen) add_dependencies(mlir-headers MLIR${dialect}IncGen) Index: mlir/include/mlir/IR/OpBase.td =================================================================== --- mlir/include/mlir/IR/OpBase.td +++ mlir/include/mlir/IR/OpBase.td @@ -2347,4 +2347,94 @@ // so to replace the matched DAG with an existing SSA value. def replaceWithValue; + +//===----------------------------------------------------------------------===// +// Data type generation +//===----------------------------------------------------------------------===// + +// Define a new type belonging to a dialect and called 'name'. +class TypeDef { + Dialect dialect = owningdialect; + string cppClassName = name # "Type"; + + // Short summary of the type + string summary = ?; + // The longer description of this type + string description = ?; + + // Name of storage class to generate or use + string storageClass = name # "TypeStorage"; + // Namespace (withing dialect c++ namespace) in which the storage class resides + string storageNamespace = "detail"; + // Should we generate the storage class? (Or use an existing one?) + bit genStorageClass = 1; + // Should we generate the storage class constructor? + bit hasStorageCustomConstructor = 0; + + // This is the list of fields in the storage class (and list of parameters + // in the creation functions). If empty, don't use or generate a storage class. + dag parameters = (ins); + + // Use the lowercased name as the keyword for parsing/printing. Specify only + // if you want tblgen to automatically generate the printer/parser for this + // type. + string mnemonic = ?; + + // If null, generate just the declarations. + // If an empty code block, generate print/parser methods only if 'mnemonic' is specified. + // If a non-empty code block, just use that code as the definition code. + code printer = [{}]; + code parser = [{}]; + + // If set, generate accessors for each Type parameter. + bit genAccessors = 1; + // Generate the verifyConstructionInvariants declaration and getChecked method. + bit genVerifyInvariantsDecl = 0; + // Extra code to include in the class declaration + code extraClassDeclaration = [{}]; +} + +// 'Parameters' should be subclasses of this or simple strings (which is a +// shorthand for TypeParameter<"C++Type">). +class TypeParameter { + // Custom memory allocation code for storage constructor + code allocator = ?; + // The C++ type of this parameter + string cppType = type; + // A description of this parameter + string description = desc; + // The format string for the asm syntax (documentation only) + string syntax = ?; +} + +// For StringRefs, which require allocation +class StringRefParameter : TypeParameter<"::llvm::StringRef", desc> { + let allocator = [{$_dst = $_allocator.copyInto($_self);}]; + let syntax = "\"foo bar\""; +} + +// For standard ArrayRefs, which require allocation +class ArrayRefParameter : TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { + let allocator = [{$_dst = $_allocator.copyInto($_self);}]; + let syntax = "[ " # arrayOf # ", " # arrayOf # ", ... ]"; +} + +// For classes which require allocation and have their own allocateInto method +class SelfAllocationParameter : TypeParameter { + let allocator = [{$_dst = $_self.allocateInto($_allocator);}]; +} + +// For ArrayRefs which contain things which allocate themselves +class ArrayRefOfSelfAllocationParameter : TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { + let allocator = [{ + llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields; + for (size_t i = 0; i < $_self.size(); i++) { + tmpFields.push_back($_self[i].allocateInto($_allocator)); + } + $_dst = $_allocator.copyInto(ArrayRef<}] # arrayOf # [{>(tmpFields)); + }]; + let syntax = "[ " # arrayOf # ", " # arrayOf # ", ... ]"; +} + + #endif // OP_BASE Index: mlir/include/mlir/TableGen/CodeGenHelpers.h =================================================================== --- /dev/null +++ mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -0,0 +1,61 @@ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines common utilities for generating C++ from tablegen +// structures. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_CODEGENHELPERS_H +#define MLIR_TABLEGEN_CODEGENHELPERS_H + +#include "mlir/TableGen/Dialect.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace tblgen { + +// Simple RAII helper for defining ifdef-undef-endif scopes. +class IfDefScope { +public: + IfDefScope(llvm::StringRef name, llvm::raw_ostream &os) : name(name), os(os) { + os << "#ifdef " << name << "\n" + << "#undef " << name << "\n\n"; + } + ~IfDefScope() { os << "\n#endif // " << name << "\n\n"; } + +private: + llvm::StringRef name; + llvm::raw_ostream &os; +}; + +// A helper RAII class to emit nested namespaces for this op. +class NamespaceEmitter { +public: + NamespaceEmitter(raw_ostream &os, const Dialect &dialect) : os(os) { + if (!dialect) + return; + llvm::SplitString(dialect.getCppNamespace(), namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + } + + ~NamespaceEmitter() { + for (StringRef ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; + } + +private: + raw_ostream &os; + SmallVector namespaces; +}; + +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TABLEGEN_CODEGENHELPERS_H Index: mlir/include/mlir/TableGen/Operator.h =================================================================== --- mlir/include/mlir/TableGen/Operator.h +++ mlir/include/mlir/TableGen/Operator.h @@ -242,17 +242,6 @@ // debugging purposes. void print(llvm::raw_ostream &os) const; - // A helper RAII class to emit nested namespaces for this op. - class NamespaceEmitter { - public: - NamespaceEmitter(raw_ostream &os, Operator &op); - ~NamespaceEmitter(); - - private: - raw_ostream &os; - SmallVector namespaces; - }; - // Return whether all the result types are known. bool allResultTypesKnown() const { return allResultsHaveKnownTypes; }; Index: mlir/include/mlir/TableGen/TypeDef.h =================================================================== --- /dev/null +++ mlir/include/mlir/TableGen/TypeDef.h @@ -0,0 +1,141 @@ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TypeDef wrapper to simplify using TableGen Record defining a MLIR type. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_TYPEDEF_H +#define MLIR_TABLEGEN_TYPEDEF_H + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Dialect.h" +#include "llvm/TableGen/Record.h" +#include +#include + +namespace mlir { +namespace tblgen { + +class TypeParameter; + +// Wrapper class that contains a TableGen TypeDef's record and provides helper +// methods for accessing them. +class TypeDef { +public: + explicit TypeDef(const llvm::Record *def) : def(def) {} + + // Get the dialect for which this type belongs + Dialect getDialect() const; + + // Returns the name of this TypeDef record + StringRef getName() const; + + // Query functions for the documentation of the operator. + bool hasDescription() const; + StringRef getDescription() const; + bool hasSummary() const; + StringRef getSummary() const; + + // Returns the name of the C++ class to generate + StringRef getCppClassName() const; + + // Returns the name of the storage class for this type + StringRef getStorageClassName() const; + + // Returns the C++ namespace for this types storage class + StringRef getStorageNamespace() const; + + // Returns true if we should generate the storage class + bool genStorageClass() const; + + // Should we generate the storage class constructor? + bool hasStorageCustomConstructor() const; + + // Return the list of fields for the storage class and constructors + void getParameters(SmallVectorImpl &) const; + unsigned getNumParameters() const; + + // Iterate though parameters, applying a map function before adding to list + template + void getParametersAs(SmallVectorImpl ¶meters, + llvm::function_ref map) const; + + // Return the keyword/mnemonic to use in the printer/parser methods if we are + // supposed to auto-generate them + llvm::Optional getMnemonic() const; + + // Returns the code to use as the types printer method. If empty, generate + // just the declaration. If null and mnemonic is non-null, generate the + // declaration and definition. + llvm::Optional getPrinterCode() const; + + // Returns the code to use as the types parser method. If empty, generate + // just the declaration. If null and mnemonic is non-null, generate the + // declaration and definition. + llvm::Optional getParserCode() const; + + // Should we generate accessors based on the types parameters? + bool genAccessors() const; + + // Return true if we need to generate the verifyConstructionInvariants + // declaration and getChecked method + bool genVerifyInvariantsDecl() const; + + // Returns the dialects extra class declaration code. + llvm::Optional getExtraDecls() const; + + // Returns whether two TypeDefs are equal by checking the equality of the + // underlying record. + bool operator==(const TypeDef &other) const; + + // Compares two TypeDefs by comparing the names of the dialects. + bool operator<(const TypeDef &other) const; + + // Returns whether the TypeDef is defined. + operator bool() const { return def != nullptr; } + +private: + const llvm::Record *def; +}; + +// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs +// to parameterize them. +class TypeParameter { +public: + explicit TypeParameter(const llvm::DagInit *def, unsigned num) + : def(def), num(num) {} + + // Get the parameter name + StringRef getName() const; + // If specified, get the custom allocator code for this parameter + llvm::Optional getAllocator() const; + // Get the C++ type of this parameter + StringRef getCppType() const; + // Get a description of this parameter for documentation purposes + llvm::Optional getDescription() const; + // Get the assembly syntax documentation + StringRef getSyntax() const; + +private: + const llvm::DagInit *def; + const unsigned num; +}; + +template +void TypeDef::getParametersAs(SmallVectorImpl ¶meters, + llvm::function_ref map) const { + auto parametersDag = def->getValueAsDag("parameters"); + if (parametersDag != nullptr) + for (unsigned i = 0; i < parametersDag->getNumArgs(); i++) + parameters.push_back(map(TypeParameter(parametersDag, i))); +} + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_TYPEDEF_H Index: mlir/include/mlir/TableGen/TypeDefGenHelpers.h =================================================================== --- /dev/null +++ mlir/include/mlir/TableGen/TypeDefGenHelpers.h @@ -0,0 +1,239 @@ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Accessory functions / templates to assist autogenerated code. The print/parse +// struct templates define standard serializations which can be overridden with +// custom printers/parsers. These structs can be used for temporary stack +// storage also. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_PARSER_HELPERS_H +#define MLIR_TABLEGEN_PARSER_HELPERS_H + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/DialectImplementation.h" +#include + +namespace mlir { +namespace tblgen { +namespace parser_helpers { + +//===----------------------------------------------------------------------===// +// +// Template enables identify various types for which we have specializations +// +//===----------------------------------------------------------------------===// + +template +using void_t = void; + +template +using remove_constref = + typename std::remove_const::type>::type; + +template +using enable_if_type = typename std::enable_if< + std::is_same, TestType>::value>::type; + +template +using is_not_type = + std::is_same, TestType>::type, + typename std::false_type::type>; + +template +using get_indexable_type = remove_constref()[0])>; + +template +using enable_if_arrayref = + enable_if_type>>; + +//===----------------------------------------------------------------------===// +// +// These structs handle Type parameters' parsing for common types +// +//===----------------------------------------------------------------------===// + +template +struct Parse { + ParseResult go(MLIRContext *ctxt, // The context, should it be needed + DialectAsmParser &parser, // The parser + StringRef parameterName, // Type parameter name, for error + // printing (if necessary) + T &result); // Put the parsed value here +}; + +// Int specialization +template +using enable_if_integral_type = + typename std::enable_if::value && + is_not_type::value>::type; +template +struct Parse> { + ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, + StringRef parameterName, T &result) { + return parser.parseInteger(result); + } +}; + +// Bool specialization -- 'true' / 'false' instead of 0/1 +template +struct Parse> { + ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, + StringRef parameterName, bool &result) { + StringRef boolStr; + if (parser.parseKeyword(&boolStr)) + return mlir::failure(); + if (!boolStr.compare_lower("false")) { + result = false; + return mlir::success(); + } + if (!boolStr.compare_lower("true")) { + result = true; + return mlir::success(); + } + llvm::errs() << "Parser expected true/false, not '" << boolStr << "'\n"; + return mlir::failure(); + } +}; + +// Float specialization +template +using enable_if_float_type = + typename std::enable_if::value>::type; +template +struct Parse> { + ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, + StringRef parameterName, T &result) { + double d; + if (parser.parseFloat(d)) + return mlir::failure(); + result = d; + return mlir::success(); + } +}; + +// mlir::Type specialization +template +using enable_if_mlir_type = + typename std::enable_if::value>::type; +template +struct Parse> { + ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, + StringRef parameterName, T &result) { + Type type; + auto loc = parser.getCurrentLocation(); + if (parser.parseType(type)) + return mlir::failure(); + if ((result = type.dyn_cast_or_null()) == nullptr) { + parser.emitError(loc, "expected type '" + parameterName + "'"); + return mlir::failure(); + } + return mlir::success(); + } +}; + +// StringRef specialization +template +struct Parse> { + ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, + StringRef parameterName, StringRef &result) { + StringAttr a; + if (parser.parseAttribute(a)) + return mlir::failure(); + result = a.getValue(); + return mlir::success(); + } +}; + +// ArrayRef specialization +template +struct Parse> { + using inner_t = get_indexable_type; + Parse innerParser; + llvm::SmallVector parameters; + + ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, + StringRef parameterName, ArrayRef &result) { + if (parser.parseLSquare()) + return mlir::failure(); + if (failed(parser.parseOptionalRSquare())) { + do { + inner_t parameter; // = std::declval(); + innerParser.go(ctxt, parser, parameterName, parameter); + parameters.push_back(parameter); + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRSquare()) + return mlir::failure(); + } + result = ArrayRef(parameters); + return mlir::success(); + } +}; + +//===----------------------------------------------------------------------===// +// +// These structs handle Type parameters' printing for common types +// +//===----------------------------------------------------------------------===// + +template +struct Print { + static void go(DialectAsmPrinter &printer, const T &obj); +}; + +// Several C++ types can just be piped into the printer +template +using enable_if_trivial_print = + typename std::enable_if::value || + (std::is_integral::value && + is_not_type::value) || + std::is_floating_point::value>::type; +template +struct Print>> { + static void go(DialectAsmPrinter &printer, const T &obj) { printer << obj; } +}; + +// StringRef has to be quoted to match the parse specialization above +template +struct Print> { + static void go(DialectAsmPrinter &printer, const T &obj) { + printer << "\"" << obj << "\""; + } +}; + +// bool specialization +template +struct Print> { + static void go(DialectAsmPrinter &printer, const bool &obj) { + if (obj) + printer << "true"; + else + printer << "false"; + } +}; + +// ArrayRef specialization +template +struct Print> { + static void go(DialectAsmPrinter &printer, + const ArrayRef> &obj) { + printer << "["; + for (size_t i = 0; i < obj.size(); i++) { + Print>::go(printer, obj[i]); + if (i < obj.size() - 1) + printer << ", "; + } + printer << "]"; + } +}; + +} // end namespace parser_helpers +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_PARSER_HELPERS_H Index: mlir/lib/TableGen/CMakeLists.txt =================================================================== --- mlir/lib/TableGen/CMakeLists.txt +++ mlir/lib/TableGen/CMakeLists.txt @@ -25,6 +25,7 @@ SideEffects.cpp Successor.cpp Type.cpp + TypeDef.cpp DISABLE_LLVM_LINK_LLVM_DYLIB Index: mlir/lib/TableGen/Operator.cpp =================================================================== --- mlir/lib/TableGen/Operator.cpp +++ mlir/lib/TableGen/Operator.cpp @@ -566,21 +566,6 @@ } } -Operator::NamespaceEmitter::NamespaceEmitter(raw_ostream &os, Operator &op) - : os(os) { - auto dialect = op.getDialect(); - if (!dialect) - return; - llvm::SplitString(dialect.getCppNamespace(), namespaces, "::"); - for (StringRef ns : namespaces) - os << "namespace " << ns << " {\n"; -} - -Operator::NamespaceEmitter::~NamespaceEmitter() { - for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; -} - auto Operator::VariableDecoratorIterator::unwrap(llvm::Init *init) -> VariableDecorator { return VariableDecorator(cast(init)->getDef()); Index: mlir/lib/TableGen/TypeDef.cpp =================================================================== --- /dev/null +++ mlir/lib/TableGen/TypeDef.cpp @@ -0,0 +1,169 @@ +//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/TypeDef.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +Dialect TypeDef::getDialect() const { + auto dialectDef = + dyn_cast(def->getValue("dialect")->getValue()); + if (dialectDef == nullptr) + return Dialect(nullptr); + return Dialect(dialectDef->getDef()); +} + +StringRef TypeDef::getName() const { return def->getName(); } +StringRef TypeDef::getCppClassName() const { + return def->getValueAsString("cppClassName"); +} + +bool TypeDef::hasDescription() const { + const llvm::RecordVal *s = def->getValue("description"); + return s != nullptr && isa(s->getValue()); +} + +StringRef TypeDef::getDescription() const { + return def->getValueAsString("description"); +} + +bool TypeDef::hasSummary() const { + const llvm::RecordVal *s = def->getValue("summary"); + return s != nullptr && isa(s->getValue()); +} + +StringRef TypeDef::getSummary() const { + return def->getValueAsString("summary"); +} + +StringRef TypeDef::getStorageClassName() const { + return def->getValueAsString("storageClass"); +} +StringRef TypeDef::getStorageNamespace() const { + return def->getValueAsString("storageNamespace"); +} + +bool TypeDef::genStorageClass() const { + return def->getValueAsBit("genStorageClass"); +} +bool TypeDef::hasStorageCustomConstructor() const { + return def->getValueAsBit("hasStorageCustomConstructor"); +} +void TypeDef::getParameters(SmallVectorImpl ¶meters) const { + auto *parametersDag = def->getValueAsDag("parameters"); + if (parametersDag != nullptr) { + size_t numParams = parametersDag->getNumArgs(); + for (unsigned i = 0; i < numParams; i++) { + parameters.push_back(TypeParameter(parametersDag, i)); + } + } +} +unsigned TypeDef::getNumParameters() const { + auto *parametersDag = def->getValueAsDag("parameters"); + return parametersDag ? parametersDag->getNumArgs() : 0; +} +llvm::Optional TypeDef::getMnemonic() const { + return def->getValueAsOptionalString("mnemonic"); +} +llvm::Optional TypeDef::getPrinterCode() const { + return def->getValueAsOptionalCode("printer"); +} +llvm::Optional TypeDef::getParserCode() const { + return def->getValueAsOptionalCode("parser"); +} +bool TypeDef::genAccessors() const { + return def->getValueAsBit("genAccessors"); +} +bool TypeDef::genVerifyInvariantsDecl() const { + return def->getValueAsBit("genVerifyInvariantsDecl"); +} + +llvm::Optional TypeDef::getExtraDecls() const { + auto value = def->getValueAsString("extraClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} + +bool TypeDef::operator==(const TypeDef &other) const { + return def == other.def; +} + +bool TypeDef::operator<(const TypeDef &other) const { + return getName() < other.getName(); +} + +StringRef TypeParameter::getName() const { + return def->getArgName(num)->getValue(); +} +llvm::Optional TypeParameter::getAllocator() const { + auto *parameterType = def->getArg(num); + if (auto *stringType = dyn_cast(parameterType)) { + return llvm::Optional(); + } + + if (auto *typeParameter = dyn_cast(parameterType)) { + auto *code = typeParameter->getDef()->getValue("allocator"); + if (llvm::CodeInit *ci = dyn_cast(code->getValue())) + return ci->getValue(); + if (isa(code->getValue())) + return llvm::Optional(); + + llvm::PrintFatalError( + typeParameter->getDef()->getLoc(), + "Record `" + def->getArgName(num)->getValue() + + "', field `printer' does not have a code initializer!"); + } + + llvm::PrintFatalError( + "Parameters DAG arguments must be either strings or defs " + "which inherit from TypeParameter\n"); +} +StringRef TypeParameter::getCppType() const { + auto *parameterType = def->getArg(num); + if (auto *stringType = dyn_cast(parameterType)) { + return stringType->getValue(); + } + if (auto *typeParameter = dyn_cast(parameterType)) { + return typeParameter->getDef()->getValueAsString("cppType"); + } + llvm::PrintFatalError( + "Parameters DAG arguments must be either strings or defs " + "which inherit from TypeParameter\n"); +} +llvm::Optional TypeParameter::getDescription() const { + + auto *parameterType = def->getArg(num); + if (auto *typeParameter = dyn_cast(parameterType)) { + const auto *desc = typeParameter->getDef()->getValue("description"); + if (llvm::StringInit *ci = dyn_cast(desc->getValue())) + return ci->getValue(); + } + return llvm::Optional(); +} +StringRef TypeParameter::getSyntax() const { + auto *parameterType = def->getArg(num); + if (auto *stringType = dyn_cast(parameterType)) { + return stringType->getValue(); + } else if (auto *typeParameter = dyn_cast(parameterType)) { + const auto *syntax = typeParameter->getDef()->getValue("syntax"); + if (syntax != nullptr && isa(syntax->getValue())) + return dyn_cast(syntax->getValue())->getValue(); + return getCppType(); + } else { + llvm::errs() << "Parameters DAG arguments must be either strings or defs " + "which inherit from TypeParameter\n"; + return StringRef(); + } +} Index: mlir/test/lib/Dialect/Test/CMakeLists.txt =================================================================== --- mlir/test/lib/Dialect/Test/CMakeLists.txt +++ mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -8,6 +8,12 @@ mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIRTestInterfaceIncGen) +set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td) +mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls) +mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(MLIRTestDefIncGen) + + set(LLVM_TARGET_DEFINITIONS TestOps.td) mlir_tablegen(TestOps.h.inc -gen-op-decls) mlir_tablegen(TestOps.cpp.inc -gen-op-defs) @@ -23,11 +29,13 @@ add_mlir_library(MLIRTestDialect TestDialect.cpp TestPatterns.cpp + TestTypes.cpp EXCLUDE_FROM_LIBMLIR DEPENDS MLIRTestInterfaceIncGen + MLIRTestDefIncGen MLIRTestOpsIncGen LINK_LIBS PUBLIC Index: mlir/test/lib/Dialect/Test/TestDialect.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.cpp +++ mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -141,16 +141,23 @@ >(); addInterfaces(); - addTypes(); + addTypes(); allowUnknownOperations(); } -static Type parseTestType(DialectAsmParser &parser, +static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, llvm::SetVector &stack) { StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); + auto genType = generatedTypeParser(ctxt, parser, typeTag); + if (genType != Type()) + return genType; + if (typeTag == "test_type") return TestType::get(parser.getBuilder().getContext()); @@ -174,7 +181,7 @@ if (failed(parser.parseComma())) return Type(); stack.insert(rec); - Type subtype = parseTestType(parser, stack); + Type subtype = parseTestType(ctxt, parser, stack); stack.pop_back(); if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) return Type(); @@ -184,11 +191,13 @@ Type TestDialect::parseType(DialectAsmParser &parser) const { llvm::SetVector stack; - return parseTestType(parser, stack); + return parseTestType(getContext(), parser, stack); } static void printTestType(Type type, DialectAsmPrinter &printer, llvm::SetVector &stack) { + if (!generatedTypePrinter(type, printer)) + return; if (type.isa()) { printer << "test_type"; return; Index: mlir/test/lib/Dialect/Test/TestTypeDefs.td =================================================================== --- /dev/null +++ mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -0,0 +1,111 @@ +//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_TYPEDEFS +#define TEST_TYPEDEFS + +// To get the test dialect def +include "TestOps.td" + +class Test_Type : TypeDef { } + +def SimpleTypeA : Test_Type<"SimpleA"> { + let mnemonic = "smpla"; +} + +// A more complex parameterized type +def CompoundTypeA : Test_Type<"CompoundA"> { + // Override the default mnemonic + let mnemonic = "cmpnd_a"; + + // What types do we contain? + let parameters = ( + ins + "int":$widthOfSomething, + "::mlir::SimpleAType": $exampleTdType, + ArrayRefParameter<"int", "">: $arrayOfInts, + ArrayRefParameter<"Type", "An example of an array of types as a type parameter">: $arrayOfTypes, + "::llvm::StringRef": $simpleString, + ArrayRefParameter<"::llvm::StringRef", "">: $arrayOfStrings + ); + + let extraClassDeclaration = [{ + struct SomeCppStruct {}; + }]; +} + +def IntegerType : Test_Type<"TestInteger"> { + let mnemonic = "int"; + let genVerifyInvariantsDecl = 1; + let parameters = ( + ins + "::mlir::TestIntegerType::SignednessSemantics":$signedness, + "unsigned":$width + ); + + let extraClassDeclaration = [{ + /// Signedness semantics. + enum SignednessSemantics { + Signless, /// No signedness semantics + Signed, /// Signed integer + Unsigned, /// Unsigned integer + }; + + /// This extra function is necessary since it doesn't include signedness + static IntegerType getChecked(unsigned width, Location location); + + /// Return true if this is a signless integer type. + bool isSignless() const { return getSignedness() == Signless; } + /// Return true if this is a signed integer type. + bool isSigned() const { return getSignedness() == Signed; } + /// Return true if this is an unsigned integer type. + bool isUnsigned() const { return getSignedness() == Unsigned; } + }]; +} + +class FieldInfo_Type : Test_Type { +let parameters = ( + ins + ArrayRefOfSelfAllocationParameter<"::mlir::FieldInfo", "Models struct fields">: $fields +); + +let printer = [{ + printer << "struct" << "<"; + for (size_t i=0; ifields.size(); i++) { + const auto& field = getImpl()->fields[i]; + printer << "{" << field.name << "," << field.type << "}"; + if (i < getImpl()->fields.size() - 1) + printer << ","; + } + printer << ">"; +}]; + +let parser = [{ + llvm::SmallVector parameters; + if (parser.parseLess()) return Type(); + while (mlir::succeeded(parser.parseOptionalLBrace())) { + StringRef name; + if (parser.parseKeyword(&name)) return Type(); + if (parser.parseComma()) return Type(); + Type type; + if (parser.parseType(type)) return Type(); + if (parser.parseRBrace()) return Type(); + parameters.push_back(FieldInfo {name, type}); + if (parser.parseOptionalComma()) break; + } + if (parser.parseGreater()) return Type(); + return get(ctxt, parameters); +}]; +} + +def StructType : FieldInfo_Type<"Struct"> { + let mnemonic = "struct"; +} + + +#endif // TEST_TYPEDEFS Index: mlir/test/lib/Dialect/Test/TestTypes.h =================================================================== --- mlir/test/lib/Dialect/Test/TestTypes.h +++ mlir/test/lib/Dialect/Test/TestTypes.h @@ -14,11 +14,36 @@ #ifndef MLIR_TESTTYPES_H #define MLIR_TESTTYPES_H +#include + #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" namespace mlir { +struct FieldInfo { +public: + StringRef name; + Type type; + + FieldInfo allocateInto(TypeStorageAllocator &alloc) const { + return FieldInfo{alloc.copyInto(name), type}; + } +}; + +bool operator==(const FieldInfo &a, const FieldInfo &b); +llvm::hash_code hash_value(const FieldInfo &fi); + +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.h.inc" + +namespace mlir { + #include "TestTypeInterfaces.h.inc" /// This class is a simple test type that uses a generated interface. Index: mlir/test/lib/Dialect/Test/TestTypes.cpp =================================================================== --- /dev/null +++ mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -0,0 +1,93 @@ +//===- TestTypes.cpp - MLIR Test Dialect Types ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains types defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#include "TestTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Types.h" +#include "mlir/TableGen/TypeDefGenHelpers.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { +namespace tblgen { +namespace parser_helpers { + +// Custom parser for SignednessSemantics +template <> +struct Parse { + static ParseResult go(MLIRContext *ctxt, DialectAsmParser &parser, + StringRef parameterName, + TestIntegerType::SignednessSemantics &result) { + StringRef signStr; + auto loc = parser.getCurrentLocation(); + if (parser.parseKeyword(&signStr)) + return mlir::failure(); + if (signStr.compare_lower("u") || signStr.compare_lower("unsigned")) + result = TestIntegerType::SignednessSemantics::Unsigned; + else if (signStr.compare_lower("s") || signStr.compare_lower("signed")) + result = TestIntegerType::SignednessSemantics::Signed; + else if (signStr.compare_lower("n") || signStr.compare_lower("none")) + result = TestIntegerType::SignednessSemantics::Signless; + else { + parser.emitError(loc, "expected signed, unsigned, or none"); + return mlir::failure(); + } + return mlir::success(); + } +}; + +// Custom printer for SignednessSemantics +template <> +struct Print { + static void go(DialectAsmPrinter &printer, + const TestIntegerType::SignednessSemantics &ss) { + switch (ss) { + case TestIntegerType::SignednessSemantics::Unsigned: + printer << "unsigned"; + break; + case TestIntegerType::SignednessSemantics::Signed: + printer << "signed"; + break; + case TestIntegerType::SignednessSemantics::Signless: + printer << "none"; + break; + } + } +}; + +} // namespace parser_helpers +} // namespace tblgen + +bool operator==(const FieldInfo &a, const FieldInfo &b) { + return a.name == b.name && a.type == b.type; +} + +llvm::hash_code hash_value(const FieldInfo &fi) { + return llvm::hash_combine(fi.name, fi.type); +} + +// Example type validity checker +LogicalResult TestIntegerType::verifyConstructionInvariants( + mlir::Location loc, mlir::TestIntegerType::SignednessSemantics ss, + unsigned int width) { + + if (width > 8) + return mlir::failure(); + return mlir::success(); +} + +struct TestType; +} // end namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.cpp.inc" Index: mlir/test/mlir-tblgen/testdialect-typedefs.mlir =================================================================== --- /dev/null +++ mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s + +////////////// +// Tests the types in the 'Test' dialect, not the ones in 'typedefs.mlir' + +// CHECK: @simpleA(%arg0: !test.smpla) +func @simpleA(%A : !test.smpla) -> () { + return +} + +// CHECK: @compoundA(%arg0: !test.cmpnd_a<1, !test.smpla, [5, 6], [i1, i2], "example str", ["array", "of", "strings"]>) +func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6], [i1, i2], "example str", ["array","of","strings"]>) -> () { + return +} + +// CHECK: @testInt(%arg0: !test.int, %arg1: !test.int, %arg2: !test.int) +func @testInt(%A : !test.int, %B : !test.int, %C : !test.int) { + return +} + +// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int}>) +func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int} > ) { + return +} Index: mlir/test/mlir-tblgen/typedefs.td =================================================================== --- /dev/null +++ mlir/test/mlir-tblgen/typedefs.td @@ -0,0 +1,136 @@ +// RUN: mlir-tblgen -gen-typedef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL +// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF + +include "mlir/IR/OpBase.td" + +// DECL: #ifdef GET_TYPEDEF_CLASSES +// DECL: #undef GET_TYPEDEF_CLASSES + +// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic); +// DECL: bool generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer); + +// DEF: #ifdef GET_TYPEDEF_LIST +// DEF: #undef GET_TYPEDEF_LIST +// DEF: ::mlir::test::SimpleAType, +// DEF: ::mlir::test::CompoundAType, +// DEF: ::mlir::test::IndexType, +// DEF: ::mlir::test::SingleParameterType, +// DEF: ::mlir::test::IntegerType + +// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) +// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser); +// DEF return ::mlir::Type(); + +def Test_Dialect: Dialect { +// DECL-NOT: TestDialect +// DEF-NOT: TestDialect + let name = "TestDialect"; + let cppNamespace = "::mlir::test"; +} + +class TestType : TypeDef { } + +def A_SimpleTypeA : TestType<"SimpleA"> { +// DECL: class SimpleAType: public ::mlir::Type +// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); +// DECL: void print(::mlir::DialectAsmPrinter& printer) const; +} + +// A more complex parameterized type +def B_CompoundTypeA : TestType<"CompoundA"> { + let summary = "A more complex parameterized type"; + let description = "This type is to test a reasonably complex type"; + let mnemonic = "cmpnd_a"; + let parameters = ( + ins + "int":$widthOfSomething, + "::mlir::test::SimpleTypeA": $exampleTdType, + "SomeCppStruct": $exampleCppType, + ArrayRefParameter<"int", "Matrix dimensions">:$dims + ); + + let genVerifyInvariantsDecl = 1; + +// DECL-LABEL: class CompoundAType: public ::mlir::Type +// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); +// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); +// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; } +// DECL: int getWidthOfSomething() const; +// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const; +// DECL: SomeCppStruct getExampleCppType() const; +} + +def C_IndexType : TestType<"Index"> { + let mnemonic = "index"; + + let parameters = ( + ins + StringRefParameter<"Label for index">:$label + ); + +// DECL-LABEL: class IndexType: public ::mlir::Type +// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); +// DECL: void print(::mlir::DialectAsmPrinter& printer) const; +// DECL: static ::llvm::StringRef getMnemonic() { return "index"; } +} + +def D_SingleParameterType : TestType<"SingleParameter"> { + let parameters = ( + ins + "int": $num + ); +// DECL-LABEL: struct SingleParameterTypeStorage; +// DECL-LABEL: class SingleParameterType +// DECL-NEXT: detail::SingleParameterTypeStorage +// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); +// DECL: void print(::mlir::DialectAsmPrinter& printer) const; +} + +def E_IntegerType : TestType<"Integer"> { + let parser = [{}]; + let printer = [{}]; + let mnemonic = "int"; + let genVerifyInvariantsDecl = 1; + let parameters = ( + ins + "SignednessSemantics":$signedness, + TypeParameter<"unsigned", "Bitwdith of integer">:$width + ); + +// DECL-LABEL: IntegerType: public ::mlir::Type + + let extraClassDeclaration = [{ + /// Signedness semantics. + enum SignednessSemantics { + Signless, /// No signedness semantics + Signed, /// Signed integer + Unsigned, /// Unsigned integer + }; + + /// This extra function is necessary since it doesn't include signedness + static IntegerType getChecked(unsigned width, Location location); + + /// Return true if this is a signless integer type. + bool isSignless() const { return getSignedness() == Signless; } + /// Return true if this is a signed integer type. + bool isSigned() const { return getSignedness() == Signed; } + /// Return true if this is an unsigned integer type. + bool isUnsigned() const { return getSignedness() == Unsigned; } + }]; + +// DECL: /// Signedness semantics. +// DECL-NEXT: enum SignednessSemantics { +// DECL-NEXT: Signless, /// No signedness semantics +// DECL-NEXT: Signed, /// Signed integer +// DECL-NEXT: Unsigned, /// Unsigned integer +// DECL-NEXT: }; +// DECL: /// This extra function is necessary since it doesn't include signedness +// DECL-NEXT: static IntegerType getChecked(unsigned width, Location location); + +// DECL: /// Return true if this is a signless integer type. +// DECL-NEXT: bool isSignless() const { return getSignedness() == Signless; } +// DECL-NEXT: /// Return true if this is a signed integer type. +// DECL-NEXT: bool isSigned() const { return getSignedness() == Signed; } +// DECL-NEXT: /// Return true if this is an unsigned integer type. +// DECL-NEXT: bool isUnsigned() const { return getSignedness() == Unsigned; } +} Index: mlir/tools/mlir-tblgen/CMakeLists.txt =================================================================== --- mlir/tools/mlir-tblgen/CMakeLists.txt +++ mlir/tools/mlir-tblgen/CMakeLists.txt @@ -20,6 +20,7 @@ RewriterGen.cpp SPIRVUtilsGen.cpp StructsGen.cpp + TypeDefGen.cpp ) set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning") Index: mlir/tools/mlir-tblgen/DialectGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/DialectGen.cpp +++ mlir/tools/mlir-tblgen/DialectGen.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" @@ -155,12 +156,7 @@ } // Emit all nested namespaces. - StringRef cppNamespace = dialect.getCppNamespace(); - llvm::SmallVector namespaces; - llvm::SplitString(cppNamespace, namespaces, "::"); - - for (auto ns : namespaces) - os << "namespace " << ns << " {\n"; + NamespaceEmitter nsEmitter(os, dialect); // Emit the start of the decl. std::string cppName = dialect.getCppClassName(); @@ -188,10 +184,6 @@ // End the dialect decl. os << "};\n"; - - // Close all nested namespaces in reverse order. - for (auto ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; } static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper, Index: mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "OpFormatGen.h" +#include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" @@ -158,23 +159,6 @@ // Op emitter //===----------------------------------------------------------------------===// -namespace { -// Simple RAII helper for defining ifdef-undef-endif scopes. -class IfDefScope { -public: - IfDefScope(StringRef name, raw_ostream &os) : name(name), os(os) { - os << "#ifdef " << name << "\n" - << "#undef " << name << "\n\n"; - } - - ~IfDefScope() { os << "\n#endif // " << name << "\n\n"; } - -private: - StringRef name; - raw_ostream &os; -}; -} // end anonymous namespace - namespace { // Helper class to emit a record into the given output stream. class OpEmitter { @@ -2178,7 +2162,7 @@ os << "#undef GET_OP_FWD_DEFINES\n"; for (auto *def : defs) { Operator op(*def); - Operator::NamespaceEmitter emitter(os, op); + NamespaceEmitter emitter(os, op.getDialect()); os << "class " << op.getCppClassName() << ";\n"; } os << "#endif\n\n"; @@ -2187,7 +2171,7 @@ IfDefScope scope("GET_OP_CLASSES", os); for (auto *def : defs) { Operator op(*def); - Operator::NamespaceEmitter emitter(os, op); + NamespaceEmitter emitter(os, op.getDialect()); if (emitDecl) { os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); OpOperandAdaptorEmitter::emitDecl(op, os); Index: mlir/tools/mlir-tblgen/OpDocGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpDocGen.cpp +++ mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -14,6 +14,7 @@ #include "DocGenUtilities.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/TypeDef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -22,6 +23,8 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#include + using namespace llvm; using namespace mlir; using namespace mlir::tblgen; @@ -185,12 +188,66 @@ os << "\n"; } +//===----------------------------------------------------------------------===// +// TypeDef Documentation +//===----------------------------------------------------------------------===// + +/// Emit the assembly format of a type. +static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) { + SmallVector parameters; + td.getParameters(parameters); + if (parameters.size() == 0) { + os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic() + << "`\n"; + } else { + os << "\nSyntax:\n\n```\n!" << td.getDialect().getName() << "." + << td.getMnemonic() << "<\n"; + for (auto it = parameters.begin(); it < parameters.end(); it++) { + os << " " << it->getSyntax(); + if (it < parameters.end() - 1) + os << ","; + os << " # " << it->getName() << "\n"; + } + os << ">\n```\n"; + } +} + +static void emitTypeDefDoc(TypeDef td, raw_ostream &os) { + os << llvm::formatv("### `{0}` ({1})\n", td.getName(), td.getCppClassName()); + + // Emit the summary, syntax, and description if present. + if (td.hasSummary()) + os << "\n" << td.getSummary() << "\n"; + if (td.getMnemonic() && td.getPrinterCode() && *td.getPrinterCode() == "" && + td.getParserCode() && *td.getParserCode() == "") + emitTypeAssemblyFormat(td, os); + if (td.hasDescription()) + mlir::tblgen::emitDescription(td.getDescription(), os); + + // Emit attributes. + SmallVector parameters; + td.getParameters(parameters); + if (parameters.size() != 0) { + os << "\n#### Type parameters:\n\n"; + os << "| Parameter | C++ type | Description |\n" + << "| :-------: | :-------: | ----------- |\n"; + for (const auto &it : parameters) { + auto desc = it.getDescription(); + os << "| " << it.getName() << " | `" << td.getCppClassName() << "` | " + << (desc ? *desc : "") << " |\n"; + } + } + + os << "\n"; +} + //===----------------------------------------------------------------------===// // Dialect Documentation //===----------------------------------------------------------------------===// static void emitDialectDoc(const Dialect &dialect, ArrayRef ops, - ArrayRef types, raw_ostream &os) { + ArrayRef types, ArrayRef typeDefs, + raw_ostream &os) { os << "# '" << dialect.getName() << "' Dialect\n\n"; emitIfNotEmpty(dialect.getSummary(), os); emitIfNotEmpty(dialect.getDescription(), os); @@ -199,7 +256,7 @@ // TODO: Add link between use and def for types if (!types.empty()) { - os << "## Type definition\n\n"; + os << "## Type constraint definition\n\n"; for (const Type &type : types) emitTypeDoc(type, os); } @@ -209,28 +266,44 @@ for (const Operator &op : ops) emitOpDoc(op, os); } + + if (!typeDefs.empty()) { + os << "## Type definition\n\n"; + for (const TypeDef &td : typeDefs) { + emitTypeDefDoc(td, os); + } + } } static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { const auto &opDefs = recordKeeper.getAllDerivedDefinitions("Op"); const auto &typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType"); + const auto &typeDefDefs = recordKeeper.getAllDerivedDefinitions("TypeDef"); + std::set dialectsWithDocs; std::map> dialectOps; std::map> dialectTypes; + std::map> dialectTypeDefs; for (auto *opDef : opDefs) { Operator op(opDef); dialectOps[op.getDialect()].push_back(op); + dialectsWithDocs.insert(op.getDialect()); } for (auto *typeDef : typeDefs) { Type type(typeDef); if (auto dialect = type.getDialect()) dialectTypes[dialect].push_back(type); } + for (auto *typeDef : typeDefDefs) { + TypeDef type(typeDef); + dialectTypeDefs[type.getDialect()].push_back(type); + dialectsWithDocs.insert(type.getDialect()); + } os << "\n"; - for (auto dialectWithOps : dialectOps) - emitDialectDoc(dialectWithOps.first, dialectWithOps.second, - dialectTypes[dialectWithOps.first], os); + for (auto dialect : dialectsWithDocs) + emitDialectDoc(dialect, dialectOps[dialect], dialectTypes[dialect], + dialectTypeDefs[dialect], os); } //===----------------------------------------------------------------------===// Index: mlir/tools/mlir-tblgen/TypeDefGen.cpp =================================================================== --- /dev/null +++ mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -0,0 +1,627 @@ +//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TypeDefGen uses the description of typeDefs to generate C++ definitions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" +#include "mlir/TableGen/OpClass.h" +#include "mlir/TableGen/OpTrait.h" +#include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/TypeDef.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#define DEBUG_TYPE "mlir-tblgen-typedefgen" + +using namespace mlir; +using namespace mlir::tblgen; + +static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*"); +static llvm::cl::opt + selectedDialect("typedefs-dialect", + llvm::cl::desc("Gen types for this dialect"), + llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated); + +/// Find all the TypeDefs for the specified dialect. If no dialect specified and +/// can only find one dialect's types, use that. +static mlir::LogicalResult +findAllTypeDefs(const llvm::RecordKeeper &recordKeeper, + SmallVectorImpl &typeDefs) { + auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef"); + auto defs = llvm::map_range( + recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); }); + if (defs.empty()) + return mlir::success(); + + StringRef dialectName; + if (selectedDialect.getNumOccurrences() == 0) { + if (defs.empty()) + return mlir::success(); + + llvm::SmallSet dialects; + for (auto typeDef : defs) { + dialects.insert(typeDef.getDialect()); + } + if (dialects.size() != 1) { + llvm::errs() << "TypeDefs belonging to more than one dialect. Must " + "select one via '--typedefs-dialect'\n"; + return mlir::failure(); + } + + dialectName = (*dialects.begin()).getName(); + } else if (selectedDialect.getNumOccurrences() == 1) { + dialectName = selectedDialect.getValue(); + } else { + llvm::errs() + << "cannot select multiple dialects for which to generate types" + "via '--typedefs-dialect'\n"; + return mlir::failure(); + } + + for (auto typeDef : defs) { + if (typeDef.getDialect().getName().equals(dialectName)) + typeDefs.push_back(typeDef); + } + return mlir::success(); +} + +/// Create a string list of parameters and types for function decls +/// String construction helper function: parameter1Type parameter1Name, +/// parameter2Type parameter2Name +static std::string constructParameterParameters(TypeDef &typeDef, + bool prependComma) { + SmallVector parameters; + if (prependComma) + parameters.push_back(""); + typeDef.getParametersAs(parameters, [](auto parameter) { + return (parameter.getCppType() + " " + parameter.getName()).str(); + }); + if (parameters.size() > 0) + return llvm::join(parameters, ", "); + return ""; +} + +/// Create an initializer for the storage class +/// String construction helper function: parameter1(parameter1), +/// parameter2(parameter2), +/// [...] +static std::string constructParametersInitializers(TypeDef &typeDef) { + SmallVector parameters; + typeDef.getParametersAs(parameters, [](auto parameter) { + return (parameter.getName() + "(" + parameter.getName() + ")").str(); + }); + return llvm::join(parameters, ", "); +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef declarations +//===----------------------------------------------------------------------===// + +/// The code block for the start of a typeDef class declaration -- singleton +/// case +/// +/// {0}: The name of the typeDef class. +static const char *const typeDefDeclSingletonBeginStr = R"( + class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{ +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + +)"; + +/// The code block for the start of a typeDef class declaration -- parametric +/// case +/// +/// {0}: The name of the typeDef class. +/// {1}: The typeDef storage class namespace. +/// {2}: The storage class name +/// {3}: The list of parameters with types +static const char *const typeDefDeclParametricBeginStr = R"( + namespace {1} { + struct {2}; + } + class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, + {1}::{2}> {{ +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + +)"; + +// snippet for print/parse +static const char *const typeDefParsePrint = R"( + static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); + void print(::mlir::DialectAsmPrinter& printer) const; +)"; + +/// The code block for the verifyConstructionInvariants and getChecked +/// +/// {0}: List of parameters, parameters style +/// {1}: C++ type class name +static const char *const typeDefDeclVerifyStr = R"( + static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0}); + static {1} getChecked(Location loc{0}); +)"; + +/// Generate the declaration for the given typeDef class. +static void emitTypeDefDecl(TypeDef &typeDef, raw_ostream &os) { + // Emit the beginning string template: either the singleton or parametric + // template + if (typeDef.getNumParameters() == 0) + os << llvm::formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(), + typeDef.getStorageNamespace(), + typeDef.getStorageClassName()); + else + os << llvm::formatv( + typeDefDeclParametricBeginStr, typeDef.getCppClassName(), + typeDef.getStorageNamespace(), typeDef.getStorageClassName()); + + // Emit the extra declarations first in case there's a type definition in + // there + if (llvm::Optional extraDecl = typeDef.getExtraDecls()) + os << *extraDecl; + + // Get the CppType1 param1, CppType2 param2 argument list + std::string parameterParameters = constructParameterParameters(typeDef, true); + + os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", + typeDef.getCppClassName(), parameterParameters); + + // parse/print + os << typeDefParsePrint; + + // verify invariants + if (typeDef.genVerifyInvariantsDecl()) + os << llvm::formatv(typeDefDeclVerifyStr, parameterParameters, + typeDef.getCppClassName()); + + // mnenomic, if specified + if (auto mnenomic = typeDef.getMnemonic()) { + os << " static ::llvm::StringRef getMnemonic() { return \"" << mnenomic + << "\"; }\n"; + } + + if (typeDef.genAccessors()) { + SmallVector parameters; + typeDef.getParameters(parameters); + + for (auto parameter : parameters) { + SmallString<16> name = parameter.getName(); + name[0] = llvm::toUpper(name[0]); + os << llvm::formatv(" {0} get{1}() const;\n", parameter.getCppType(), + name); + } + } + + // End the typeDef decl. + os << " };\n"; +} + +/// Main entry point for decls +static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper, + raw_ostream &os) { + emitSourceFileHeader("TypeDef Declarations", os); + + SmallVector typeDefs; + if (mlir::failed(findAllTypeDefs(recordKeeper, typeDefs))) + return true; + + IfDefScope scope("GET_TYPEDEF_CLASSES", os); + if (typeDefs.size() > 0) { + NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect()); + // well known print/parse dispatch function declarations + os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " + "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n"; + os << " bool generatedTypePrinter(::mlir::Type type, " + "::mlir::DialectAsmPrinter& " + "printer);\n"; + os << "\n"; + + // declare all the type classes first (in case they reference each other) + for (auto typeDef : typeDefs) { + os << " class " << typeDef.getCppClassName() << ";\n"; + } + + // declare all the typedefs + for (auto typeDef : typeDefs) { + emitTypeDefDecl(typeDef, os); + } + } + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef list +//===----------------------------------------------------------------------===// + +static mlir::LogicalResult emitTypeDefList(SmallVectorImpl &typeDefs, + raw_ostream &os) { + IfDefScope scope("GET_TYPEDEF_LIST", os); + for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) { + os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName(); + if (i < typeDefs.end() - 1) + os << ",\n"; + else + os << "\n"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef definitions +//===----------------------------------------------------------------------===// + +/// Beginning of storage class +/// {0}: Storage class namespace +/// {1}: Storage class c++ name +/// {2}: Parameters parameters +/// {3}: Parameter initialzer string +/// {4}: Parameter name list; +/// {5}: Parameter types +static const char *const typeDefStorageClassBegin = R"( +namespace {0} {{ + struct {1} : public ::mlir::TypeStorage {{ + {1} ({2}) + : {3} {{ } + + /// The hash key for this storage is a pair of the integer and type params. + using KeyTy = std::tuple<{5}>; + + /// Define the comparison function for the key type. + bool operator==(const KeyTy &key) const {{ + return key == KeyTy({4}); + } + + static ::llvm::hash_code hashKey(const KeyTy &key) {{ +)"; + +/// The storage class' constructor template +/// {0}: storage class name +static const char *const typeDefStorageClassConstructorBegin = R"( + /// Define a construction method for creating a new instance of this storage. + static {0} *construct(::mlir::TypeStorageAllocator &allocator, + const KeyTy &key) {{ +)"; + +/// The storage class' constructor return template +/// {0}: storage class name +/// {1}: list of parameters +static const char *const typeDefStorageClassConstructorReturn = R"( + return new (allocator.allocate<{0}>()) + {0}({1}); + } +)"; + +/// use tgfmt to emit custom allocation code for each parameter, if necessary +static mlir::LogicalResult emitCustomAllocationCode(TypeDef &typeDef, + raw_ostream &os) { + SmallVector parameters; + typeDef.getParameters(parameters); + auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator"); + for (auto parameter : parameters) { + auto allocCode = parameter.getAllocator(); + if (allocCode) { + fmtCtxt.withSelf(parameter.getName()); + fmtCtxt.addSubst("_dst", parameter.getName()); + auto fmtObj = tgfmt(*allocCode, &fmtCtxt); + os << " "; + fmtObj.format(os); + os << "\n"; + } + } + return mlir::success(); +} + +static mlir::LogicalResult emitStorageClass(TypeDef typeDef, raw_ostream &os) { + SmallVector parameters; + typeDef.getParameters(parameters); + + // Initialize a bunch of variables to be used later on + auto parameterNames = llvm::map_range( + parameters, [](TypeParameter parameter) { return parameter.getName(); }); + auto parameterTypes = + llvm::map_range(parameters, [](TypeParameter parameter) { + return parameter.getCppType(); + }); + auto parameterList = llvm::join(parameterNames, ", "); + auto parameterTypeList = llvm::join(parameterTypes, ", "); + auto parameterParameters = constructParameterParameters(typeDef, false); + auto parameterInits = constructParametersInitializers(typeDef); + + // emit most of the storage class up until the hashKey body + os << llvm::formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(), + typeDef.getStorageClassName(), parameterParameters, + parameterInits, parameterList, parameterTypeList); + + // extract each parameter from the key (auto unboxing is a c++17 feature) + for (size_t i = 0; i < parameters.size(); i++) { + os << llvm::formatv(" auto {0} = std::get<{1}>(key);\n", + parameters[i].getName(), i); + } + // then combine them all. this requires all the parameters types to have a + // hash_value defined + os << " return ::llvm::hash_combine(\n"; + for (auto *parameterIter = parameters.begin(); + parameterIter < parameters.end(); parameterIter++) { + os << " " << parameterIter->getName(); + if (parameterIter < parameters.end() - 1) { + os << ",\n"; + } + } + os << ");\n"; + os << " }\n"; + + // if user wants to build the storage constructor themselves, declare it here + // and then they can write the definition elsewhere + if (typeDef.hasStorageCustomConstructor()) + os << " static " << typeDef.getStorageClassName() + << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy " + "&key);\n"; + else { + os << llvm::formatv(typeDefStorageClassConstructorBegin, + typeDef.getStorageClassName()); + // I want C++17's unboxing!!! + for (size_t i = 0; i < parameters.size(); i++) { + os << llvm::formatv(" auto {0} = std::get<{1}>(key);\n", + parameters[i].getName(), i); + } + // Reassign the parameter variables with allocation code, if it's specified + if (mlir::failed(emitCustomAllocationCode(typeDef, os))) + return mlir::failure(); + // return an allocated copy + os << llvm::formatv(typeDefStorageClassConstructorReturn, + typeDef.getStorageClassName(), parameterList); + } + + // Emit the parameters' class parameters + for (auto parameter : parameters) { + os << " " << parameter.getCppType() << " " << parameter.getName() + << ";\n"; + } + os << " };\n"; + os << "} // namespace " << typeDef.getStorageNamespace() << "\n"; + + return mlir::success(); +} + +/// Emit the body of an autogenerated printer +static mlir::LogicalResult emitPrinterAutogen(TypeDef typeDef, + raw_ostream &os) { + if (auto mnemonic = typeDef.getMnemonic()) { + SmallVector parameters; + typeDef.getParameters(parameters); + + os << " printer << \"" << *mnemonic << "\";\n"; + + // if non-parametric, we're done + if (parameters.size() > 0) { + os << " printer << \"<\";\n"; + + // emit a printer for each parameter separated by ','. + // printer structs for common C++ types are defined in + // TypeDefGenHelpers.h, which must be #included by the consuming code. + for (auto *parameterIter = parameters.begin(); + parameterIter < parameters.end(); parameterIter++) { + // Each printer struct must be put on the stack then 'go' called + os << " ::mlir::tblgen::parser_helpers::Print<" + << parameterIter->getCppType() << ">::go(printer, getImpl()->" + << parameterIter->getName() << ");\n"; + + // emit the comma unless we're the last parameter + if (parameterIter < parameters.end() - 1) { + os << " printer << \", \";\n"; + } + } + os << " printer << \">\";\n"; + } + } + return mlir::success(); +} + +/// Emit the body of an autogenerated parser +static mlir::LogicalResult emitParserAutogen(TypeDef typeDef, raw_ostream &os) { + SmallVector parameters; + typeDef.getParameters(parameters); + + // by the time we get to this function, the mnenomic has already been parsed + if (parameters.size() > 0) { + os << " if (parser.parseLess()) return ::mlir::Type();\n"; + + // emit a parser for each parameter separated by ','. + // parse structs for common C++ types are defined in + // TypeDefGenHelpers.h, which must be #included by the consuming code. + for (auto *parameterIter = parameters.begin(); + parameterIter < parameters.end(); parameterIter++) { + os << " " << parameterIter->getCppType() << " " + << parameterIter->getName() << ";\n"; + os << llvm::formatv( + " ::mlir::tblgen::parser_helpers::Parse<{0}> {1}Parser;\n", + parameterIter->getCppType(), parameterIter->getName()); + os << llvm::formatv(" if ({0}Parser.go(ctxt, parser, \"{1}\", {0})) " + "return ::mlir::Type();\n", + parameterIter->getName(), + parameterIter->getCppType()); + + // parse a comma unless we're the last parameter + if (parameterIter < parameters.end() - 1) { + os << " if (parser.parseComma()) return ::mlir::Type();\n"; + } + } + os << " if (parser.parseGreater()) return ::mlir::Type();\n"; + // done with the parsing + + // all the parameters are now in variables named the same as the parameters + auto parameterNames = + llvm::map_range(parameters, [](TypeParameter parameter) { + return parameter.getName(); + }); + os << " return get(ctxt, " << llvm::join(parameterNames, ", ") << ");\n"; + } else { + os << " return get(ctxt);\n"; + } + return mlir::success(); +} + +/// Print all the typedef-specific definition code +static mlir::LogicalResult emitTypeDefDef(TypeDef typeDef, raw_ostream &os) { + NamespaceEmitter ns(os, typeDef.getDialect()); + SmallVector parameters; + typeDef.getParameters(parameters); + + // emit the storage class, if requested and necessary + if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0) + if (mlir::failed(emitStorageClass(typeDef, os))) + return mlir::failure(); + + std::string paramFuncParams = constructParameterParameters(typeDef, true); + SmallVector paramNames; + paramNames.push_back(""); + typeDef.getParametersAs( + paramNames, [](TypeParameter param) { return param.getName(); }); + os << llvm::formatv("{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" + " return Base::get(ctxt{2});\n" + "}\n", + typeDef.getCppClassName(), paramFuncParams, + llvm::join(paramNames, ",")); + // emit the accessors + if (typeDef.genAccessors()) { + for (auto parameter : parameters) { + SmallString<16> name = parameter.getName(); + name[0] = llvm::toUpper(name[0]); + os << llvm::formatv( + "{0} {3}::get{1}() const { return getImpl()->{2}; }\n", + parameter.getCppType(), name, parameter.getName(), + typeDef.getCppClassName()); + } + } + + // emit the printer code, if appropriate + auto printerCode = typeDef.getPrinterCode(); + if (printerCode && typeDef.getMnemonic()) { + // Both the mnenomic and printerCode must be defined (for parity with + // parserCode) + os << "void " << typeDef.getCppClassName() + << "::print(mlir::DialectAsmPrinter& printer) const {\n"; + if (*printerCode == "") { + // if no code specified, autogenerate a parser + if (mlir::failed(emitPrinterAutogen(typeDef, os))) + return mlir::failure(); + } else { + os << *printerCode << "\n"; + } + os << "}\n"; + } + + // emit a parser, if appropriate + auto parserCode = typeDef.getParserCode(); + if (parserCode && typeDef.getMnemonic()) { + // The mnenomic must be defined so the dispatcher knows how to dispatch + os << "::mlir::Type " << typeDef.getCppClassName() + << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& " + "parser) " + "{\n"; + if (*parserCode == "") { + if (mlir::failed(emitParserAutogen(typeDef, os))) + return mlir::failure(); + } else + os << *parserCode << "\n"; + os << "}\n"; + } + + return mlir::success(); +} + +/// Emit the dialect printer/parser dispatch. Client code should call these +/// functions from their dialect's print/parse methods. +static mlir::LogicalResult +emitParsePrintDispatch(SmallVectorImpl &typeDefs, raw_ostream &os) { + if (typeDefs.size() == 0) + return mlir::success(); + const Dialect &dialect = typeDefs.begin()->getDialect(); + NamespaceEmitter ns(os, dialect); + os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " + "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n"; + for (auto typeDef : typeDefs) { + if (typeDef.getMnemonic()) + os << llvm::formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " + "{0}::{1}::parse(ctxt, parser);\n", + typeDef.getDialect().getCppNamespace(), + typeDef.getCppClassName()); + } + os << " return ::mlir::Type();\n"; + os << "}\n\n"; + + os << "bool generatedTypePrinter(::mlir::Type type, " + "::mlir::DialectAsmPrinter& " + "printer) {\n" + << " bool notfound = false;\n" + << " ::llvm::TypeSwitch<::mlir::Type>(type)\n"; + for (auto typeDef : typeDefs) { + if (typeDef.getMnemonic()) + os << llvm::formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ " + "t.dyn_cast<{0}::{1}>().print(printer); })\n", + typeDef.getDialect().getCppNamespace(), + typeDef.getCppClassName()); + } + os << " .Default([¬found](::mlir::Type) { notfound = true; });\n" + << " return notfound;\n" + << "}\n\n"; + return mlir::success(); +} + +/// Entry point for typedef definitions +static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper, + raw_ostream &os) { + emitSourceFileHeader("TypeDef Definitions", os); + + SmallVector typeDefs; + if (mlir::failed(findAllTypeDefs(recordKeeper, typeDefs))) + return true; + + if (mlir::failed(emitTypeDefList(typeDefs, os))) + return true; + + IfDefScope scope("GET_TYPEDEF_CLASSES", os); + if (mlir::failed(emitParsePrintDispatch(typeDefs, os))) + return true; + for (auto typeDef : typeDefs) { + if (mlir::failed(emitTypeDefDef(typeDef, os))) + return true; + } + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef registration hooks +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return emitTypeDefDefs(records, os); + }); + +static mlir::GenRegistration + genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return emitTypeDefDecls(records, os); + });