diff --git a/mlir/cmake/modules/AddMLIR.cmake b/mlir/cmake/modules/AddMLIR.cmake --- a/mlir/cmake/modules/AddMLIR.cmake +++ b/mlir/cmake/modules/AddMLIR.cmake @@ -9,6 +9,8 @@ set(LLVM_TARGET_DEFINITIONS ${dialect}.td) mlir_tablegen(${dialect}.h.inc -gen-op-decls) mlir_tablegen(${dialect}.cpp.inc -gen-op-defs) + mlir_tablegen(${dialect}Types.h.inc -gen-typedef-decls) + mlir_tablegen(${dialect}Types.cpp.inc -gen-typedef-defs) mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace}) add_public_tablegen_target(MLIR${dialect}IncGen) add_dependencies(mlir-headers MLIR${dialect}IncGen) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2351,4 +2351,92 @@ // so to replace the matched DAG with an existing SSA value. def replaceWithValue; + +//===----------------------------------------------------------------------===// +// Data type generation +//===----------------------------------------------------------------------===// + +// Define a new type belonging to a dialect and called 'name'. +class TypeDef { + Dialect dialect = owningdialect; + string cppClassName = name # "Type"; + + // Short summary of the type + string summary = ?; + // The longer description of this type + string description = ?; + + // Name of storage class to generate or use + string storageClass = name # "TypeStorage"; + // Namespace (withing dialect c++ namespace) in which the storage class resides + string storageNamespace = "detail"; + // Should we generate the storage class? (Or use an existing one?) + bit genStorageClass = 1; + // Should we generate the storage class constructor? + bit hasStorageCustomConstructor = 0; + + // This is the list of fields in the storage class (and list of parameters + // in the creation functions). If empty, don't use or generate a storage class. + dag parameters = (ins); + + // Use the lowercased name as the keyword for parsing/printing. Specify only + // if you want tblgen to generate decls and/or defs of printer/parser for this + // type. + string mnemonic = ?; + + // If 'mnemonic' specified, + // If null, generate just the declarations. + // Error if an empty code block. + // If a non-empty code block, just use that code as the definition code. + code printer = ?; + code parser = ?; + + // If set, generate accessors for each Type parameter. + bit genAccessors = 1; + // Generate the verifyConstructionInvariants declaration and getChecked method. + bit genVerifyInvariantsDecl = 0; + // Extra code to include in the class declaration + code extraClassDeclaration = [{}]; +} + +// 'Parameters' should be subclasses of this or simple strings (which is a +// shorthand for TypeParameter<"C++Type">). +class TypeParameter { + // Custom memory allocation code for storage constructor + code allocator = ?; + // The C++ type of this parameter + string cppType = type; + // A description of this parameter + string description = desc; + // The format string for the asm syntax (documentation only) + string syntax = ?; +} + +// For StringRefs, which require allocation +class StringRefParameter : TypeParameter<"::llvm::StringRef", desc> { + let allocator = [{$_dst = $_allocator.copyInto($_self);}]; +} + +// For standard ArrayRefs, which require allocation +class ArrayRefParameter : TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { + let allocator = [{$_dst = $_allocator.copyInto($_self);}]; +} + +// For classes which require allocation and have their own allocateInto method +class SelfAllocationParameter : TypeParameter { + let allocator = [{$_dst = $_self.allocateInto($_allocator);}]; +} + +// For ArrayRefs which contain things which allocate themselves +class ArrayRefOfSelfAllocationParameter : TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { + let allocator = [{ + llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields; + for (size_t i = 0; i < $_self.size(); i++) { + tmpFields.push_back($_self[i].allocateInto($_allocator)); + } + $_dst = $_allocator.copyInto(ArrayRef<}] # arrayOf # [{>(tmpFields)); + }]; +} + + #endif // OP_BASE diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/TypeDef.h @@ -0,0 +1,142 @@ +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TypeDef wrapper to simplify using TableGen Record defining a MLIR type. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_TYPEDEF_H +#define MLIR_TABLEGEN_TYPEDEF_H + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Dialect.h" +#include "llvm/TableGen/Record.h" +#include +#include + +namespace mlir { +namespace tblgen { + +class TypeParameter; + +// Wrapper class that contains a TableGen TypeDef's record and provides helper +// methods for accessing them. +class TypeDef { +public: + explicit TypeDef(const llvm::Record *def) : def(def) {} + + // Get the dialect for which this type belongs + Dialect getDialect() const; + + // Returns the name of this TypeDef record + StringRef getName() const; + + // Query functions for the documentation of the operator. + bool hasDescription() const; + StringRef getDescription() const; + bool hasSummary() const; + StringRef getSummary() const; + + // Returns the name of the C++ class to generate + StringRef getCppClassName() const; + + // Returns the name of the storage class for this type + StringRef getStorageClassName() const; + + // Returns the C++ namespace for this types storage class + StringRef getStorageNamespace() const; + + // Returns true if we should generate the storage class + bool genStorageClass() const; + + // Should we generate the storage class constructor? + bool hasStorageCustomConstructor() const; + + // Return the list of fields for the storage class and constructors + void getParameters(SmallVectorImpl &) const; + unsigned getNumParameters() const; + + // Iterate though parameters, applying a map function before adding to list + template + void getParametersAs(SmallVectorImpl ¶meters, + llvm::function_ref map) const; + + // Return the keyword/mnemonic to use in the printer/parser methods if we are + // supposed to auto-generate them + llvm::Optional getMnemonic() const; + + // Returns the code to use as the types printer method. If not specified, + // return a non-value. Otherwise, return the contents of that code block. + llvm::Optional getPrinterCode() const; + + // Returns the code to use as the types parser method. If not specified, + // return a non-value. Otherwise, return the contents of that code block. + llvm::Optional getParserCode() const; + + // Should we generate accessors based on the types parameters? + bool genAccessors() const; + + // Return true if we need to generate the verifyConstructionInvariants + // declaration and getChecked method + bool genVerifyInvariantsDecl() const; + + // Returns the dialects extra class declaration code. + llvm::Optional getExtraDecls() const; + + // Get the code location (for error printing) + llvm::ArrayRef getLoc() const; + + // Returns whether two TypeDefs are equal by checking the equality of the + // underlying record. + bool operator==(const TypeDef &other) const; + + // Compares two TypeDefs by comparing the names of the dialects. + bool operator<(const TypeDef &other) const; + + // Returns whether the TypeDef is defined. + operator bool() const { return def != nullptr; } + +private: + const llvm::Record *def; +}; + +// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs +// to parameterize them. +class TypeParameter { +public: + explicit TypeParameter(const llvm::DagInit *def, unsigned num) + : def(def), num(num) {} + + // Get the parameter name + StringRef getName() const; + // If specified, get the custom allocator code for this parameter + llvm::Optional getAllocator() const; + // Get the C++ type of this parameter + StringRef getCppType() const; + // Get a description of this parameter for documentation purposes + llvm::Optional getDescription() const; + // Get the assembly syntax documentation + StringRef getSyntax() const; + +private: + const llvm::DagInit *def; + const unsigned num; +}; + +template +void TypeDef::getParametersAs(SmallVectorImpl ¶meters, + llvm::function_ref map) const { + auto parametersDag = def->getValueAsDag("parameters"); + if (parametersDag != nullptr) + for (unsigned i = 0; i < parametersDag->getNumArgs(); i++) + parameters.push_back(map(TypeParameter(parametersDag, i))); +} + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_TYPEDEF_H diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -25,6 +25,7 @@ SideEffects.cpp Successor.cpp Type.cpp + TypeDef.cpp DISABLE_LLVM_LINK_LLVM_DYLIB diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/TypeDef.cpp @@ -0,0 +1,168 @@ +//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/TypeDef.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +Dialect TypeDef::getDialect() const { + auto *dialectDef = + dyn_cast(def->getValue("dialect")->getValue()); + if (dialectDef == nullptr) + return Dialect(nullptr); + return Dialect(dialectDef->getDef()); +} + +StringRef TypeDef::getName() const { return def->getName(); } +StringRef TypeDef::getCppClassName() const { + return def->getValueAsString("cppClassName"); +} + +bool TypeDef::hasDescription() const { + const llvm::RecordVal *s = def->getValue("description"); + return s != nullptr && isa(s->getValue()); +} + +StringRef TypeDef::getDescription() const { + return def->getValueAsString("description"); +} + +bool TypeDef::hasSummary() const { + const llvm::RecordVal *s = def->getValue("summary"); + return s != nullptr && isa(s->getValue()); +} + +StringRef TypeDef::getSummary() const { + return def->getValueAsString("summary"); +} + +StringRef TypeDef::getStorageClassName() const { + return def->getValueAsString("storageClass"); +} +StringRef TypeDef::getStorageNamespace() const { + return def->getValueAsString("storageNamespace"); +} + +bool TypeDef::genStorageClass() const { + return def->getValueAsBit("genStorageClass"); +} +bool TypeDef::hasStorageCustomConstructor() const { + return def->getValueAsBit("hasStorageCustomConstructor"); +} +void TypeDef::getParameters(SmallVectorImpl ¶meters) const { + auto *parametersDag = def->getValueAsDag("parameters"); + if (parametersDag != nullptr) { + size_t numParams = parametersDag->getNumArgs(); + for (unsigned i = 0; i < numParams; i++) { + parameters.push_back(TypeParameter(parametersDag, i)); + } + } +} +unsigned TypeDef::getNumParameters() const { + auto *parametersDag = def->getValueAsDag("parameters"); + return parametersDag ? parametersDag->getNumArgs() : 0; +} +llvm::Optional TypeDef::getMnemonic() const { + return def->getValueAsOptionalString("mnemonic"); +} +llvm::Optional TypeDef::getPrinterCode() const { + return def->getValueAsOptionalCode("printer"); +} +llvm::Optional TypeDef::getParserCode() const { + return def->getValueAsOptionalCode("parser"); +} +bool TypeDef::genAccessors() const { + return def->getValueAsBit("genAccessors"); +} +bool TypeDef::genVerifyInvariantsDecl() const { + return def->getValueAsBit("genVerifyInvariantsDecl"); +} +llvm::Optional TypeDef::getExtraDecls() const { + auto value = def->getValueAsString("extraClassDeclaration"); + return value.empty() ? llvm::Optional() : value; +} +llvm::ArrayRef TypeDef::getLoc() const { return def->getLoc(); } +bool TypeDef::operator==(const TypeDef &other) const { + return def == other.def; +} + +bool TypeDef::operator<(const TypeDef &other) const { + return getName() < other.getName(); +} + +StringRef TypeParameter::getName() const { + return def->getArgName(num)->getValue(); +} +llvm::Optional TypeParameter::getAllocator() const { + auto *parameterType = def->getArg(num); + if (auto *stringType = dyn_cast(parameterType)) { + return llvm::Optional(); + } + + if (auto *typeParameter = dyn_cast(parameterType)) { + auto *code = typeParameter->getDef()->getValue("allocator"); + if (llvm::CodeInit *ci = dyn_cast(code->getValue())) + return ci->getValue(); + if (isa(code->getValue())) + return llvm::Optional(); + + llvm::PrintFatalError( + typeParameter->getDef()->getLoc(), + "Record `" + def->getArgName(num)->getValue() + + "', field `printer' does not have a code initializer!"); + } + + llvm::PrintFatalError( + "Parameters DAG arguments must be either strings or defs " + "which inherit from TypeParameter\n"); +} +StringRef TypeParameter::getCppType() const { + auto *parameterType = def->getArg(num); + if (auto *stringType = dyn_cast(parameterType)) { + return stringType->getValue(); + } + if (auto *typeParameter = dyn_cast(parameterType)) { + return typeParameter->getDef()->getValueAsString("cppType"); + } + llvm::PrintFatalError( + "Parameters DAG arguments must be either strings or defs " + "which inherit from TypeParameter\n"); +} +llvm::Optional TypeParameter::getDescription() const { + + auto *parameterType = def->getArg(num); + if (auto *typeParameter = dyn_cast(parameterType)) { + const auto *desc = typeParameter->getDef()->getValue("description"); + if (llvm::StringInit *ci = dyn_cast(desc->getValue())) + return ci->getValue(); + } + return llvm::Optional(); +} +StringRef TypeParameter::getSyntax() const { + auto *parameterType = def->getArg(num); + if (auto *stringType = dyn_cast(parameterType)) { + return stringType->getValue(); + } else if (auto *typeParameter = dyn_cast(parameterType)) { + const auto *syntax = typeParameter->getDef()->getValue("syntax"); + if (syntax != nullptr && isa(syntax->getValue())) + return dyn_cast(syntax->getValue())->getValue(); + return getCppType(); + } else { + llvm::errs() << "Parameters DAG arguments must be either strings or defs " + "which inherit from TypeParameter\n"; + return StringRef(); + } +} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -8,6 +8,12 @@ mlir_tablegen(TestTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIRTestInterfaceIncGen) +set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td) +mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls) +mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs) +add_public_tablegen_target(MLIRTestDefIncGen) + + set(LLVM_TARGET_DEFINITIONS TestOps.td) mlir_tablegen(TestOps.h.inc -gen-op-decls) mlir_tablegen(TestOps.cpp.inc -gen-op-defs) @@ -23,11 +29,13 @@ add_mlir_library(MLIRTestDialect TestDialect.cpp TestPatterns.cpp + TestTypes.cpp EXCLUDE_FROM_LIBMLIR DEPENDS MLIRTestInterfaceIncGen + MLIRTestDefIncGen MLIRTestOpsIncGen LINK_LIBS PUBLIC diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -141,16 +141,23 @@ >(); addInterfaces(); - addTypes(); + addTypes(); allowUnknownOperations(); } -static Type parseTestType(DialectAsmParser &parser, +static Type parseTestType(MLIRContext *ctxt, DialectAsmParser &parser, llvm::SetVector &stack) { StringRef typeTag; if (failed(parser.parseKeyword(&typeTag))) return Type(); + auto genType = generatedTypeParser(ctxt, parser, typeTag); + if (genType != Type()) + return genType; + if (typeTag == "test_type") return TestType::get(parser.getBuilder().getContext()); @@ -174,7 +181,7 @@ if (failed(parser.parseComma())) return Type(); stack.insert(rec); - Type subtype = parseTestType(parser, stack); + Type subtype = parseTestType(ctxt, parser, stack); stack.pop_back(); if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) return Type(); @@ -184,11 +191,13 @@ Type TestDialect::parseType(DialectAsmParser &parser) const { llvm::SetVector stack; - return parseTestType(parser, stack); + return parseTestType(getContext(), parser, stack); } static void printTestType(Type type, DialectAsmPrinter &printer, llvm::SetVector &stack) { + if (!generatedTypePrinter(type, printer)) + return; if (type.isa()) { printer << "test_type"; return; diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -0,0 +1,128 @@ +//===-- TestTypeDefs.td - Test dialect type definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_TYPEDEFS +#define TEST_TYPEDEFS + +// To get the test dialect def +include "TestOps.td" + +class Test_Type : TypeDef { } + +def SimpleTypeA : Test_Type<"SimpleA"> { + let mnemonic = "smpla"; + + let printer = [{ $_printer << "smpla"; }]; + let parser = [{ return get($_ctxt); }]; +} + +// A more complex parameterized type +def CompoundTypeA : Test_Type<"CompoundA"> { + // Override the default mnemonic + let mnemonic = "cmpnd_a"; + + // What types do we contain? + let parameters = ( + ins + "int":$widthOfSomething, + "::mlir::Type":$oneType, + ArrayRefParameter<"int", "An example of an array of ints">: $arrayOfInts + ); + + let extraClassDeclaration = [{ + struct SomeCppStruct {}; + }]; +} + +def IntegerType : Test_Type<"TestInteger"> { + let mnemonic = "int"; + let genVerifyInvariantsDecl = 1; + let parameters = ( + ins + "::mlir::TestIntegerType::SignednessSemantics":$signedness, + "unsigned":$width + ); + + let printer = [{ + $_printer << "int<"; + Print($_printer, getImpl()->signedness); + $_printer << ", " << getImpl()->width << ">"; + }]; + + let parser = [{ + if (parser.parseLess()) return Type(); + SignednessSemantics signedness; + if (Parse($_parser, signedness)) return mlir::Type(); + if ($_parser.parseComma()) return Type(); + int width; + if ($_parser.parseInteger(width)) return Type(); + if ($_parser.parseGreater()) return Type(); + return get(ctxt, signedness, width); + }]; + + let extraClassDeclaration = [{ + /// Signedness semantics. + enum SignednessSemantics { + Signless, /// No signedness semantics + Signed, /// Signed integer + Unsigned, /// Unsigned integer + }; + + /// This extra function is necessary since it doesn't include signedness + static IntegerType getChecked(unsigned width, Location location); + + /// Return true if this is a signless integer type. + bool isSignless() const { return getSignedness() == Signless; } + /// Return true if this is a signed integer type. + bool isSigned() const { return getSignedness() == Signed; } + /// Return true if this is an unsigned integer type. + bool isUnsigned() const { return getSignedness() == Unsigned; } + }]; +} + +class FieldInfo_Type : Test_Type { + let parameters = ( + ins + ArrayRefOfSelfAllocationParameter<"::mlir::FieldInfo", "Models struct fields">: $fields + ); + + let printer = [{ + $_printer << "struct" << "<"; + for (size_t i=0; ifields.size(); i++) { + const auto& field = getImpl()->fields[i]; + $_printer << "{" << field.name << "," << field.type << "}"; + if (i < getImpl()->fields.size() - 1) + $_printer << ","; + } + $_printer << ">"; + }]; + + let parser = [{ + llvm::SmallVector parameters; + if ($_parser.parseLess()) return Type(); + while (mlir::succeeded($_parser.parseOptionalLBrace())) { + StringRef name; + if ($_parser.parseKeyword(&name)) return Type(); + if ($_parser.parseComma()) return Type(); + Type type; + if ($_parser.parseType(type)) return Type(); + if ($_parser.parseRBrace()) return Type(); + parameters.push_back(FieldInfo {name, type}); + if ($_parser.parseOptionalComma()) break; + } + if ($_parser.parseGreater()) return Type(); + return get($_ctxt, parameters); + }]; +} + +def StructType : FieldInfo_Type<"Struct"> { + let mnemonic = "struct"; +} + + +#endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -14,11 +14,36 @@ #ifndef MLIR_TESTTYPES_H #define MLIR_TESTTYPES_H +#include + #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" namespace mlir { +struct FieldInfo { +public: + StringRef name; + Type type; + + FieldInfo allocateInto(TypeStorageAllocator &alloc) const { + return FieldInfo{alloc.copyInto(name), type}; + } +}; + +bool operator==(const FieldInfo &a, const FieldInfo &b); +llvm::hash_code hash_value(const FieldInfo &fi); + +} // namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.h.inc" + +namespace mlir { + #include "TestTypeInterfaces.h.inc" /// This class is a simple test type that uses a generated interface. diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -0,0 +1,122 @@ +//===- TestTypes.cpp - MLIR Test Dialect Types ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains types defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#include "TestTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir { + +// Custom parser for SignednessSemantics +static ParseResult Parse(DialectAsmParser &parser, + TestIntegerType::SignednessSemantics &result) { + StringRef signStr; + auto loc = parser.getCurrentLocation(); + if (parser.parseKeyword(&signStr)) + return mlir::failure(); + if (signStr.compare_lower("u") || signStr.compare_lower("unsigned")) + result = TestIntegerType::SignednessSemantics::Unsigned; + else if (signStr.compare_lower("s") || signStr.compare_lower("signed")) + result = TestIntegerType::SignednessSemantics::Signed; + else if (signStr.compare_lower("n") || signStr.compare_lower("none")) + result = TestIntegerType::SignednessSemantics::Signless; + else { + parser.emitError(loc, "expected signed, unsigned, or none"); + return mlir::failure(); + } + return mlir::success(); +} + +// Custom printer for SignednessSemantics +static void Print(DialectAsmPrinter &printer, + const TestIntegerType::SignednessSemantics &ss) { + switch (ss) { + case TestIntegerType::SignednessSemantics::Unsigned: + printer << "unsigned"; + break; + case TestIntegerType::SignednessSemantics::Signed: + printer << "signed"; + break; + case TestIntegerType::SignednessSemantics::Signless: + printer << "none"; + break; + } +} + +Type CompoundAType::parse(::mlir::MLIRContext *ctxt, + ::mlir::DialectAsmParser &parser) { + int widthOfSomething; + Type oneType; + SmallVector arrayOfInts; + if (parser.parseLess()) + return Type(); + if (parser.parseInteger(widthOfSomething)) + return Type(); + if (parser.parseComma()) + return Type(); + if (parser.parseType(oneType)) + return Type(); + if (parser.parseComma()) + return Type(); + + if (parser.parseLSquare()) + return Type(); + int i; + while (!*parser.parseOptionalInteger(i)) { + arrayOfInts.push_back(i); + if (parser.parseOptionalComma()) + break; + } + if (parser.parseRSquare()) + return Type(); + if (parser.parseGreater()) + return Type(); + + return get(ctxt, widthOfSomething, oneType, arrayOfInts); +} +void CompoundAType::print(::mlir::DialectAsmPrinter &printer) const { + printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType() + << ", ["; + auto intArray = getArrayOfInts(); + for (size_t idx = 0; idx < intArray.size(); idx++) { + printer << intArray[idx]; + if (idx < intArray.size() - 1) + printer << ", "; + } + printer << "]>"; +} + +bool operator==(const FieldInfo &a, const FieldInfo &b) { + return a.name == b.name && a.type == b.type; +} + +llvm::hash_code hash_value(const FieldInfo &fi) { + return llvm::hash_combine(fi.name, fi.type); +} + +// Example type validity checker +LogicalResult TestIntegerType::verifyConstructionInvariants( + mlir::Location loc, mlir::TestIntegerType::SignednessSemantics ss, + unsigned int width) { + + if (width > 8) + return mlir::failure(); + return mlir::success(); +} + +} // end namespace mlir + +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.cpp.inc" diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s + +////////////// +// Tests the types in the 'Test' dialect, not the ones in 'typedefs.mlir' + +// CHECK: @simpleA(%arg0: !test.smpla) +func @simpleA(%A : !test.smpla) -> () { + return +} + +// CHECK: @compoundA(%arg0: !test.cmpnd_a<1, !test.smpla, [5, 6]>) +func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () { + return +} + +// CHECK: @testInt(%arg0: !test.int, %arg1: !test.int, %arg2: !test.int) +func @testInt(%A : !test.int, %B : !test.int, %C : !test.int) { + return +} + +// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int}>) +func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int} > ) { + return +} diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -0,0 +1,132 @@ +// RUN: mlir-tblgen -gen-typedef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL +// RUN: mlir-tblgen -gen-typedef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF + +include "mlir/IR/OpBase.td" + +// DECL: #ifdef GET_TYPEDEF_CLASSES +// DECL: #undef GET_TYPEDEF_CLASSES + +// DECL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic); +// DECL: bool generatedTypePrinter(::mlir::Type type, ::mlir::DialectAsmPrinter& printer); + +// DEF: #ifdef GET_TYPEDEF_LIST +// DEF: #undef GET_TYPEDEF_LIST +// DEF: ::mlir::test::SimpleAType, +// DEF: ::mlir::test::CompoundAType, +// DEF: ::mlir::test::IndexType, +// DEF: ::mlir::test::SingleParameterType, +// DEF: ::mlir::test::IntegerType + +// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) +// DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(ctxt, parser); +// DEF return ::mlir::Type(); + +def Test_Dialect: Dialect { +// DECL-NOT: TestDialect +// DEF-NOT: TestDialect + let name = "TestDialect"; + let cppNamespace = "::mlir::test"; +} + +class TestType : TypeDef { } + +def A_SimpleTypeA : TestType<"SimpleA"> { +// DECL: class SimpleAType: public ::mlir::Type +} + +// A more complex parameterized type +def B_CompoundTypeA : TestType<"CompoundA"> { + let summary = "A more complex parameterized type"; + let description = "This type is to test a reasonably complex type"; + let mnemonic = "cmpnd_a"; + let parameters = ( + ins + "int":$widthOfSomething, + "::mlir::test::SimpleTypeA": $exampleTdType, + "SomeCppStruct": $exampleCppType, + ArrayRefParameter<"int", "Matrix dimensions">:$dims + ); + + let genVerifyInvariantsDecl = 1; + +// DECL-LABEL: class CompoundAType: public ::mlir::Type +// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); +// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims); +// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; } +// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); +// DECL: void print(::mlir::DialectAsmPrinter& printer) const; +// DECL: int getWidthOfSomething() const; +// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const; +// DECL: SomeCppStruct getExampleCppType() const; +} + +def C_IndexType : TestType<"Index"> { + let mnemonic = "index"; + + let parameters = ( + ins + StringRefParameter<"Label for index">:$label + ); + +// DECL-LABEL: class IndexType: public ::mlir::Type +// DECL: static ::llvm::StringRef getMnemonic() { return "index"; } +// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); +// DECL: void print(::mlir::DialectAsmPrinter& printer) const; +} + +def D_SingleParameterType : TestType<"SingleParameter"> { + let parameters = ( + ins + "int": $num + ); +// DECL-LABEL: struct SingleParameterTypeStorage; +// DECL-LABEL: class SingleParameterType +// DECL-NEXT: detail::SingleParameterTypeStorage +} + +def E_IntegerType : TestType<"Integer"> { + let mnemonic = "int"; + let genVerifyInvariantsDecl = 1; + let parameters = ( + ins + "SignednessSemantics":$signedness, + TypeParameter<"unsigned", "Bitwdith of integer">:$width + ); + +// DECL-LABEL: IntegerType: public ::mlir::Type + + let extraClassDeclaration = [{ + /// Signedness semantics. + enum SignednessSemantics { + Signless, /// No signedness semantics + Signed, /// Signed integer + Unsigned, /// Unsigned integer + }; + + /// This extra function is necessary since it doesn't include signedness + static IntegerType getChecked(unsigned width, Location location); + + /// Return true if this is a signless integer type. + bool isSignless() const { return getSignedness() == Signless; } + /// Return true if this is a signed integer type. + bool isSigned() const { return getSignedness() == Signed; } + /// Return true if this is an unsigned integer type. + bool isUnsigned() const { return getSignedness() == Unsigned; } + }]; + +// DECL: /// Signedness semantics. +// DECL-NEXT: enum SignednessSemantics { +// DECL-NEXT: Signless, /// No signedness semantics +// DECL-NEXT: Signed, /// Signed integer +// DECL-NEXT: Unsigned, /// Unsigned integer +// DECL-NEXT: }; +// DECL: /// This extra function is necessary since it doesn't include signedness +// DECL-NEXT: static IntegerType getChecked(unsigned width, Location location); + +// DECL: /// Return true if this is a signless integer type. +// DECL-NEXT: bool isSignless() const { return getSignedness() == Signless; } +// DECL-NEXT: /// Return true if this is a signed integer type. +// DECL-NEXT: bool isSigned() const { return getSignedness() == Signed; } +// DECL-NEXT: /// Return true if this is an unsigned integer type. +// DECL-NEXT: bool isUnsigned() const { return getSignedness() == Unsigned; } +} diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -20,6 +20,7 @@ RewriterGen.cpp SPIRVUtilsGen.cpp StructsGen.cpp + TypeDefGen.cpp ) set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning") diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -14,6 +14,7 @@ #include "DocGenUtilities.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/TypeDef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -22,6 +23,8 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#include + using namespace llvm; using namespace mlir; using namespace mlir::tblgen; @@ -185,12 +188,66 @@ os << "\n"; } +//===----------------------------------------------------------------------===// +// TypeDef Documentation +//===----------------------------------------------------------------------===// + +/// Emit the assembly format of a type. +static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) { + SmallVector parameters; + td.getParameters(parameters); + if (parameters.size() == 0) { + os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic() + << "`\n"; + } else { + os << "\nSyntax:\n\n```\n!" << td.getDialect().getName() << "." + << td.getMnemonic() << "<\n"; + for (auto *it = parameters.begin(); it < parameters.end(); it++) { + os << " " << it->getSyntax(); + if (it < parameters.end() - 1) + os << ","; + os << " # " << it->getName() << "\n"; + } + os << ">\n```\n"; + } +} + +static void emitTypeDefDoc(TypeDef td, raw_ostream &os) { + os << llvm::formatv("### `{0}` ({1})\n", td.getName(), td.getCppClassName()); + + // Emit the summary, syntax, and description if present. + if (td.hasSummary()) + os << "\n" << td.getSummary() << "\n"; + if (td.getMnemonic() && td.getPrinterCode() && *td.getPrinterCode() == "" && + td.getParserCode() && *td.getParserCode() == "") + emitTypeAssemblyFormat(td, os); + if (td.hasDescription()) + mlir::tblgen::emitDescription(td.getDescription(), os); + + // Emit attributes. + SmallVector parameters; + td.getParameters(parameters); + if (parameters.size() != 0) { + os << "\n#### Type parameters:\n\n"; + os << "| Parameter | C++ type | Description |\n" + << "| :-------: | :-------: | ----------- |\n"; + for (const auto &it : parameters) { + auto desc = it.getDescription(); + os << "| " << it.getName() << " | `" << td.getCppClassName() << "` | " + << (desc ? *desc : "") << " |\n"; + } + } + + os << "\n"; +} + //===----------------------------------------------------------------------===// // Dialect Documentation //===----------------------------------------------------------------------===// static void emitDialectDoc(const Dialect &dialect, ArrayRef ops, - ArrayRef types, raw_ostream &os) { + ArrayRef types, ArrayRef typeDefs, + raw_ostream &os) { os << "# '" << dialect.getName() << "' Dialect\n\n"; emitIfNotEmpty(dialect.getSummary(), os); emitIfNotEmpty(dialect.getDescription(), os); @@ -199,7 +256,7 @@ // TODO: Add link between use and def for types if (!types.empty()) { - os << "## Type definition\n\n"; + os << "## Type constraint definition\n\n"; for (const Type &type : types) emitTypeDoc(type, os); } @@ -209,28 +266,44 @@ for (const Operator &op : ops) emitOpDoc(op, os); } + + if (!typeDefs.empty()) { + os << "## Type definition\n\n"; + for (const TypeDef &td : typeDefs) { + emitTypeDefDoc(td, os); + } + } } static void emitDialectDoc(const RecordKeeper &recordKeeper, raw_ostream &os) { const auto &opDefs = recordKeeper.getAllDerivedDefinitions("Op"); const auto &typeDefs = recordKeeper.getAllDerivedDefinitions("DialectType"); + const auto &typeDefDefs = recordKeeper.getAllDerivedDefinitions("TypeDef"); + std::set dialectsWithDocs; std::map> dialectOps; std::map> dialectTypes; + std::map> dialectTypeDefs; for (auto *opDef : opDefs) { Operator op(opDef); dialectOps[op.getDialect()].push_back(op); + dialectsWithDocs.insert(op.getDialect()); } for (auto *typeDef : typeDefs) { Type type(typeDef); if (auto dialect = type.getDialect()) dialectTypes[dialect].push_back(type); } + for (auto *typeDef : typeDefDefs) { + TypeDef type(typeDef); + dialectTypeDefs[type.getDialect()].push_back(type); + dialectsWithDocs.insert(type.getDialect()); + } os << "\n"; - for (auto dialectWithOps : dialectOps) - emitDialectDoc(dialectWithOps.first, dialectWithOps.second, - dialectTypes[dialectWithOps.first], os); + for (auto dialect : dialectsWithDocs) + emitDialectDoc(dialect, dialectOps[dialect], dialectTypes[dialect], + dialectTypeDefs[dialect], os); } //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -0,0 +1,568 @@ +//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TypeDefGen uses the description of typeDefs to generate C++ definitions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Interfaces.h" +#include "mlir/TableGen/OpClass.h" +#include "mlir/TableGen/OpTrait.h" +#include "mlir/TableGen/Operator.h" +#include "mlir/TableGen/TypeDef.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +#define DEBUG_TYPE "mlir-tblgen-typedefgen" + +using namespace mlir; +using namespace mlir::tblgen; + +static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*"); +static llvm::cl::opt + selectedDialect("typedefs-dialect", + llvm::cl::desc("Gen types for this dialect"), + llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated); + +/// Find all the TypeDefs for the specified dialect. If no dialect specified and +/// can only find one dialect's types, use that. +static mlir::LogicalResult +findAllTypeDefs(const llvm::RecordKeeper &recordKeeper, + SmallVectorImpl &typeDefs) { + auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef"); + auto defs = llvm::map_range( + recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); }); + if (defs.empty()) + return mlir::success(); + + StringRef dialectName; + if (selectedDialect.getNumOccurrences() == 0) { + if (defs.empty()) + return mlir::success(); + + llvm::SmallSet dialects; + for (auto typeDef : defs) { + dialects.insert(typeDef.getDialect()); + } + if (dialects.size() != 1) { + llvm::errs() << "TypeDefs belonging to more than one dialect. Must " + "select one via '--typedefs-dialect'\n"; + return mlir::failure(); + } + + dialectName = (*dialects.begin()).getName(); + } else if (selectedDialect.getNumOccurrences() == 1) { + dialectName = selectedDialect.getValue(); + } else { + llvm::errs() + << "cannot select multiple dialects for which to generate types" + "via '--typedefs-dialect'\n"; + return mlir::failure(); + } + + for (auto typeDef : defs) { + if (typeDef.getDialect().getName().equals(dialectName)) + typeDefs.push_back(typeDef); + } + return mlir::success(); +} + +/// Create a string list of parameters and types for function decls +/// String construction helper function: parameter1Type parameter1Name, +/// parameter2Type parameter2Name +static std::string constructParameterParameters(TypeDef &typeDef, + bool prependComma) { + SmallVector parameters; + if (prependComma) + parameters.push_back(""); + typeDef.getParametersAs(parameters, [](auto parameter) { + return (parameter.getCppType() + " " + parameter.getName()).str(); + }); + if (parameters.size() > 0) + return llvm::join(parameters, ", "); + return ""; +} + +/// Create an initializer for the storage class +/// String construction helper function: parameter1(parameter1), +/// parameter2(parameter2), +/// [...] +static std::string constructParametersInitializers(TypeDef &typeDef) { + SmallVector parameters; + typeDef.getParametersAs(parameters, [](auto parameter) { + return (parameter.getName() + "(" + parameter.getName() + ")").str(); + }); + return llvm::join(parameters, ", "); +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef declarations +//===----------------------------------------------------------------------===// + +/// The code block for the start of a typeDef class declaration -- singleton +/// case +/// +/// {0}: The name of the typeDef class. +static const char *const typeDefDeclSingletonBeginStr = R"( + class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, ::mlir::TypeStorage> {{ +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + +)"; + +/// The code block for the start of a typeDef class declaration -- parametric +/// case +/// +/// {0}: The name of the typeDef class. +/// {1}: The typeDef storage class namespace. +/// {2}: The storage class name +/// {3}: The list of parameters with types +static const char *const typeDefDeclParametricBeginStr = R"( + namespace {1} { + struct {2}; + } + class {0}: public ::mlir::Type::TypeBase<{0}, ::mlir::Type, + {1}::{2}> {{ +public: + /// Inherit some necessary constructors from 'TypeBase'. + using Base::Base; + +)"; + +// snippet for print/parse +static const char *const typeDefParsePrint = R"( + static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser); + void print(::mlir::DialectAsmPrinter& printer) const; +)"; + +/// The code block for the verifyConstructionInvariants and getChecked +/// +/// {0}: List of parameters, parameters style +/// {1}: C++ type class name +static const char *const typeDefDeclVerifyStr = R"( + static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0}); + static {1} getChecked(Location loc{0}); +)"; + +/// Generate the declaration for the given typeDef class. +static void emitTypeDefDecl(TypeDef &typeDef, raw_ostream &os) { + // Emit the beginning string template: either the singleton or parametric + // template + if (typeDef.getNumParameters() == 0) + os << llvm::formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(), + typeDef.getStorageNamespace(), + typeDef.getStorageClassName()); + else + os << llvm::formatv( + typeDefDeclParametricBeginStr, typeDef.getCppClassName(), + typeDef.getStorageNamespace(), typeDef.getStorageClassName()); + + // Emit the extra declarations first in case there's a type definition in + // there + if (llvm::Optional extraDecl = typeDef.getExtraDecls()) + os << *extraDecl; + + // Get the CppType1 param1, CppType2 param2 argument list + std::string parameterParameters = constructParameterParameters(typeDef, true); + + os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n", + typeDef.getCppClassName(), parameterParameters); + + // verify invariants + if (typeDef.genVerifyInvariantsDecl()) + os << llvm::formatv(typeDefDeclVerifyStr, parameterParameters, + typeDef.getCppClassName()); + + // mnenomic, if specified + if (auto mnenomic = typeDef.getMnemonic()) { + os << " static ::llvm::StringRef getMnemonic() { return \"" << mnenomic + << "\"; }\n"; + + // if mnemonic specified, emit print/parse declarations + os << typeDefParsePrint; + } + + if (typeDef.genAccessors()) { + SmallVector parameters; + typeDef.getParameters(parameters); + + for (auto parameter : parameters) { + SmallString<16> name = parameter.getName(); + name[0] = llvm::toUpper(name[0]); + os << llvm::formatv(" {0} get{1}() const;\n", parameter.getCppType(), + name); + } + } + + // End the typeDef decl. + os << " };\n"; +} + +/// Main entry point for decls +static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper, + raw_ostream &os) { + emitSourceFileHeader("TypeDef Declarations", os); + + SmallVector typeDefs; + if (mlir::failed(findAllTypeDefs(recordKeeper, typeDefs))) + return true; + + IfDefScope scope("GET_TYPEDEF_CLASSES", os); + if (typeDefs.size() > 0) { + NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect()); + // well known print/parse dispatch function declarations + os << " ::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " + "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnenomic);\n"; + os << " bool generatedTypePrinter(::mlir::Type type, " + "::mlir::DialectAsmPrinter& " + "printer);\n"; + os << "\n"; + + // declare all the type classes first (in case they reference each other) + for (auto typeDef : typeDefs) { + os << " class " << typeDef.getCppClassName() << ";\n"; + } + + // declare all the typedefs + for (auto typeDef : typeDefs) { + emitTypeDefDecl(typeDef, os); + } + } + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef list +//===----------------------------------------------------------------------===// + +static mlir::LogicalResult emitTypeDefList(SmallVectorImpl &typeDefs, + raw_ostream &os) { + IfDefScope scope("GET_TYPEDEF_LIST", os); + for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) { + os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName(); + if (i < typeDefs.end() - 1) + os << ",\n"; + else + os << "\n"; + } + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef definitions +//===----------------------------------------------------------------------===// + +/// Beginning of storage class +/// {0}: Storage class namespace +/// {1}: Storage class c++ name +/// {2}: Parameters parameters +/// {3}: Parameter initialzer string +/// {4}: Parameter name list; +/// {5}: Parameter types +static const char *const typeDefStorageClassBegin = R"( +namespace {0} {{ + struct {1} : public ::mlir::TypeStorage {{ + {1} ({2}) + : {3} {{ } + + /// The hash key for this storage is a pair of the integer and type params. + using KeyTy = std::tuple<{5}>; + + /// Define the comparison function for the key type. + bool operator==(const KeyTy &key) const {{ + return key == KeyTy({4}); + } + + static ::llvm::hash_code hashKey(const KeyTy &key) {{ +)"; + +/// The storage class' constructor template +/// {0}: storage class name +static const char *const typeDefStorageClassConstructorBegin = R"( + /// Define a construction method for creating a new instance of this storage. + static {0} *construct(::mlir::TypeStorageAllocator &allocator, + const KeyTy &key) {{ +)"; + +/// The storage class' constructor return template +/// {0}: storage class name +/// {1}: list of parameters +static const char *const typeDefStorageClassConstructorReturn = R"( + return new (allocator.allocate<{0}>()) + {0}({1}); + } +)"; + +/// use tgfmt to emit custom allocation code for each parameter, if necessary +static mlir::LogicalResult emitCustomAllocationCode(TypeDef &typeDef, + raw_ostream &os) { + SmallVector parameters; + typeDef.getParameters(parameters); + auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator"); + for (auto parameter : parameters) { + auto allocCode = parameter.getAllocator(); + if (allocCode) { + fmtCtxt.withSelf(parameter.getName()); + fmtCtxt.addSubst("_dst", parameter.getName()); + auto fmtObj = tgfmt(*allocCode, &fmtCtxt); + os << " "; + fmtObj.format(os); + os << "\n"; + } + } + return mlir::success(); +} + +static mlir::LogicalResult emitStorageClass(TypeDef typeDef, raw_ostream &os) { + SmallVector parameters; + typeDef.getParameters(parameters); + + // Initialize a bunch of variables to be used later on + auto parameterNames = llvm::map_range( + parameters, [](TypeParameter parameter) { return parameter.getName(); }); + auto parameterTypes = + llvm::map_range(parameters, [](TypeParameter parameter) { + return parameter.getCppType(); + }); + auto parameterList = llvm::join(parameterNames, ", "); + auto parameterTypeList = llvm::join(parameterTypes, ", "); + auto parameterParameters = constructParameterParameters(typeDef, false); + auto parameterInits = constructParametersInitializers(typeDef); + + // emit most of the storage class up until the hashKey body + os << llvm::formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(), + typeDef.getStorageClassName(), parameterParameters, + parameterInits, parameterList, parameterTypeList); + + // extract each parameter from the key (auto unboxing is a c++17 feature) + for (size_t i = 0; i < parameters.size(); i++) { + os << llvm::formatv(" auto {0} = std::get<{1}>(key);\n", + parameters[i].getName(), i); + } + // then combine them all. this requires all the parameters types to have a + // hash_value defined + os << " return ::llvm::hash_combine(\n"; + for (auto *parameterIter = parameters.begin(); + parameterIter < parameters.end(); parameterIter++) { + os << " " << parameterIter->getName(); + if (parameterIter < parameters.end() - 1) { + os << ",\n"; + } + } + os << ");\n"; + os << " }\n"; + + // if user wants to build the storage constructor themselves, declare it here + // and then they can write the definition elsewhere + if (typeDef.hasStorageCustomConstructor()) + os << " static " << typeDef.getStorageClassName() + << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy " + "&key);\n"; + else { + os << llvm::formatv(typeDefStorageClassConstructorBegin, + typeDef.getStorageClassName()); + // I want C++17's unboxing!!! + for (size_t i = 0; i < parameters.size(); i++) { + os << llvm::formatv(" auto {0} = std::get<{1}>(key);\n", + parameters[i].getName(), i); + } + // Reassign the parameter variables with allocation code, if it's specified + if (mlir::failed(emitCustomAllocationCode(typeDef, os))) + return mlir::failure(); + // return an allocated copy + os << llvm::formatv(typeDefStorageClassConstructorReturn, + typeDef.getStorageClassName(), parameterList); + } + + // Emit the parameters' class parameters + for (auto parameter : parameters) { + os << " " << parameter.getCppType() << " " << parameter.getName() + << ";\n"; + } + os << " };\n"; + os << "} // namespace " << typeDef.getStorageNamespace() << "\n"; + + return mlir::success(); +} + +/// Print all the typedef-specific definition code +static mlir::LogicalResult emitTypeDefDef(TypeDef typeDef, raw_ostream &os) { + NamespaceEmitter ns(os, typeDef.getDialect()); + SmallVector parameters; + typeDef.getParameters(parameters); + + // emit the storage class, if requested and necessary + if (typeDef.genStorageClass() && typeDef.getNumParameters() > 0) + if (mlir::failed(emitStorageClass(typeDef, os))) + return mlir::failure(); + + std::string paramFuncParams = constructParameterParameters(typeDef, true); + SmallVector paramNames; + paramNames.push_back(""); + typeDef.getParametersAs( + paramNames, [](TypeParameter param) { return param.getName(); }); + os << llvm::formatv("{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n" + " return Base::get(ctxt{2});\n" + "}\n", + typeDef.getCppClassName(), paramFuncParams, + llvm::join(paramNames, ",")); + // emit the accessors + if (typeDef.genAccessors()) { + for (auto parameter : parameters) { + SmallString<16> name = parameter.getName(); + name[0] = llvm::toUpper(name[0]); + os << llvm::formatv( + "{0} {3}::get{1}() const { return getImpl()->{2}; }\n", + parameter.getCppType(), name, parameter.getName(), + typeDef.getCppClassName()); + } + } + + // If mnemonic is specified, maybe print a def + if (typeDef.getMnemonic()) { + // emit the printer code, if appropriate + auto printerCode = typeDef.getPrinterCode(); + if (printerCode) { + // Both the mnenomic and printerCode must be defined (for parity with + // parserCode) + os << "void " << typeDef.getCppClassName() + << "::print(mlir::DialectAsmPrinter& printer) const {\n"; + if (*printerCode == "") { + // if no code specified, emit error + llvm::PrintError( + typeDef.getLoc(), + typeDef.getName() + + ": printer (if specified) must have non-empty code"); + return mlir::failure(); + } else { + auto fmtCtxt = FmtContext().addSubst("_printer", "printer"); + auto fmtObj = tgfmt(*printerCode, &fmtCtxt); + fmtObj.format(os); + os << "\n"; + } + os << "}\n"; + } + + // emit a parser, if appropriate + auto parserCode = typeDef.getParserCode(); + if (parserCode) { + // The mnenomic must be defined so the dispatcher knows how to dispatch + os << "::mlir::Type " << typeDef.getCppClassName() + << "::parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& " + "parser) " + "{\n"; + if (*parserCode == "") { + // if no code specified, emit error + llvm::PrintError( + typeDef.getLoc(), + typeDef.getName() + + ": parser (if specified) must have non-empty code"); + return mlir::failure(); + } else { + auto fmtCtxt = FmtContext() + .addSubst("_parser", "parser") + .addSubst("_ctxt", "ctxt"); + auto fmtObj = tgfmt(*parserCode, &fmtCtxt); + fmtObj.format(os); + os << "\n"; + } + os << "}\n"; + } + + } // typeDef.getMnemonic() + return mlir::success(); +} + +/// Emit the dialect printer/parser dispatch. Client code should call these +/// functions from their dialect's print/parse methods. +static mlir::LogicalResult +emitParsePrintDispatch(SmallVectorImpl &typeDefs, raw_ostream &os) { + if (typeDefs.size() == 0) + return mlir::success(); + const Dialect &dialect = typeDefs.begin()->getDialect(); + NamespaceEmitter ns(os, dialect); + os << "::mlir::Type generatedTypeParser(::mlir::MLIRContext* ctxt, " + "::mlir::DialectAsmParser& parser, ::llvm::StringRef mnemonic) {\n"; + for (auto typeDef : typeDefs) { + if (typeDef.getMnemonic()) + os << llvm::formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " + "{0}::{1}::parse(ctxt, parser);\n", + typeDef.getDialect().getCppNamespace(), + typeDef.getCppClassName()); + } + os << " return ::mlir::Type();\n"; + os << "}\n\n"; + + os << "bool generatedTypePrinter(::mlir::Type type, " + "::mlir::DialectAsmPrinter& " + "printer) {\n" + << " bool notfound = false;\n" + << " ::llvm::TypeSwitch<::mlir::Type>(type)\n"; + for (auto typeDef : typeDefs) { + if (typeDef.getMnemonic()) + os << llvm::formatv(" .Case<{0}::{1}>([&](::mlir::Type t) {{ " + "t.dyn_cast<{0}::{1}>().print(printer); })\n", + typeDef.getDialect().getCppNamespace(), + typeDef.getCppClassName()); + } + os << " .Default([¬found](::mlir::Type) { notfound = true; });\n" + << " return notfound;\n" + << "}\n\n"; + return mlir::success(); +} + +/// Entry point for typedef definitions +static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper, + raw_ostream &os) { + emitSourceFileHeader("TypeDef Definitions", os); + + SmallVector typeDefs; + if (mlir::failed(findAllTypeDefs(recordKeeper, typeDefs))) + return true; + + if (mlir::failed(emitTypeDefList(typeDefs, os))) + return true; + + IfDefScope scope("GET_TYPEDEF_CLASSES", os); + if (mlir::failed(emitParsePrintDispatch(typeDefs, os))) + return true; + for (auto typeDef : typeDefs) { + if (mlir::failed(emitTypeDefDef(typeDef, os))) + return true; + } + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: TypeDef registration hooks +//===----------------------------------------------------------------------===// + +static mlir::GenRegistration + genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return emitTypeDefDefs(records, os); + }); + +static mlir::GenRegistration + genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + return emitTypeDefDecls(records, os); + });