diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -9,6 +9,8 @@ set(LLVM_TARGET_DEFINITIONS ${dialect}.td) mlir_tablegen(${dialect}.h.inc -gen-op-decls) mlir_tablegen(${dialect}.cpp.inc -gen-op-defs) + mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls) + mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs) mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace}) add_public_tablegen_target(MLIR${dialect}IncGen) add_dependencies(mlir-headers MLIR${dialect}IncGen) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2364,4 +2364,116 @@ // so to replace the matched DAG with an existing SSA value. def replaceWithValue; + +//===----------------------------------------------------------------------===// +// Data type generation +//===----------------------------------------------------------------------===// + +// Define a new type belonging to a dialect and called 'name'. +class TypeDef { + Dialect dialect = owningdialect; + string cppClassName = name # "Type"; + + // Short summary of the type. + string summary = ?; + // The longer description of this type. + string description = ?; + + // Name of storage class to generate or use. + string storageClass = name # "TypeStorage"; + // Namespace (withing dialect c++ namespace) in which the storage class + // resides. + string storageNamespace = "detail"; + // Specify if the storage class is to be generated. + bit genStorageClass = 1; + // Specify that the generated storage class has a constructor which is written + // in C++. + bit hasStorageCustomConstructor = 0; + + // The list of parameters for this type. Parameters will become both + // parameters to the get() method and storage class member variables. + // + // The format of this dag is: + // (ins + // "":$param1Name, + // "":$param2Name, + // TypeParameter<"c++ type", "param description">:$param3Name) + // TypeParameters (or more likely one of their subclasses) are required to add + // more information about the parameter, specifically: + // - Documentation + // - Code to allocate the parameter (if allocation is needed in the storage + // class constructor) + // + // For example: + // (ins + // "int":$width, + // ArrayRefParameter<"bool", "list of bools">:$yesNoArray) + // + // (ArrayRefParameter is a subclass of TypeParameter which has allocation code + // for re-allocating ArrayRefs. It is defined below.) + dag parameters = (ins); + + // Use the lowercased name as the keyword for parsing/printing. Specify only + // if you want tblgen to generate declarations and/or definitions of + // printer/parser for this type. + string mnemonic = ?; + // If 'mnemonic' specified, + // If null, generate just the declarations. + // If a non-empty code block, just use that code as the definition code. + // Error if an empty code block. + code printer = ?; + code parser = ?; + + // If set, generate accessors for each Type parameter. + bit genAccessors = 1; + // Generate the verifyConstructionInvariants declaration and getChecked + // method. + bit genVerifyInvariantsDecl = 0; + // Extra code to include in the class declaration. + code extraClassDeclaration = [{}]; +} + +// 'Parameters' should be subclasses of this or simple strings (which is a +// shorthand for TypeParameter<"C++Type">). +class TypeParameter { + // Custom memory allocation code for storage constructor. + code allocator = ?; + // The C++ type of this parameter. + string cppType = type; + // A description of this parameter. + string description = desc; + // The format string for the asm syntax (documentation only). + string syntax = ?; +} + +// For StringRefs, which require allocation. +class StringRefParameter : + TypeParameter<"::llvm::StringRef", desc> { + let allocator = [{$_dst = $_allocator.copyInto($_self);}]; +} + +// For standard ArrayRefs, which require allocation. +class ArrayRefParameter : + TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { + let allocator = [{$_dst = $_allocator.copyInto($_self);}]; +} + +// For classes which require allocation and have their own allocateInto method. +class SelfAllocationParameter : + TypeParameter { + let allocator = [{$_dst = $_self.allocateInto($_allocator);}]; +} + +// For ArrayRefs which contain things which allocate themselves. +class ArrayRefOfSelfAllocationParameter : + TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { + let allocator = [{ + llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields; + for (size_t i = 0, e = $_self.size(); i < e; ++i) + tmpFields.push_back($_self[i].allocateInto($_allocator)); + $_dst = $_allocator.copyInto(ArrayRef<}] # arrayOf # [{>(tmpFields)); + }]; +} + + #endif // OP_BASE diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/TypeDef.h @@ -0,0 +1,135 @@ +//===-- TypeDef.h - Record wrapper for type definitions ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TypeDef wrapper to simplify using TableGen Record defining a MLIR type. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_TYPEDEF_H +#define MLIR_TABLEGEN_TYPEDEF_H + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Dialect.h" + +namespace llvm { +class Record; +class DagInit; +class SMLoc; +} // namespace llvm + +namespace mlir { +namespace tblgen { + +class TypeParameter; + +/// Wrapper class that contains a TableGen TypeDef's record and provides helper +/// methods for accessing them. +class TypeDef { +public: + explicit TypeDef(const llvm::Record *def) : def(def) {} + + // Get the dialect for which this type belongs. + Dialect getDialect() const; + + // Returns the name of this TypeDef record. + StringRef getName() const; + + // Query functions for the documentation of the operator. + bool hasDescription() const; + StringRef getDescription() const; + bool hasSummary() const; + StringRef getSummary() const; + + // Returns the name of the C++ class to generate. + StringRef getCppClassName() const; + + // Returns the name of the storage class for this type. + StringRef getStorageClassName() const; + + // Returns the C++ namespace for this types storage class. + StringRef getStorageNamespace() const; + + // Returns true if we should generate the storage class. + bool genStorageClass() const; + + // Indicates whether or not to generate the storage class constructor. + bool hasStorageCustomConstructor() const; + + // Fill a list with this types parameters. See TypeDef in OpBase.td for + // documentation of parameter usage. + void getParameters(SmallVectorImpl &) const; + // Return the number of type parameters + unsigned getNumParameters() const; + + // Return the keyword/mnemonic to use in the printer/parser methods if we are + // supposed to auto-generate them. + Optional getMnemonic() const; + + // Returns the code to use as the types printer method. If not specified, + // return a non-value. Otherwise, return the contents of that code block. + Optional getPrinterCode() const; + + // Returns the code to use as the types parser method. If not specified, + // return a non-value. Otherwise, return the contents of that code block. + Optional getParserCode() const; + + // Returns true if the accessors based on the types parameters should be + // generated. + bool genAccessors() const; + + // Return true if we need to generate the verifyConstructionInvariants + // declaration and getChecked method. + bool genVerifyInvariantsDecl() const; + + // Returns the dialects extra class declaration code. + Optional getExtraDecls() const; + + // Get the code location (for error printing). + ArrayRef getLoc() const; + + // Returns whether two TypeDefs are equal by checking the equality of the + // underlying record. + bool operator==(const TypeDef &other) const; + + // Compares two TypeDefs by comparing the names of the dialects. + bool operator<(const TypeDef &other) const; + + // Returns whether the TypeDef is defined. + operator bool() const { return def != nullptr; } + +private: + const llvm::Record *def; +}; + +// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs +// to parameterize them. +class TypeParameter { +public: + explicit TypeParameter(const llvm::DagInit *def, unsigned num) + : def(def), num(num) {} + + // Get the parameter name. + StringRef getName() const; + // If specified, get the custom allocator code for this parameter. + llvm::Optional getAllocator() const; + // Get the C++ type of this parameter. + StringRef getCppType() const; + // Get a description of this parameter for documentation purposes. + llvm::Optional getDescription() const; + // Get the assembly syntax documentation. + StringRef getSyntax() const; + +private: + const llvm::DagInit *def; + const unsigned num; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_TYPEDEF_H diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -25,6 +25,7 @@ SideEffects.cpp Successor.cpp Type.cpp + TypeDef.cpp DISABLE_LLVM_LINK_LLVM_DYLIB diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/TypeDef.cpp @@ -0,0 +1,160 @@ +//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/TypeDef.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +Dialect TypeDef::getDialect() const { + auto *dialectDef = + dyn_cast(def->getValue("dialect")->getValue()); + if (dialectDef == nullptr) + return Dialect(nullptr); + return Dialect(dialectDef->getDef()); +} + +StringRef TypeDef::getName() const { return def->getName(); } +StringRef TypeDef::getCppClassName() const { + return def->getValueAsString("cppClassName"); +} + +bool TypeDef::hasDescription() const { + const llvm::RecordVal *s = def->getValue("description"); + return s != nullptr && isa(s->getValue()); +} + +StringRef TypeDef::getDescription() const { + return def->getValueAsString("description"); +} + +bool TypeDef::hasSummary() const { + const llvm::RecordVal *s = def->getValue("summary"); + return s != nullptr && isa(s->getValue()); +} + +StringRef TypeDef::getSummary() const { + return def->getValueAsString("summary"); +} + +StringRef TypeDef::getStorageClassName() const { + return def->getValueAsString("storageClass"); +} +StringRef TypeDef::getStorageNamespace() const { + return def->getValueAsString("storageNamespace"); +} + +bool TypeDef::genStorageClass() const { + return def->getValueAsBit("genStorageClass"); +} +bool TypeDef::hasStorageCustomConstructor() const { + return def->getValueAsBit("hasStorageCustomConstructor"); +} +void TypeDef::getParameters(SmallVectorImpl ¶meters) const { + auto *parametersDag = def->getValueAsDag("parameters"); + if (parametersDag != nullptr) { + size_t numParams = parametersDag->getNumArgs(); + for (unsigned i = 0; i < numParams; i++) + parameters.push_back(TypeParameter(parametersDag, i)); + } +} +unsigned TypeDef::getNumParameters() const { + auto *parametersDag = def->getValueAsDag("parameters"); + return parametersDag ? parametersDag->getNumArgs() : 0; +} +llvm::Optional TypeDef::getMnemonic() const { + return def->getValueAsOptionalString("mnemonic"); +} +llvm::Optional TypeDef::getPrinterCode() const { + return def->getValueAsOptionalCode("printer"); +} +llvm::Optional TypeDef::getParserCode() const { + return def->getValueAsOptionalCode("parser"); +} +bool TypeDef::genAccessors() const { + return def->getValueAsBit("genAccessors"); +} +bool TypeDef::genVerifyInvariantsDecl() const { + return def->getValueAsBit("genVerifyInvariantsDecl"); +} +llvm::Optional TypeDef::getExtraDecls() const { + auto value = def->getValueAsString("extraClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} +llvm::ArrayRef TypeDef::getLoc() const { return def->getLoc(); } +bool TypeDef::operator==(const TypeDef &other) const { + return def == other.def; +} + +bool TypeDef::operator<(const TypeDef &other) const { + return getName() < other.getName(); +} + +StringRef TypeParameter::getName() const { + return def->getArgName(num)->getValue(); +} +llvm::Optional TypeParameter::getAllocator() const { + llvm::Init *parameterType = def->getArg(num); + if (auto *stringType = dyn_cast(parameterType)) + return llvm::Optional(); + + if (auto *typeParameter = dyn_cast(parameterType)) { + llvm::RecordVal *code = typeParameter->getDef()->getValue("allocator"); + if (llvm::CodeInit *ci = dyn_cast(code->getValue())) + return ci->getValue(); + if (isa(code->getValue())) + return llvm::Optional(); + + llvm::PrintFatalError( + typeParameter->getDef()->getLoc(), + "Record `" + def->getArgName(num)->getValue() + + "', field `printer' does not have a code initializer!"); + } + + llvm::PrintFatalError("Parameters DAG arguments must be either strings or " + "defs which inherit from TypeParameter\n"); +} +StringRef TypeParameter::getCppType() const { + auto *parameterType = def->getArg(num); + if (auto *stringType = dyn_cast(parameterType)) + return stringType->getValue(); + if (auto *typeParameter = dyn_cast(parameterType)) + return typeParameter->getDef()->getValueAsString("cppType"); + llvm::PrintFatalError( + "Parameters DAG arguments must be either strings or defs " + "which inherit from TypeParameter\n"); +} +llvm::Optional TypeParameter::getDescription() const { + auto *parameterType = def->getArg(num); + if (auto *typeParameter = dyn_cast(parameterType)) { + const auto *desc = typeParameter->getDef()->getValue("description"); + if (llvm::StringInit *ci = dyn_cast(desc->getValue())) + return ci->getValue(); + } + return llvm::Optional(); +} +StringRef TypeParameter::getSyntax() const { + auto *parameterType = def->getArg(num); + if (auto *stringType = dyn_cast(parameterType)) + return stringType->getValue(); + if (auto *typeParameter = dyn_cast(parameterType)) { + const auto *syntax = typeParameter->getDef()->getValue("syntax"); + if (syntax && isa(syntax->getValue())) + return dyn_cast(syntax->getValue())->getValue(); + return getCppType(); + } + llvm::PrintFatalError("Parameters DAG arguments must be either strings or " + "defs which inherit from TypeParameter"); +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -9,6 +9,12 @@ mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIRTestInterfaceIncGen) +set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td) +mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls) +mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(MLIRTestDefIncGen) + + set(LLVM_TARGET_DEFINITIONS TestOps.td) mlir_tablegen(TestOps.h.inc -gen-op-decls) mlir_tablegen(TestOps.cpp.inc -gen-op-defs) @@ -25,11 +31,13 @@ TestDialect.cpp TestPatterns.cpp TestTraits.cpp + TestTypes.cpp EXCLUDE_FROM_LIBMLIR DEPENDS MLIRTestInterfaceIncGen + MLIRTestDefIncGen MLIRTestOpsIncGen LINK_LIBS PUBLIC diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -141,16 +141,23 @@ >(); addInterfaces(); - addTypes(); + addTypes(); allowUnknownOperations(); } -static Type parseTestType(DialectAsmParser &parser, +static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, llvm::SetVector &stack) { StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); + auto genType = generatedTypeParser(ctxt, parser, typeTag); + if (genType != Type()) + return genType; + if (typeTag == "test_type") return TestType::get(parser.getBuilder().getContext()); @@ -174,7 +181,7 @@ if (failed(parser.parseComma())) return Type(); stack.insert(rec); - Type subtype = parseTestType(parser, stack); + Type subtype = parseTestType(ctxt, parser, stack); stack.pop_back(); if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) return Type(); @@ -184,11 +191,13 @@ Type TestDialect::parseType(DialectAsmParser &parser) const { llvm::SetVector stack; - return parseTestType(parser, stack); + return parseTestType(getContext(), parser, stack); } static void printTestType(Type type, DialectAsmPrinter &printer, llvm::SetVector &stack) { + if (succeeded(generatedTypePrinter(type, printer))) + return; if (type.isa()) { printer << "test_type"; return; diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -0,0 +1,150 @@ +//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TableGen data type definitions for Test dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_TYPEDEFS +#define TEST_TYPEDEFS + +// To get the test dialect def. +include "TestOps.td" + +// All of the types will extend this class. +class Test_Type : TypeDef { } + +def SimpleTypeA : Test_Type<"SimpleA"> { + let mnemonic = "smpla"; + + let printer = [{ $_printer << "smpla"; }]; + let parser = [{ return get($_ctxt); }]; +} + +// A more complex parameterized type. +def CompoundTypeA : Test_Type<"CompoundA"> { + let mnemonic = "cmpnd_a"; + + // List of type parameters. + let parameters = ( + ins + "int":$widthOfSomething, + "::mlir::Type":$oneType, + // This is special syntax since ArrayRefs require allocation in the + // constructor. + ArrayRefParameter< + "int", // The parameter C++ type. + "An example of an array of ints" // Parameter description. + >: $arrayOfInts + ); + + let extraClassDeclaration = [{ + struct SomeCppStruct {}; + }]; +} + +// An example of how one could implement a standard integer. +def IntegerType : Test_Type<"TestInteger"> { + let mnemonic = "int"; + let genVerifyInvariantsDecl = 1; + let parameters = ( + ins + // SignednessSemantics is defined below. + "::mlir::TestIntegerType::SignednessSemantics":$signedness, + "unsigned":$width + ); + + // We define the printer inline. + let printer = [{ + $_printer << "int<"; + printSignedness($_printer, getImpl()->signedness); + $_printer << ", " << getImpl()->width << ">"; + }]; + + // The parser is defined here also. + let parser = [{ + if (parser.parseLess()) return Type(); + SignednessSemantics signedness; + if (parseSignedness($_parser, signedness)) return mlir::Type(); + if ($_parser.parseComma()) return Type(); + int width; + if ($_parser.parseInteger(width)) return Type(); + if ($_parser.parseGreater()) return Type(); + return get(ctxt, signedness, width); + }]; + + // Any extra code one wants in the type's class declaration. + let extraClassDeclaration = [{ + /// Signedness semantics. + enum SignednessSemantics { + Signless, /// No signedness semantics + Signed, /// Signed integer + Unsigned, /// Unsigned integer + }; + + /// This extra function is necessary since it doesn't include signedness + static IntegerType getChecked(unsigned width, Location location); + + /// Return true if this is a signless integer type. + bool isSignless() const { return getSignedness() == Signless; } + /// Return true if this is a signed integer type. + bool isSigned() const { return getSignedness() == Signed; } + /// Return true if this is an unsigned integer type. + bool isUnsigned() const { return getSignedness() == Unsigned; } + }]; +} + +// A parent type for any type which is just a list of fields (e.g. structs, +// unions). +class FieldInfo_Type : Test_Type { + let parameters = ( + ins + // An ArrayRef of something which requires allocation in the storage + // constructor. + ArrayRefOfSelfAllocationParameter< + "::mlir::FieldInfo", // FieldInfo is defined/declared in TestTypes.h. + "Models struct fields">: $fields + ); + + // Prints the type in this format: + // struct<[{field1Name, field1Type}, {field2Name, field2Type}] + let printer = [{ + $_printer << "struct" << "<"; + for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) { + const auto& field = getImpl()->fields[i]; + $_printer << "{" << field.name << "," << field.type << "}"; + if (i < getImpl()->fields.size() - 1) + $_printer << ","; + } + $_printer << ">"; + }]; + + // Parses the above format + let parser = [{ + llvm::SmallVector parameters; + if ($_parser.parseLess()) return Type(); + while (mlir::succeeded($_parser.parseOptionalLBrace())) { + StringRef name; + if ($_parser.parseKeyword(&name)) return Type(); + if ($_parser.parseComma()) return Type(); + Type type; + if ($_parser.parseType(type)) return Type(); + if ($_parser.parseRBrace()) return Type(); + parameters.push_back(FieldInfo {name, type}); + if ($_parser.parseOptionalComma()) break; + } + if ($_parser.parseGreater()) return Type(); + return get($_ctxt, parameters); + }]; +} + +def StructType : FieldInfo_Type<"Struct"> { + let mnemonic = "struct"; +} + +#endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -14,11 +14,35 @@ #ifndef MLIR_TESTTYPES_H #define MLIR_TESTTYPES_H +#include + #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" namespace mlir { +/// FieldInfo represents a field in the StructType data type. It is used as a +/// parameter in TestTypeDefs.td. +struct FieldInfo { + StringRef name; + Type type; + + // Custom allocation called from generated constructor code + FieldInfo allocateInto(TypeStorageAllocator &alloc) const { + return FieldInfo{alloc.copyInto(name), type}; + } +}; + +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.h.inc" + +namespace mlir { + #include "TestTypeInterfaces.h.inc" /// This class is a simple test type that uses a generated interface. diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -0,0 +1,117 @@ +//===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains types defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#include "TestTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +// Custom parser for SignednessSemantics. +static ParseResult +parseSignedness(DialectAsmParser &parser, + TestIntegerType::SignednessSemantics &result) { + StringRef signStr; + auto loc = parser.getCurrentLocation(); + if (parser.parseKeyword(&signStr)) + return failure(); + if (signStr.compare_lower("u") || signStr.compare_lower("unsigned")) + result = TestIntegerType::SignednessSemantics::Unsigned; + else if (signStr.compare_lower("s") || signStr.compare_lower("signed")) + result = TestIntegerType::SignednessSemantics::Signed; + else if (signStr.compare_lower("n") || signStr.compare_lower("none")) + result = TestIntegerType::SignednessSemantics::Signless; + else + return parser.emitError(loc, "expected signed, unsigned, or none"); + return success(); +} + +// Custom printer for SignednessSemantics. +static void printSignedness(DialectAsmPrinter &printer, + const TestIntegerType::SignednessSemantics &ss) { + switch (ss) { + case TestIntegerType::SignednessSemantics::Unsigned: + printer << "unsigned"; + break; + case TestIntegerType::SignednessSemantics::Signed: + printer << "signed"; + break; + case TestIntegerType::SignednessSemantics::Signless: + printer << "none"; + break; + } +} + +Type CompoundAType::parse(MLIRContext *ctxt, DialectAsmParser &parser) { + int widthOfSomething; + Type oneType; + SmallVector arrayOfInts; + if (parser.parseLess() || parser.parseInteger(widthOfSomething) || + parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || + parser.parseLSquare()) + return Type(); + + int i; + while (!*parser.parseOptionalInteger(i)) { + arrayOfInts.push_back(i); + if (parser.parseOptionalComma()) + break; + } + + if (parser.parseRSquare() || parser.parseGreater()) + return Type(); + + return get(ctxt, widthOfSomething, oneType, arrayOfInts); +} +void CompoundAType::print(DialectAsmPrinter &printer) const { + printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType() + << ", ["; + auto intArray = getArrayOfInts(); + llvm::interleaveComma(intArray, printer); + printer << "]>"; +} + +// The functions don't need to be in the header file, but need to be in the mlir +// namespace. Declare them here, then define them immediately below. Separating +// the declaration and definition adheres to the LLVM coding standards. +namespace mlir { +// FieldInfo is used as part of a parameter, so equality comparison is +// compulsory. +static bool operator==(const FieldInfo &a, const FieldInfo &b); +// FieldInfo is used as part of a parameter, so a hash will be computed. +static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT +} // namespace mlir + +// FieldInfo is used as part of a parameter, so equality comparison is +// compulsory. +static bool mlir::operator==(const FieldInfo &a, const FieldInfo &b) { + return a.name == b.name && a.type == b.type; +} + +// FieldInfo is used as part of a parameter, so a hash will be computed. +static llvm::hash_code mlir::hash_value(const FieldInfo &fi) { // NOLINT + return llvm::hash_combine(fi.name, fi.type); +} + +// Example type validity checker. +LogicalResult TestIntegerType::verifyConstructionInvariants( + Location loc, TestIntegerType::SignednessSemantics ss, unsigned int width) { + if (width > 8) + return failure(); + return success(); +} + +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.cpp.inc" diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s + +////////////// +// Tests the types in the 'Test' dialect, not the ones in 'typedefs.mlir' + +// CHECK: @simpleA(%arg0: !test.smpla) +func @simpleA(%A : !test.smpla) -> () { + return +} + +// CHECK: @compoundA(%arg0: !test.cmpnd_a<1, !test.smpla, [5, 6]>) +func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () { + return +} + +// CHECK: @testInt(%arg0: !test.int, %arg1: !test.int, %arg2: !test.int) +func @testInt(%A : !test.int, %B : !test.int, %C : !test.int) { + return +} + +// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int}>) +func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int} > ) { + return +} diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -0,0 +1,132 @@ +// RUN: mlir-tblgen -gen-typedef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL +// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF + +include "mlir/IR/OpBase.td" + +// DECL: #ifdef GET_TYPEDEF_CLASSES +// DECL: #undef GET_TYPEDEF_CLASSES + +// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic); +// DECL: ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer); + +// DEF: #ifdef GET_TYPEDEF_LIST +// DEF: #undef GET_TYPEDEF_LIST +// DEF: ::mlir::test::SimpleAType, +// DEF: ::mlir::test::CompoundAType, +// DEF: ::mlir::test::IndexType, +// DEF: ::mlir::test::SingleParameterType, +// DEF: ::mlir::test::IntegerType + +// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) +// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser); +// DEF return ::mlir::Type(); + +def Test_Dialect: Dialect { +// DECL-NOT: TestDialect +// DEF-NOT: TestDialect + let name = "TestDialect"; + let cppNamespace = "::mlir::test"; +} + +class TestType : TypeDef { } + +def A_SimpleTypeA : TestType<"SimpleA"> { +// DECL: class SimpleAType: public ::mlir::Type +} + +// A more complex parameterized type +def B_CompoundTypeA : TestType<"CompoundA"> { + let summary = "A more complex parameterized type"; + let description = "This type is to test a reasonably complex type"; + let mnemonic = "cmpnd_a"; + let parameters = ( + ins + "int":$widthOfSomething, + "::mlir::test::SimpleTypeA": $exampleTdType, + "SomeCppStruct": $exampleCppType, + ArrayRefParameter<"int", "Matrix dimensions">:$dims + ); + + let genVerifyInvariantsDecl = 1; + +// DECL-LABEL: class CompoundAType: public ::mlir::Type +// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); +// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); +// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; } +// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); +// DECL: void print(::mlir::DialectAsmPrinter& printer) const; +// DECL: int getWidthOfSomething() const; +// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const; +// DECL: SomeCppStruct getExampleCppType() const; +} + +def C_IndexType : TestType<"Index"> { + let mnemonic = "index"; + + let parameters = ( + ins + StringRefParameter<"Label for index">:$label + ); + +// DECL-LABEL: class IndexType: public ::mlir::Type +// DECL: static ::llvm::StringRef getMnemonic() { return "index"; } +// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); +// DECL: void print(::mlir::DialectAsmPrinter& printer) const; +} + +def D_SingleParameterType : TestType<"SingleParameter"> { + let parameters = ( + ins + "int": $num + ); +// DECL-LABEL: struct SingleParameterTypeStorage; +// DECL-LABEL: class SingleParameterType +// DECL-NEXT: detail::SingleParameterTypeStorage +} + +def E_IntegerType : TestType<"Integer"> { + let mnemonic = "int"; + let genVerifyInvariantsDecl = 1; + let parameters = ( + ins + "SignednessSemantics":$signedness, + TypeParameter<"unsigned", "Bitwdith of integer">:$width + ); + +// DECL-LABEL: IntegerType: public ::mlir::Type + + let extraClassDeclaration = [{ + /// Signedness semantics. + enum SignednessSemantics { + Signless, /// No signedness semantics + Signed, /// Signed integer + Unsigned, /// Unsigned integer + }; + + /// This extra function is necessary since it doesn't include signedness + static IntegerType getChecked(unsigned width, Location location); + + /// Return true if this is a signless integer type. + bool isSignless() const { return getSignedness() == Signless; } + /// Return true if this is a signed integer type. + bool isSigned() const { return getSignedness() == Signed; } + /// Return true if this is an unsigned integer type. + bool isUnsigned() const { return getSignedness() == Unsigned; } + }]; + +// DECL: /// Signedness semantics. +// DECL-NEXT: enum SignednessSemantics { +// DECL-NEXT: Signless, /// No signedness semantics +// DECL-NEXT: Signed, /// Signed integer +// DECL-NEXT: Unsigned, /// Unsigned integer +// DECL-NEXT: }; +// DECL: /// This extra function is necessary since it doesn't include signedness +// DECL-NEXT: static IntegerType getChecked(unsigned width, Location location); + +// DECL: /// Return true if this is a signless integer type. +// DECL-NEXT: bool isSignless() const { return getSignedness() == Signless; } +// DECL-NEXT: /// Return true if this is a signed integer type. +// DECL-NEXT: bool isSigned() const { return getSignedness() == Signed; } +// DECL-NEXT: /// Return true if this is an unsigned integer type. +// DECL-NEXT: bool isUnsigned() const { return getSignedness() == Unsigned; } +} diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -20,6 +20,7 @@ RewriterGen.cpp SPIRVUtilsGen.cpp StructsGen.cpp + TypeDefGen.cpp ) set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning") diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -15,6 +15,7 @@ #include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/TypeDef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -23,6 +24,8 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#include + using namespace llvm; using namespace mlir; using namespace mlir::tblgen; @@ -155,12 +158,67 @@ os << "\n"; } +//===----------------------------------------------------------------------===// +// TypeDef Documentation +//===----------------------------------------------------------------------===// + +/// Emit the assembly format of a type. +static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) { + SmallVector parameters; + td.getParameters(parameters); + if (parameters.size() == 0) { + os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic() + << "`\n"; + return; + } + + os << "\nSyntax:\n\n```\n!" << td.getDialect().getName() << "." + << td.getMnemonic() << "<\n"; + for (auto *it = parameters.begin(), *e = parameters.end(); it < e; ++it) { + os << " " << it->getSyntax(); + if (it < parameters.end() - 1) + os << ","; + os << " # " << it->getName() << "\n"; + } + os << ">\n```\n"; +} + +static void emitTypeDefDoc(TypeDef td, raw_ostream &os) { + os << llvm::formatv("### `{0}` ({1})\n", td.getName(), td.getCppClassName()); + + // Emit the summary, syntax, and description if present. + if (td.hasSummary()) + os << "\n" << td.getSummary() << "\n"; + if (td.getMnemonic() && td.getPrinterCode() && *td.getPrinterCode() == "" && + td.getParserCode() && *td.getParserCode() == "") + emitTypeAssemblyFormat(td, os); + if (td.hasDescription()) + mlir::tblgen::emitDescription(td.getDescription(), os); + + // Emit attribute documentation. + SmallVector parameters; + td.getParameters(parameters); + if (parameters.size() != 0) { + os << "\n#### Type parameters:\n\n"; + os << "| Parameter | C++ type | Description |\n" + << "| :-------: | :-------: | ----------- |\n"; + for (const auto &it : parameters) { + auto desc = it.getDescription(); + os << "| " << it.getName() << " | `" << td.getCppClassName() << "` | " + << (desc ? *desc : "") << " |\n"; + } + } + + os << "\n"; +} + //===----------------------------------------------------------------------===// // Dialect Documentation //===----------------------------------------------------------------------===// static void emitDialectDoc(const Dialect &dialect, ArrayRef ops, - ArrayRef types, raw_ostream &os) { + ArrayRef types, ArrayRef typeDefs, + raw_ostream &os) { os << "# '" << dialect.getName() << "' Dialect\n\n"; emitIfNotEmpty(dialect.getSummary(), os); emitIfNotEmpty(dialect.getDescription(), os); @@ -169,7 +227,7 @@ // TODO: Add link between use and def for types if (!types.empty()) { - os << "## Type definition\n\n"; + os << "## Type constraint definition\n\n"; for (const Type &type : types) emitTypeDoc(type, os); } @@ -179,28 +237,43 @@ for (const Operator &op : ops) emitOpDoc(op, os); } + + if (!typeDefs.empty()) { + os << "## Type definition\n\n"; + for (const TypeDef &td : typeDefs) + emitTypeDefDoc(td, os); + } } static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { const auto &opDefs = recordKeeper.getAllDerivedDefinitions("Op"); const auto &typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType"); + const auto &typeDefDefs = recordKeeper.getAllDerivedDefinitions("TypeDef"); + std::set dialectsWithDocs; std::map> dialectOps; std::map> dialectTypes; + std::map> dialectTypeDefs; for (auto *opDef : opDefs) { Operator op(opDef); dialectOps[op.getDialect()].push_back(op); + dialectsWithDocs.insert(op.getDialect()); } for (auto *typeDef : typeDefs) { Type type(typeDef); if (auto dialect = type.getDialect()) dialectTypes[dialect].push_back(type); } + for (auto *typeDef : typeDefDefs) { + TypeDef type(typeDef); + dialectTypeDefs[type.getDialect()].push_back(type); + dialectsWithDocs.insert(type.getDialect()); + } os << "\n"; - for (const auto &dialectWithOps : dialectOps) - emitDialectDoc(dialectWithOps.first, dialectWithOps.second, - dialectTypes[dialectWithOps.first], os); + for (auto dialect : dialectsWithDocs) + emitDialectDoc(dialect, dialectOps[dialect], dialectTypes[dialect], + dialectTypeDefs[dialect], os); } //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -0,0 +1,561 @@ +//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TypeDefGen uses the description of typeDefs to generate C++ definitions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/TypeDef.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/TableGenBackend.h" + +#define DEBUG_TYPE "mlir-tblgen-typedefgen" + +using namespace mlir; +using namespace mlir::tblgen; + +static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*"); +static llvm::cl::opt + selectedDialect("typedefs-dialect", + llvm::cl::desc("Gen types for this dialect"), + llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated); + +/// Find all the TypeDefs for the specified dialect. If no dialect specified and +/// can only find one dialect's types, use that. +static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper, + SmallVectorImpl &typeDefs) { + auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef"); + auto defs = llvm::map_range( + recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); }); + if (defs.empty()) + return; + + StringRef dialectName; + if (selectedDialect.getNumOccurrences() == 0) { + if (defs.empty()) + return; + + llvm::SmallSet dialects; + for (const TypeDef &typeDef : defs) + dialects.insert(typeDef.getDialect()); + if (dialects.size() != 1) + llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must " + "select one via '--typedefs-dialect'"); + + dialectName = (*dialects.begin()).getName(); + } else if (selectedDialect.getNumOccurrences() == 1) { + dialectName = selectedDialect.getValue(); + } else { + llvm::PrintFatalError("Cannot select multiple dialects for which to " + "generate types via '--typedefs-dialect'."); + } + + for (const TypeDef &typeDef : defs) + if (typeDef.getDialect().getName().equals(dialectName)) + typeDefs.push_back(typeDef); +} + +namespace { + +/// Pass an instance of this class to llvm::formatv() to emit a comma separated +/// list of parameters in the format by 'EmitFormat'. +class TypeParamCommaFormatter : public llvm::detail::format_adapter { +public: + /// Choose the output format + enum EmitFormat { + /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name, + /// [...]". + TypeNamePairs, + + /// Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name, + /// [...]". + TypeNamePairsPrependComma, + + /// Emit "parameter1(parameter1), parameter2(parameter2), [...]". + TypeNameInitializer + }; + + TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef params) + : emitFormat(emitFormat), params(params) {} + + /// llvm::formatv will call this function when using an instance as a + /// replacement value. + void format(raw_ostream &os, StringRef options) { + if (params.size() && emitFormat == EmitFormat::TypeNamePairsPrependComma) + os << ", "; + switch (emitFormat) { + case EmitFormat::TypeNamePairs: + case EmitFormat::TypeNamePairsPrependComma: + interleaveComma(params, os, + [&](const TypeParameter &p) { emitTypeNamePair(p, os); }); + break; + case EmitFormat::TypeNameInitializer: + interleaveComma(params, os, [&](const TypeParameter &p) { + emitTypeNameInitializer(p, os); + }); + break; + } + } + +private: + // Emit "paramType paramName". + static void emitTypeNamePair(const TypeParameter ¶m, raw_ostream &os) { + os << param.getCppType() << " " << param.getName(); + } + // Emit "paramName(paramName)" + void emitTypeNameInitializer(const TypeParameter ¶m, raw_ostream &os) { + os << param.getName() << "(" << param.getName() << ")"; + } + + EmitFormat emitFormat; + ArrayRef params; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// GEN: TypeDef declarations +//===----------------------------------------------------------------------===// + +/// The code block for the start of a typeDef class declaration -- singleton +/// case. +/// +/// {0}: The name of the typeDef class. +static const char *const typeDefDeclSingletonBeginStr = R"( + class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{ + public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + +)"; + +/// The code block for the start of a typeDef class declaration -- parametric +/// case. +/// +/// {0}: The name of the typeDef class. +/// {1}: The typeDef storage class namespace. +/// {2}: The storage class name. +/// {3}: The list of parameters with types. +static const char *const typeDefDeclParametricBeginStr = R"( + namespace {1} { + struct {2}; + } + class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, + {1}::{2}> {{ + public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + +)"; + +/// The snippet for print/parse. +static const char *const typeDefParsePrint = R"( + static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); + void print(::mlir::DialectAsmPrinter& printer) const; +)"; + +/// The code block for the verifyConstructionInvariants and getChecked. +/// +/// {0}: List of parameters, parameters style. +/// {1}: C++ type class name. +static const char *const typeDefDeclVerifyStr = R"( + static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0}); + static {1} getChecked(Location loc{0}); +)"; + +/// Generate the declaration for the given typeDef class. +static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) { + SmallVector params; + typeDef.getParameters(params); + + // Emit the beginning string template: either the singleton or parametric + // template. + if (typeDef.getNumParameters() == 0) + os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(), + typeDef.getStorageNamespace(), typeDef.getStorageClassName()); + else + os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(), + typeDef.getStorageNamespace(), typeDef.getStorageClassName()); + + // Emit the extra declarations first in case there's a type definition in + // there. + if (Optional extraDecl = typeDef.getExtraDecls()) + os << *extraDecl << "\n"; + + TypeParamCommaFormatter emitTypeNamePairsAfterComma( + TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, params); + os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", + typeDef.getCppClassName(), emitTypeNamePairsAfterComma); + + // Emit the verify invariants declaration. + if (typeDef.genVerifyInvariantsDecl()) + os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma, + typeDef.getCppClassName()); + + // Emit the mnenomic, if specified. + if (auto mnenomic = typeDef.getMnemonic()) { + os << " static ::llvm::StringRef getMnemonic() { return \"" << mnenomic + << "\"; }\n"; + + // If mnemonic specified, emit print/parse declarations. + os << typeDefParsePrint; + } + + if (typeDef.genAccessors()) { + SmallVector parameters; + typeDef.getParameters(parameters); + + for (TypeParameter ¶meter : parameters) { + SmallString<16> name = parameter.getName(); + name[0] = llvm::toUpper(name[0]); + os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name); + } + } + + // End the typeDef decl. + os << " };\n"; +} + +/// Main entry point for decls. +static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper, + raw_ostream &os) { + emitSourceFileHeader("TypeDef Declarations", os); + + SmallVector typeDefs; + findAllTypeDefs(recordKeeper, typeDefs); + + IfDefScope scope("GET_TYPEDEF_CLASSES", os); + if (typeDefs.size() > 0) { + NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect()); + + // Well known print/parse dispatch function declarations. These are called + // from Dialect::parseType() and Dialect::printType() methods. + os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " + "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n"; + os << " ::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, " + "::mlir::DialectAsmPrinter& printer);\n"; + os << "\n"; + + // Declare all the type classes first (in case they reference each other). + for (const TypeDef &typeDef : typeDefs) + os << " class " << typeDef.getCppClassName() << ";\n"; + + // Declare all the typedefs. + for (const TypeDef &typeDef : typeDefs) + emitTypeDefDecl(typeDef, os); + } + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef list +//===----------------------------------------------------------------------===// + +static void emitTypeDefList(SmallVectorImpl &typeDefs, + raw_ostream &os) { + IfDefScope scope("GET_TYPEDEF_LIST", os); + for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) { + os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName(); + if (i < typeDefs.end() - 1) + os << ",\n"; + else + os << "\n"; + } +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef definitions +//===----------------------------------------------------------------------===// + +/// Beginning of storage class. +/// {0}: Storage class namespace. +/// {1}: Storage class c++ name. +/// {2}: Parameters parameters. +/// {3}: Parameter initialzer string. +/// {4}: Parameter name list. +/// {5}: Parameter types. +static const char *const typeDefStorageClassBegin = R"( +namespace {0} {{ + struct {1} : public ::mlir::TypeStorage {{ + {1} ({2}) + : {3} {{ } + + /// The hash key for this storage is a pair of the integer and type params. + using KeyTy = std::tuple<{5}>; + + /// Define the comparison function for the key type. + bool operator==(const KeyTy &key) const {{ + return key == KeyTy({4}); + } +)"; + +/// The storage class' constructor template. +/// {0}: storage class name. +static const char *const typeDefStorageClassConstructorBegin = R"( + /// Define a construction method for creating a new instance of this storage. + static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{ +)"; + +/// The storage class' constructor return template. +/// {0}: storage class name. +/// {1}: list of parameters. +static const char *const typeDefStorageClassConstructorReturn = R"( + return new (allocator.allocate<{0}>()) + {0}({1}); + } +)"; + +/// Use tgfmt to emit custom allocation code for each parameter, if necessary. +static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) { + SmallVector parameters; + typeDef.getParameters(parameters); + auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator"); + for (TypeParameter ¶meter : parameters) { + auto allocCode = parameter.getAllocator(); + if (allocCode) { + fmtCtxt.withSelf(parameter.getName()); + fmtCtxt.addSubst("_dst", parameter.getName()); + os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n"; + } + } +} + +/// Emit the storage class code for type 'typeDef'. +/// This includes (in-order): +/// 1) typeDefStorageClassBegin, which includes: +/// - The class constructor. +/// - The KeyTy definition. +/// - The equality (==) operator. +/// 2) The hashKey method. +/// 3) The construct method. +/// 4) The list of parameters as the storage class member variables. +static void emitStorageClass(TypeDef typeDef, raw_ostream &os) { + SmallVector parameters; + typeDef.getParameters(parameters); + + // Initialize a bunch of variables to be used later on. + auto parameterNames = map_range( + parameters, [](TypeParameter parameter) { return parameter.getName(); }); + auto parameterTypes = map_range(parameters, [](TypeParameter parameter) { + return parameter.getCppType(); + }); + auto parameterList = join(parameterNames, ", "); + auto parameterTypeList = join(parameterTypes, ", "); + + // 1) Emit most of the storage class up until the hashKey body. + os << formatv( + typeDefStorageClassBegin, typeDef.getStorageNamespace(), + typeDef.getStorageClassName(), + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters), + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters), + parameterList, parameterTypeList); + + // 2) Emit the haskKey method. + os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n"; + // Extract each parameter from the key. + for (size_t i = 0, e = parameters.size(); i < e; ++i) + os << formatv(" const auto &{0} = std::get<{1}>(key);\n", + parameters[i].getName(), i); + // Then combine them all. This requires all the parameters types to have a + // hash_value defined. + os << " return ::llvm::hash_combine("; + interleaveComma(parameterNames, os); + os << ");\n"; + os << " }\n"; + + // 3) Emit the construct method. + if (typeDef.hasStorageCustomConstructor()) + // If user wants to build the storage constructor themselves, declare it + // here and then they can write the definition elsewhere. + os << " static " << typeDef.getStorageClassName() + << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy " + "&key);\n"; + else { + // If not, autogenerate one. + + // First, unbox the parameters. + os << formatv(typeDefStorageClassConstructorBegin, + typeDef.getStorageClassName()); + for (size_t i = 0; i < parameters.size(); ++i) { + os << formatv(" auto {0} = std::get<{1}>(key);\n", + parameters[i].getName(), i); + } + // Second, reassign the parameter variables with allocation code, if it's + // specified. + emitParameterAllocationCode(typeDef, os); + + // Last, return an allocated copy. + os << formatv(typeDefStorageClassConstructorReturn, + typeDef.getStorageClassName(), parameterList); + } + + // 4) Emit the parameters as storage class members. + for (auto parameter : parameters) { + os << " " << parameter.getCppType() << " " << parameter.getName() + << ";\n"; + } + os << " };\n"; + + os << "} // namespace " << typeDef.getStorageNamespace() << "\n"; +} + +/// Emit the parser and printer for a particular type, if they're specified. +void emitParserPrinter(TypeDef typeDef, raw_ostream &os) { + // Emit the printer code, if specified. + if (auto printerCode = typeDef.getPrinterCode()) { + // Both the mnenomic and printerCode must be defined (for parity with + // parserCode). + os << "void " << typeDef.getCppClassName() + << "::print(::mlir::DialectAsmPrinter& printer) const {\n"; + if (*printerCode == "") { + // If no code specified, emit error. + PrintFatalError(typeDef.getLoc(), + typeDef.getName() + + ": printer (if specified) must have non-empty code"); + } + auto fmtCtxt = FmtContext().addSubst("_printer", "printer"); + os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n"; + } + + // emit a parser, if specified. + if (auto parserCode = typeDef.getParserCode()) { + // The mnenomic must be defined so the dispatcher knows how to dispatch. + os << "::mlir::Type " << typeDef.getCppClassName() + << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& " + "parser) " + "{\n"; + if (*parserCode == "") { + // if no code specified, emit error. + PrintFatalError(typeDef.getLoc(), + typeDef.getName() + + ": parser (if specified) must have non-empty code"); + } + auto fmtCtxt = + FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "ctxt"); + os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n"; + } +} + +/// Print all the typedef-specific definition code. +static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) { + NamespaceEmitter ns(os, typeDef.getDialect()); + SmallVector parameters; + typeDef.getParameters(parameters); + + // Emit the storage class, if requested and necessary. + if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0) + emitStorageClass(typeDef, os); + + os << llvm::formatv( + "{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" + " return Base::get(ctxt", + typeDef.getCppClassName(), + TypeParamCommaFormatter( + TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, + parameters)); + for (TypeParameter ¶m : parameters) + os << ", " << param.getName(); + os << ");\n}\n"; + + // Emit the parameter accessors. + if (typeDef.genAccessors()) + for (const TypeParameter ¶meter : parameters) { + SmallString<16> name = parameter.getName(); + name[0] = llvm::toUpper(name[0]); + os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n", + parameter.getCppType(), name, parameter.getName(), + typeDef.getCppClassName()); + } + + // If mnemonic is specified maybe print definitions for the parser and printer + // code, if they're specified. + if (typeDef.getMnemonic()) + emitParserPrinter(typeDef, os); +} + +/// Emit the dialect printer/parser dispatcher. User's code should call these +/// functions from their dialect's print/parse methods. +static void emitParsePrintDispatch(SmallVectorImpl &typeDefs, + raw_ostream &os) { + if (typeDefs.size() == 0) + return; + const Dialect &dialect = typeDefs.begin()->getDialect(); + NamespaceEmitter ns(os, dialect); + + // The parser dispatch is just a list of if-elses, matching on the mnemonic + // and calling the class's parse function. + os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " + "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n"; + for (const TypeDef &typeDef : typeDefs) + if (typeDef.getMnemonic()) + os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " + "{0}::{1}::parse(ctxt, parser);\n", + typeDef.getDialect().getCppNamespace(), + typeDef.getCppClassName()); + os << " return ::mlir::Type();\n"; + os << "}\n\n"; + + // The printer dispatch uses llvm::TypeSwitch to find and call the correct + // printer. + os << "::mlir::LogicalResult generatedTypePrinter(::mlir::Type type, " + "::mlir::DialectAsmPrinter& printer) {\n" + << " ::mlir::LogicalResult found = ::mlir::success();\n" + << " ::llvm::TypeSwitch<::mlir::Type>(type)\n"; + for (auto typeDef : typeDefs) + if (typeDef.getMnemonic()) + os << formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ " + "t.dyn_cast<{0}::{1}>().print(printer); })\n", + typeDef.getDialect().getCppNamespace(), + typeDef.getCppClassName()); + os << " .Default([&found](::mlir::Type) { found = ::mlir::failure(); " + "});\n" + << " return found;\n" + << "}\n\n"; +} + +/// Entry point for typedef definitions. +static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper, + raw_ostream &os) { + emitSourceFileHeader("TypeDef Definitions", os); + + SmallVector typeDefs; + findAllTypeDefs(recordKeeper, typeDefs); + emitTypeDefList(typeDefs, os); + + IfDefScope scope("GET_TYPEDEF_CLASSES", os); + emitParsePrintDispatch(typeDefs, os); + for (auto typeDef : typeDefs) + emitTypeDefDef(typeDef, os); + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef registration hooks +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return emitTypeDefDefs(records, os); + }); + +static mlir::GenRegistration + genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return emitTypeDefDecls(records, os); + });