diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2465,20 +2465,20 @@ //===----------------------------------------------------------------------===// -// Data type generation +// Attribute and Type generation //===----------------------------------------------------------------------===// -// Class for defining a custom type getter. +// Class for defining a custom getter. // -// TableGen generates several generic getter methods for each type by default, -// corresponding to the specified dag parameters. If the default generated ones -// cannot cover some use case, custom getters can be defined using instances of -// this class. +// TableGen generates several generic getter methods for each attribute and type +// by default, corresponding to the specified dag parameters. If the default +// generated ones cannot cover some use case, custom getters can be defined +// using instances of this class. // // The signature of the `get` is always either: // // ```c++ -// static get(MLIRContext *context, ...) { +// static get(MLIRContext *context, ...) { // ... // } // ``` @@ -2486,7 +2486,7 @@ // or: // // ```c++ -// static get(MLIRContext *context, ...); +// static get(MLIRContext *context, ...); // ``` // // To define a custom getter, the parameter list and body should be passed @@ -2503,7 +2503,7 @@ // type. For example, the following signature specification // // ``` -// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)> +// AttrOrTypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg)> // ``` // // has an integer parameter and a float parameter with a default value. @@ -2514,7 +2514,7 @@ // method should be invoked using `$_get`, e.g.: // // ``` -// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg), [{ +// AttrOrTypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg), [{ // return $_get($_ctxt, integerArg, floatArg); // }]> // ``` @@ -2522,7 +2522,7 @@ // This is necessary because the `body` is also used to generate `getChecked` // methods, which have a different underlying `Base::get*` call. // -class TypeBuilder { +class AttrOrTypeBuilder { dag dagParams = parameters; code body = bodyCode; @@ -2530,33 +2530,42 @@ // is not implicitly added to the parameter list. bit hasInferredContextParam = 0; } +class AttrBuilder + : AttrOrTypeBuilder; +class TypeBuilder + : AttrOrTypeBuilder; -// A class of TypeBuilder that is able to infer the MLIRContext parameter from -// one of the other builder parameters. Instances of this builder do not have -// `MLIRContext *` implicitly added to the parameter list. -class TypeBuilderWithInferredContext +// A class of AttrOrTypeBuilder that is able to infer the MLIRContext parameter +// from one of the other builder parameters. Instances of this builder do not +// have `MLIRContext *` implicitly added to the parameter list. +class AttrOrTypeBuilderWithInferredContext : TypeBuilder { let hasInferredContextParam = 1; } +class AttrBuilderWithInferredContext + : AttrOrTypeBuilderWithInferredContext; +class TypeBuilderWithInferredContext + : AttrOrTypeBuilderWithInferredContext; -// Define a new type, named `name`, belonging to `dialect` that inherits from -// the given C++ base class. -class TypeDef - : DialectType, /*descr*/"", name # "Type"> { - // The name of the C++ base class to use for this Type. +// Define a new attribute or type, named `name`, that inherits from the given +// C++ base class. +class AttrOrTypeDef { + // The name of the C++ base class to use for this def. string cppBaseClassName = baseCppClass; - // Additional, longer human-readable description of what the op does. + // Additional, longer human-readable description of what the def does. string description = ""; // Name of storage class to generate or use. - string storageClass = name # "TypeStorage"; + string storageClass = name # valueType # "Storage"; + // Namespace (withing dialect c++ namespace) in which the storage class // resides. string storageNamespace = "detail"; + // Specify if the storage class is to be generated. bit genStorageClass = 1; + // Specify that the generated storage class has a constructor which is written // in C++. bit hasStorageCustomConstructor = 0; @@ -2568,38 +2577,38 @@ // (ins // "":$param1Name, // "":$param2Name, - // TypeParameter<"c++ type", "param description">:$param3Name) - // TypeParameters (or more likely one of their subclasses) are required to add - // more information about the parameter, specifically: + // AttrOrTypeParameter<"c++ type", "param description">:$param3Name) + // AttrOrTypeParameters (or more likely one of their subclasses) are required + // to add more information about the parameter, specifically: // - Documentation // - Code to allocate the parameter (if allocation is needed in the storage // class constructor) // // For example: - // (ins - // "int":$width, - // ArrayRefParameter<"bool", "list of bools">:$yesNoArray) + // (ins "int":$width, + // ArrayRefParameter<"bool", "list of bools">:$yesNoArray) // - // (ArrayRefParameter is a subclass of TypeParameter which has allocation code - // for re-allocating ArrayRefs. It is defined below.) + // (ArrayRefParameter is a subclass of AttrOrTypeParameter which has + // allocation code for re-allocating ArrayRefs. It is defined below.) dag parameters = (ins); - // Custom type builder methods. + // Custom builder methods. // In addition to the custom builders provided here, and unless // skipDefaultBuilders is set, a default builder is generated with the // following signature: // // ```c++ - // static get(MLIRContext *, ); + // static get(MLIRContext *, ); // ``` // - // Note that builders should only be provided when a type has parameters. - list builders = ?; + // Note that builders should only be provided when a def has parameters. + list builders = ?; // Use the lowercased name as the keyword for parsing/printing. Specify only // if you want tblgen to generate declarations and/or definitions of - // printer/parser for this type. + // the printer/parser. string mnemonic = ?; + // If 'mnemonic' specified, // If null, generate just the declarations. // If a non-empty code block, just use that code as the definition code. @@ -2607,29 +2616,53 @@ code printer = ?; code parser = ?; - // If set, generate accessors for each Type parameter. + // If set, generate accessors for each parameter. bit genAccessors = 1; + // Avoid generating default get/getChecked functions. Custom get methods must // be provided. bit skipDefaultBuilders = 0; + // Generate the verify and getChecked methods. bit genVerifyDecl = 0; + // Extra code to include in the class declaration. code extraClassDeclaration = [{}]; +} + +// Define a new attribute, named `name`, belonging to `dialect` that inherits +// from the given C++ base class. +class AttrDef + : DialectAttr, /*descr*/"">, + AttrOrTypeDef<"Attr", name, baseCppClass> { + // The name of the C++ Attribute class. + string cppClassName = name # "Attr"; - // The predicate for when this type is used as a type constraint. + // The predicate for when this def is used as a constraint. let predicate = CPred<"$_self.isa<" # dialect.cppNamespace # "::" # cppClassName # ">()">; +} + +// Define a new type, named `name`, belonging to `dialect` that inherits from +// the given C++ base class. +class TypeDef + : DialectType, /*descr*/"", name # "Type">, + AttrOrTypeDef<"Type", name, baseCppClass> { // A constant builder provided when the type has no parameters. let builderCall = !if(!empty(parameters), "$_builder.getType<" # dialect.cppNamespace # "::" # cppClassName # ">()", ""); + // The predicate for when this def is used as a constraint. + let predicate = CPred<"$_self.isa<" # dialect.cppNamespace # + "::" # cppClassName # ">()">; } // 'Parameters' should be subclasses of this or simple strings (which is a -// shorthand for TypeParameter<"C++Type">). -class TypeParameter { +// shorthand for AttrOrTypeParameter<"C++Type">). +class AttrOrTypeParameter { // Custom memory allocation code for storage constructor. code allocator = ?; // The C++ type of this parameter. @@ -2639,28 +2672,30 @@ // The format string for the asm syntax (documentation only). string syntax = ?; } +class AttrParameter : AttrOrTypeParameter; +class TypeParameter : AttrOrTypeParameter; // For StringRefs, which require allocation. class StringRefParameter : - TypeParameter<"::llvm::StringRef", desc> { + AttrOrTypeParameter<"::llvm::StringRef", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; } // For standard ArrayRefs, which require allocation. class ArrayRefParameter : - TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { + AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { let allocator = [{$_dst = $_allocator.copyInto($_self);}]; } // For classes which require allocation and have their own allocateInto method. class SelfAllocationParameter : - TypeParameter { + AttrOrTypeParameter { let allocator = [{$_dst = $_self.allocateInto($_allocator);}]; } // For ArrayRefs which contain things which allocate themselves. class ArrayRefOfSelfAllocationParameter : - TypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { + AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { let allocator = [{ llvm::SmallVector<}] # arrayOf # [{, 4> tmpFields; for (size_t i = 0, e = $_self.size(); i < e; ++i) @@ -2669,5 +2704,4 @@ }]; } - #endif // OP_BASE diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h rename from mlir/include/mlir/TableGen/TypeDef.h rename to mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/TypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -1,4 +1,4 @@ -//===-- TypeDef.h - Record wrapper for type definitions ---------*- C++ -*-===// +//===-- AttrOrTypeDef.h - Wrapper for attr and type definitions -*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,13 @@ // //===----------------------------------------------------------------------===// // -// TypeDef wrapper to simplify using TableGen Record defining a MLIR type. +// AttrOrTypeDef, AttrDef, and TypeDef wrappers to simplify using TableGen +// Record defining a MLIR attributes and types. // //===----------------------------------------------------------------------===// -#ifndef MLIR_TABLEGEN_TYPEDEF_H -#define MLIR_TABLEGEN_TYPEDEF_H +#ifndef MLIR_TABLEGEN_ATTRORTYPEDEF_H +#define MLIR_TABLEGEN_ATTRORTYPEDEF_H #include "mlir/Support/LLVM.h" #include "mlir/TableGen/Builder.h" @@ -25,14 +26,14 @@ namespace mlir { namespace tblgen { class Dialect; -class TypeParameter; +class AttrOrTypeParameter; //===----------------------------------------------------------------------===// -// TypeBuilder +// AttrOrTypeBuilder //===----------------------------------------------------------------------===// -/// Wrapper class that represents a Tablegen TypeBuilder. -class TypeBuilder : public Builder { +/// Wrapper class that represents a Tablegen AttrOrTypeBuilder. +class AttrOrTypeBuilder : public Builder { public: using Builder::Builder; @@ -41,22 +42,22 @@ }; //===----------------------------------------------------------------------===// -// TypeDef +// AttrOrTypeDef //===----------------------------------------------------------------------===// -/// Wrapper class that contains a TableGen TypeDef's record and provides helper -/// methods for accessing them. -class TypeDef { +/// Wrapper class that contains a TableGen AttrOrTypeDef's record and provides +/// helper methods for accessing them. +class AttrOrTypeDef { public: - explicit TypeDef(const llvm::Record *def); + explicit AttrOrTypeDef(const llvm::Record *def); - // Get the dialect for which this type belongs. + // Get the dialect for which this def belongs. Dialect getDialect() const; - // Returns the name of this TypeDef record. + // Returns the name of this AttrOrTypeDef record. StringRef getName() const; - // Query functions for the documentation of the operator. + // Query functions for the documentation of the def. bool hasDescription() const; StringRef getDescription() const; bool hasSummary() const; @@ -65,13 +66,13 @@ // Returns the name of the C++ class to generate. StringRef getCppClassName() const; - // Returns the name of the C++ base class to use when generating this type. + // Returns the name of the C++ base class to use when generating this def. StringRef getCppBaseClassName() const; - // Returns the name of the storage class for this type. + // Returns the name of the storage class for this def. StringRef getStorageClassName() const; - // Returns the C++ namespace for this types storage class. + // Returns the C++ namespace for this def's storage class. StringRef getStorageNamespace() const; // Returns true if we should generate the storage class. @@ -80,10 +81,11 @@ // Indicates whether or not to generate the storage class constructor. bool hasStorageCustomConstructor() const; - // Fill a list with this types parameters. See TypeDef in OpBase.td for + // Fill a list with this def's parameters. See AttrOrTypeDef in OpBase.td for // documentation of parameter usage. - void getParameters(SmallVectorImpl &) const; - // Return the number of type parameters + void getParameters(SmallVectorImpl &) const; + + // Return the number of parameters unsigned getNumParameters() const; // Return the keyword/mnemonic to use in the printer/parser methods if we are @@ -94,19 +96,18 @@ // return a non-value. Otherwise, return the contents of that code block. Optional getPrinterCode() const; - // Returns the code to use as the types parser method. If not specified, - // return a non-value. Otherwise, return the contents of that code block. + // Returns the code to use as the parser method. If not specified, returns + // None. Otherwise, returns the contents of that code block. Optional getParserCode() const; - // Returns true if the accessors based on the types parameters should be - // generated. + // Returns true if the accessors based on the parameters should be generated. bool genAccessors() const; // Return true if we need to generate the verify declaration and getChecked // method. bool genVerifyDecl() const; - // Returns the dialects extra class declaration code. + // Returns the def's extra class declaration code. Optional getExtraDecls() const; // Get the code location (for error printing). @@ -116,54 +117,80 @@ // generation. bool skipDefaultBuilders() const; - // Returns the builders of this type. - ArrayRef getBuilders() const { return builders; } + // Returns the builders of this def. + ArrayRef getBuilders() const { return builders; } - // Returns whether two TypeDefs are equal by checking the equality of the - // underlying record. - bool operator==(const TypeDef &other) const; + // Returns whether two AttrOrTypeDefs are equal by checking the equality of + // the underlying record. + bool operator==(const AttrOrTypeDef &other) const; - // Compares two TypeDefs by comparing the names of the dialects. - bool operator<(const TypeDef &other) const; + // Compares two AttrOrTypeDefs by comparing the names of the dialects. + bool operator<(const AttrOrTypeDef &other) const; - // Returns whether the TypeDef is defined. + // Returns whether the AttrOrTypeDef is defined. operator bool() const { return def != nullptr; } private: const llvm::Record *def; // The builders of this type definition. - SmallVector builders; + SmallVector builders; +}; + +//===----------------------------------------------------------------------===// +// AttrDef +//===----------------------------------------------------------------------===// + +/// This class represents a wrapper around a tablegen AttrDef record. +class AttrDef : public AttrOrTypeDef { +public: + using AttrOrTypeDef::AttrOrTypeDef; +}; + +//===----------------------------------------------------------------------===// +// TypeDef +//===----------------------------------------------------------------------===// + +/// This class represents a wrapper around a tablegen TypeDef record. +class TypeDef : public AttrOrTypeDef { +public: + using AttrOrTypeDef::AttrOrTypeDef; }; //===----------------------------------------------------------------------===// -// TypeParameter +// AttrOrTypeParameter //===----------------------------------------------------------------------===// -// A wrapper class for tblgen TypeParameter, arrays of which belong to TypeDefs -// to parameterize them. -class TypeParameter { +// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to +// AttrOrTypeDefs to parameterize them. +class AttrOrTypeParameter { public: - explicit TypeParameter(const llvm::DagInit *def, unsigned num) - : def(def), num(num) {} + explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index) + : def(def), index(index) {} // Get the parameter name. StringRef getName() const; + // If specified, get the custom allocator code for this parameter. Optional getAllocator() const; + // Get the C++ type of this parameter. StringRef getCppType() const; + // Get a description of this parameter for documentation purposes. Optional getSummary() const; + // Get the assembly syntax documentation. StringRef getSyntax() const; private: + /// The underlying tablegen parameter list this parameter is a part of. const llvm::DagInit *def; - const unsigned num; + /// The index of the parameter within the parameter list (`def`). + unsigned index; }; } // end namespace tblgen } // end namespace mlir -#endif // MLIR_TABLEGEN_TYPEDEF_H +#endif // MLIR_TABLEGEN_ATTRORTYPEDEF_H diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -23,14 +23,15 @@ // Simple RAII helper for defining ifdef-undef-endif scopes. class IfDefScope { public: - IfDefScope(llvm::StringRef name, llvm::raw_ostream &os) : name(name), os(os) { + IfDefScope(llvm::StringRef name, llvm::raw_ostream &os) + : name(name.str()), os(os) { os << "#ifdef " << name << "\n" << "#undef " << name << "\n\n"; } ~IfDefScope() { os << "\n#endif // " << name << "\n\n"; } private: - llvm::StringRef name; + std::string name; llvm::raw_ostream &os; }; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -0,0 +1,221 @@ +//===- AttrOrTypeDef.cpp - AttrOrTypeDef wrapper classes ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/AttrOrTypeDef.h" +#include "mlir/TableGen/Dialect.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// AttrOrTypeBuilder +//===----------------------------------------------------------------------===// + +/// Returns true if this builder is able to infer the MLIRContext parameter. +bool AttrOrTypeBuilder::hasInferredContextParameter() const { + return def->getValueAsBit("hasInferredContextParam"); +} + +//===----------------------------------------------------------------------===// +// AttrOrTypeDef +//===----------------------------------------------------------------------===// + +AttrOrTypeDef::AttrOrTypeDef(const llvm::Record *def) : def(def) { + // Populate the builders. + auto *builderList = + dyn_cast_or_null(def->getValueInit("builders")); + if (builderList && !builderList->empty()) { + for (llvm::Init *init : builderList->getValues()) { + AttrOrTypeBuilder builder(cast(init)->getDef(), + def->getLoc()); + + // Ensure that all parameters have names. + for (const AttrOrTypeBuilder::Parameter ¶m : + builder.getParameters()) { + if (!param.getName()) + PrintFatalError(def->getLoc(), "builder parameters must have a name"); + } + builders.emplace_back(builder); + } + } else if (skipDefaultBuilders()) { + PrintFatalError( + def->getLoc(), + "default builders are skipped and no custom builders provided"); + } +} + +Dialect AttrOrTypeDef::getDialect() const { + auto *dialect = dyn_cast(def->getValue("dialect")->getValue()); + return Dialect(dialect ? dialect->getDef() : nullptr); +} + +StringRef AttrOrTypeDef::getName() const { return def->getName(); } + +StringRef AttrOrTypeDef::getCppClassName() const { + return def->getValueAsString("cppClassName"); +} + +StringRef AttrOrTypeDef::getCppBaseClassName() const { + return def->getValueAsString("cppBaseClassName"); +} + +bool AttrOrTypeDef::hasDescription() const { + const llvm::RecordVal *desc = def->getValue("description"); + return desc && isa(desc->getValue()); +} + +StringRef AttrOrTypeDef::getDescription() const { + return def->getValueAsString("description"); +} + +bool AttrOrTypeDef::hasSummary() const { + const llvm::RecordVal *summary = def->getValue("summary"); + return summary && isa(summary->getValue()); +} + +StringRef AttrOrTypeDef::getSummary() const { + return def->getValueAsString("summary"); +} + +StringRef AttrOrTypeDef::getStorageClassName() const { + return def->getValueAsString("storageClass"); +} + +StringRef AttrOrTypeDef::getStorageNamespace() const { + return def->getValueAsString("storageNamespace"); +} + +bool AttrOrTypeDef::genStorageClass() const { + return def->getValueAsBit("genStorageClass"); +} + +bool AttrOrTypeDef::hasStorageCustomConstructor() const { + return def->getValueAsBit("hasStorageCustomConstructor"); +} + +void AttrOrTypeDef::getParameters( + SmallVectorImpl ¶meters) const { + if (auto *parametersDag = def->getValueAsDag("parameters")) { + for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i) + parameters.push_back(AttrOrTypeParameter(parametersDag, i)); + } +} + +unsigned AttrOrTypeDef::getNumParameters() const { + auto *parametersDag = def->getValueAsDag("parameters"); + return parametersDag ? parametersDag->getNumArgs() : 0; +} + +Optional AttrOrTypeDef::getMnemonic() const { + return def->getValueAsOptionalString("mnemonic"); +} + +Optional AttrOrTypeDef::getPrinterCode() const { + return def->getValueAsOptionalString("printer"); +} + +Optional AttrOrTypeDef::getParserCode() const { + return def->getValueAsOptionalString("parser"); +} + +bool AttrOrTypeDef::genAccessors() const { + return def->getValueAsBit("genAccessors"); +} + +bool AttrOrTypeDef::genVerifyDecl() const { + return def->getValueAsBit("genVerifyDecl"); +} + +Optional AttrOrTypeDef::getExtraDecls() const { + auto value = def->getValueAsString("extraClassDeclaration"); + return value.empty() ? Optional() : value; +} + +ArrayRef AttrOrTypeDef::getLoc() const { return def->getLoc(); } + +bool AttrOrTypeDef::skipDefaultBuilders() const { + return def->getValueAsBit("skipDefaultBuilders"); +} + +bool AttrOrTypeDef::operator==(const AttrOrTypeDef &other) const { + return def == other.def; +} + +bool AttrOrTypeDef::operator<(const AttrOrTypeDef &other) const { + return getName() < other.getName(); +} + +//===----------------------------------------------------------------------===// +// AttrOrTypeParameter +//===----------------------------------------------------------------------===// + +StringRef AttrOrTypeParameter::getName() const { + return def->getArgName(index)->getValue(); +} + +Optional AttrOrTypeParameter::getAllocator() const { + llvm::Init *parameterType = def->getArg(index); + if (isa(parameterType)) + return Optional(); + + if (auto *param = dyn_cast(parameterType)) { + llvm::RecordVal *code = param->getDef()->getValue("allocator"); + if (!code) + return Optional(); + if (llvm::StringInit *ci = dyn_cast(code->getValue())) + return ci->getValue(); + if (isa(code->getValue())) + return Optional(); + + llvm::PrintFatalError( + param->getDef()->getLoc(), + "Record `" + def->getArgName(index)->getValue() + + "', field `printer' does not have a code initializer!"); + } + + llvm::PrintFatalError("Parameters DAG arguments must be either strings or " + "defs which inherit from AttrOrTypeParameter\n"); +} + +StringRef AttrOrTypeParameter::getCppType() const { + auto *parameterType = def->getArg(index); + if (auto *stringType = dyn_cast(parameterType)) + return stringType->getValue(); + if (auto *param = dyn_cast(parameterType)) + return param->getDef()->getValueAsString("cppType"); + llvm::PrintFatalError( + "Parameters DAG arguments must be either strings or defs " + "which inherit from AttrOrTypeParameter\n"); +} + +Optional AttrOrTypeParameter::getSummary() const { + auto *parameterType = def->getArg(index); + if (auto *param = dyn_cast(parameterType)) { + const auto *desc = param->getDef()->getValue("summary"); + if (llvm::StringInit *ci = dyn_cast(desc->getValue())) + return ci->getValue(); + } + return Optional(); +} + +StringRef AttrOrTypeParameter::getSyntax() const { + auto *parameterType = def->getArg(index); + if (auto *stringType = dyn_cast(parameterType)) + return stringType->getValue(); + if (auto *param = dyn_cast(parameterType)) { + const auto *syntax = param->getDef()->getValue("syntax"); + if (syntax && isa(syntax->getValue())) + return cast(syntax->getValue())->getValue(); + return getCppType(); + } + llvm::PrintFatalError("Parameters DAG arguments must be either strings or " + "defs which inherit from AttrOrTypeParameter"); +} diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -11,6 +11,7 @@ llvm_add_library(MLIRTableGen STATIC Argument.cpp Attribute.cpp + AttrOrTypeDef.cpp Builder.cpp Constraint.cpp Dialect.cpp @@ -26,7 +27,6 @@ SideEffects.cpp Successor.cpp Type.cpp - TypeDef.cpp DISABLE_LLVM_LINK_LLVM_DYLIB diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp deleted file mode 100644 --- a/mlir/lib/TableGen/TypeDef.cpp +++ /dev/null @@ -1,212 +0,0 @@ -//===- TypeDef.cpp - TypeDef wrapper class --------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// TypeDef wrapper to simplify using TableGen Record defining a MLIR dialect. -// -//===----------------------------------------------------------------------===// - -#include "mlir/TableGen/TypeDef.h" -#include "mlir/TableGen/Dialect.h" -#include "llvm/ADT/StringExtras.h" -#include "llvm/TableGen/Error.h" -#include "llvm/TableGen/Record.h" - -using namespace mlir; -using namespace mlir::tblgen; - -//===----------------------------------------------------------------------===// -// TypeBuilder -//===----------------------------------------------------------------------===// - -/// Returns true if this builder is able to infer the MLIRContext parameter. -bool TypeBuilder::hasInferredContextParameter() const { - return def->getValueAsBit("hasInferredContextParam"); -} - -//===----------------------------------------------------------------------===// -// TypeDef -//===----------------------------------------------------------------------===// - -Dialect TypeDef::getDialect() const { - auto *dialectDef = - dyn_cast(def->getValue("dialect")->getValue()); - if (dialectDef == nullptr) - return Dialect(nullptr); - return Dialect(dialectDef->getDef()); -} - -StringRef TypeDef::getName() const { return def->getName(); } -StringRef TypeDef::getCppClassName() const { - return def->getValueAsString("cppClassName"); -} - -StringRef TypeDef::getCppBaseClassName() const { - return def->getValueAsString("cppBaseClassName"); -} - -bool TypeDef::hasDescription() const { - const llvm::RecordVal *s = def->getValue("description"); - return s != nullptr && isa(s->getValue()); -} - -StringRef TypeDef::getDescription() const { - return def->getValueAsString("description"); -} - -bool TypeDef::hasSummary() const { - const llvm::RecordVal *s = def->getValue("summary"); - return s != nullptr && isa(s->getValue()); -} - -StringRef TypeDef::getSummary() const { - return def->getValueAsString("summary"); -} - -StringRef TypeDef::getStorageClassName() const { - return def->getValueAsString("storageClass"); -} -StringRef TypeDef::getStorageNamespace() const { - return def->getValueAsString("storageNamespace"); -} - -bool TypeDef::genStorageClass() const { - return def->getValueAsBit("genStorageClass"); -} -bool TypeDef::hasStorageCustomConstructor() const { - return def->getValueAsBit("hasStorageCustomConstructor"); -} -void TypeDef::getParameters(SmallVectorImpl ¶meters) const { - auto *parametersDag = def->getValueAsDag("parameters"); - if (parametersDag != nullptr) { - size_t numParams = parametersDag->getNumArgs(); - for (unsigned i = 0; i < numParams; i++) - parameters.push_back(TypeParameter(parametersDag, i)); - } -} -unsigned TypeDef::getNumParameters() const { - auto *parametersDag = def->getValueAsDag("parameters"); - return parametersDag ? parametersDag->getNumArgs() : 0; -} -llvm::Optional TypeDef::getMnemonic() const { - return def->getValueAsOptionalString("mnemonic"); -} -llvm::Optional TypeDef::getPrinterCode() const { - return def->getValueAsOptionalString("printer"); -} -llvm::Optional TypeDef::getParserCode() const { - return def->getValueAsOptionalString("parser"); -} -bool TypeDef::genAccessors() const { - return def->getValueAsBit("genAccessors"); -} -bool TypeDef::genVerifyDecl() const { - return def->getValueAsBit("genVerifyDecl"); -} -llvm::Optional TypeDef::getExtraDecls() const { - auto value = def->getValueAsString("extraClassDeclaration"); - return value.empty() ? llvm::Optional() : value; -} -llvm::ArrayRef TypeDef::getLoc() const { return def->getLoc(); } - -bool TypeDef::skipDefaultBuilders() const { - return def->getValueAsBit("skipDefaultBuilders"); -} - -bool TypeDef::operator==(const TypeDef &other) const { - return def == other.def; -} - -bool TypeDef::operator<(const TypeDef &other) const { - return getName() < other.getName(); -} - -//===----------------------------------------------------------------------===// -// TypeParameter -//===----------------------------------------------------------------------===// - -TypeDef::TypeDef(const llvm::Record *def) : def(def) { - // Populate the builders. - auto *builderList = - dyn_cast_or_null(def->getValueInit("builders")); - if (builderList && !builderList->empty()) { - for (llvm::Init *init : builderList->getValues()) { - TypeBuilder builder(cast(init)->getDef(), def->getLoc()); - - // Ensure that all parameters have names. - for (const TypeBuilder::Parameter ¶m : builder.getParameters()) { - if (!param.getName()) - PrintFatalError(def->getLoc(), - "type builder parameters must have a name"); - } - builders.emplace_back(builder); - } - } else if (skipDefaultBuilders()) { - PrintFatalError( - def->getLoc(), - "default builders are skipped and no custom builders provided"); - } -} - -StringRef TypeParameter::getName() const { - return def->getArgName(num)->getValue(); -} -Optional TypeParameter::getAllocator() const { - llvm::Init *parameterType = def->getArg(num); - if (isa(parameterType)) - return llvm::Optional(); - - if (auto *typeParameter = dyn_cast(parameterType)) { - llvm::RecordVal *code = typeParameter->getDef()->getValue("allocator"); - if (!code) - return llvm::Optional(); - if (llvm::StringInit *ci = dyn_cast(code->getValue())) - return ci->getValue(); - if (isa(code->getValue())) - return llvm::Optional(); - - llvm::PrintFatalError( - typeParameter->getDef()->getLoc(), - "Record `" + def->getArgName(num)->getValue() + - "', field `printer' does not have a code initializer!"); - } - - llvm::PrintFatalError("Parameters DAG arguments must be either strings or " - "defs which inherit from TypeParameter\n"); -} -StringRef TypeParameter::getCppType() const { - auto *parameterType = def->getArg(num); - if (auto *stringType = dyn_cast(parameterType)) - return stringType->getValue(); - if (auto *typeParameter = dyn_cast(parameterType)) - return typeParameter->getDef()->getValueAsString("cppType"); - llvm::PrintFatalError( - "Parameters DAG arguments must be either strings or defs " - "which inherit from TypeParameter\n"); -} -Optional TypeParameter::getSummary() const { - auto *parameterType = def->getArg(num); - if (auto *typeParameter = dyn_cast(parameterType)) { - const auto *desc = typeParameter->getDef()->getValue("summary"); - if (llvm::StringInit *ci = dyn_cast(desc->getValue())) - return ci->getValue(); - } - return Optional(); -} -StringRef TypeParameter::getSyntax() const { - auto *parameterType = def->getArg(num); - if (auto *stringType = dyn_cast(parameterType)) - return stringType->getValue(); - if (auto *typeParameter = dyn_cast(parameterType)) { - const auto *syntax = typeParameter->getDef()->getValue("syntax"); - if (syntax && isa(syntax->getValue())) - return dyn_cast(syntax->getValue())->getValue(); - return getCppType(); - } - llvm::PrintFatalError("Parameters DAG arguments must be either strings or " - "defs which inherit from TypeParameter"); -} diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -11,10 +11,15 @@ mlir_tablegen(TestOpInterfaces.cpp.inc -gen-op-interface-defs) add_public_tablegen_target(MLIRTestInterfaceIncGen) +set(LLVM_TARGET_DEFINITIONS TestAttrDefs.td) +mlir_tablegen(TestAttrDefs.h.inc -gen-attrdef-decls) +mlir_tablegen(TestAttrDefs.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRTestAttrDefIncGen) + set(LLVM_TARGET_DEFINITIONS TestTypeDefs.td) mlir_tablegen(TestTypeDefs.h.inc -gen-typedef-decls) mlir_tablegen(TestTypeDefs.cpp.inc -gen-typedef-defs) -add_public_tablegen_target(MLIRTestDefIncGen) +add_public_tablegen_target(MLIRTestTypeDefIncGen) set(LLVM_TARGET_DEFINITIONS TestOps.td) @@ -30,6 +35,7 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestDialect + TestAttributes.cpp TestDialect.cpp TestInterfaces.cpp TestPatterns.cpp @@ -39,8 +45,9 @@ EXCLUDE_FROM_LIBMLIR DEPENDS + MLIRTestAttrDefIncGen MLIRTestInterfaceIncGen - MLIRTestDefIncGen + MLIRTestTypeDefIncGen MLIRTestOpsIncGen LINK_LIBS PUBLIC diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -0,0 +1,44 @@ +//===-- TestAttrDefs.td - Test dialect attr definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// TableGen data attribute definitions for Test dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_ATTRDEFS +#define TEST_ATTRDEFS + +// To get the test dialect definition. +include "TestOps.td" + +// All of the attributes will extend this class. +class Test_Attr : AttrDef; + +def SimpleAttrA : Test_Attr<"SimpleA"> { + let mnemonic = "smpla"; +} + +// A more complex parameterized attribute. +def CompoundAttrA : Test_Attr<"CompoundA"> { + let mnemonic = "cmpnd_a"; + + // List of type parameters. + let parameters = ( + ins + "int":$widthOfSomething, + "::mlir::Type":$oneType, + // This is special syntax since ArrayRefs require allocation in the + // constructor. + ArrayRefParameter< + "int", // The parameter C++ type. + "An example of an array of ints" // Parameter description. + >: $arrayOfInts + ); +} + +#endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -0,0 +1,27 @@ +//===- TestTypes.h - MLIR Test Dialect Types --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains types defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTATTRIBUTES_H +#define MLIR_TESTATTRIBUTES_H + +#include + +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" + +#define GET_ATTRDEF_CLASSES +#include "TestAttrDefs.h.inc" + +#endif // MLIR_TESTATTRIBUTES_H diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -0,0 +1,82 @@ +//===- TestAttributes.cpp - MLIR Test Dialect Attributes --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains attributes defined by the TestDialect for testing various +// features of MLIR. +// +//===----------------------------------------------------------------------===// + +#include "TestAttributes.h" +#include "TestDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/Types.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::test; + +Attribute CompoundAAttr::parse(MLIRContext *context, DialectAsmParser &parser, + Type type) { + int widthOfSomething; + Type oneType; + SmallVector arrayOfInts; + if (parser.parseLess() || parser.parseInteger(widthOfSomething) || + parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || + parser.parseLSquare()) + return Attribute(); + + int intVal; + while (!*parser.parseOptionalInteger(intVal)) { + arrayOfInts.push_back(intVal); + if (parser.parseOptionalComma()) + break; + } + + if (parser.parseRSquare() || parser.parseGreater()) + return Attribute(); + return get(context, widthOfSomething, oneType, arrayOfInts); +} + +void CompoundAAttr::print(DialectAsmPrinter &printer) const { + printer << "cmpnd_a<" << getWidthOfSomething() << ", " << getOneType() + << ", ["; + llvm::interleaveComma(getArrayOfInts(), printer); + printer << "]>"; +} + +//===----------------------------------------------------------------------===// +// Tablegen Generated Definitions +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "TestAttrDefs.cpp.inc" + +//===----------------------------------------------------------------------===// +// TestDialect +//===----------------------------------------------------------------------===// + +Attribute TestDialect::parseAttribute(DialectAsmParser &parser, + Type type) const { + StringRef attrTag; + if (failed(parser.parseKeyword(&attrTag))) + return Attribute(); + if (auto attr = generatedAttributeParser(getContext(), parser, attrTag, type)) + return attr; + + parser.emitError(parser.getNameLoc(), "unknown test attribute"); + return Attribute(); +} + +void TestDialect::printAttribute(Attribute attr, + DialectAsmPrinter &printer) const { + if (succeeded(generatedAttributePrinter(attr, printer))) + return; +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestAttributes.h" #include "TestTypes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BuiltinOps.h" @@ -168,6 +169,10 @@ #define GET_OP_LIST #include "TestOps.cpp.inc" >(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "TestAttrDefs.cpp.inc" + >(); addInterfaces(); addTypes traits = []> : diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -0,0 +1,96 @@ +// RUN: mlir-tblgen -gen-attrdef-decls -I %S/../../include %s | FileCheck %s --check-prefix=DECL +// RUN: mlir-tblgen -gen-attrdef-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEF + +include "mlir/IR/OpBase.td" + +// DECL: #ifdef GET_ATTRDEF_CLASSES +// DECL: #undef GET_ATTRDEF_CLASSES + +// DECL: namespace mlir { +// DECL: class DialectAsmParser; +// DECL: class DialectAsmPrinter; +// DECL: } // namespace mlir + +// DEF: #ifdef GET_ATTRDEF_LIST +// DEF: #undef GET_ATTRDEF_LIST +// DEF: ::mlir::test::SimpleAAttr, +// DEF: ::mlir::test::CompoundAAttr, +// DEF: ::mlir::test::IndexAttr, +// DEF: ::mlir::test::SingleParameterAttr + +// DEF-LABEL: ::mlir::Attribute generatedAttributeParser(::mlir::MLIRContext *context, +// DEF-NEXT: ::mlir::DialectAsmParser &parser, +// DEF-NEXT: ::llvm::StringRef mnemonic, ::mlir::Type type) { +// DEF: if (mnemonic == ::mlir::test::CompoundAAttr::getMnemonic()) return ::mlir::test::CompoundAAttr::parse(context, parser, type); +// DEF-NEXT: if (mnemonic == ::mlir::test::IndexAttr::getMnemonic()) return ::mlir::test::IndexAttr::parse(context, parser, type); +// DEF-NEXT: return ::mlir::Attribute(); + +def Test_Dialect: Dialect { +// DECL-NOT: TestDialect +// DEF-NOT: TestDialect + let name = "TestDialect"; + let cppNamespace = "::mlir::test"; +} + +class TestAttr : AttrDef { } + +def A_SimpleAttrA : TestAttr<"SimpleA"> { +// DECL: class SimpleAAttr : public ::mlir::Attribute +} + +// A more complex parameterized type +def B_CompoundAttrA : TestAttr<"CompoundA"> { + let summary = "A more complex parameterized attribute"; + let description = "This attribute is to test a reasonably complex attribute"; + let mnemonic = "cmpnd_a"; + let parameters = ( + ins + "int":$widthOfSomething, + "::mlir::test::SimpleTypeA": $exampleTdType, + "SomeCppStruct": $exampleCppType, + ArrayRefParameter<"int", "Matrix dimensions">:$dims, + "::mlir::Type":$inner + ); + + let genVerifyDecl = 1; + +// DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute +// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); +// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); +// DECL: static constexpr ::llvm::StringLiteral getMnemonic() { +// DECL: return ::llvm::StringLiteral("cmpnd_a"); +// DECL: } +// DECL: static ::mlir::Attribute parse(::mlir::MLIRContext *context, +// DECL-NEXT: ::mlir::DialectAsmParser &parser, ::mlir::Type type); +// DECL: void print(::mlir::DialectAsmPrinter &printer) const; +// DECL: int getWidthOfSomething() const; +// DECL: ::mlir::test::SimpleTypeA getExampleTdType() const; +// DECL: SomeCppStruct getExampleCppType() const; +} + +def C_IndexAttr : TestAttr<"Index"> { + let mnemonic = "index"; + + let parameters = ( + ins + StringRefParameter<"Label for index">:$label + ); + +// DECL-LABEL: class IndexAttr : public ::mlir::Attribute +// DECL: static constexpr ::llvm::StringLiteral getMnemonic() { +// DECL: return ::llvm::StringLiteral("index"); +// DECL: } +// DECL: static ::mlir::Attribute parse(::mlir::MLIRContext *context, +// DECL-NEXT: ::mlir::DialectAsmParser &parser, ::mlir::Type type); +// DECL: void print(::mlir::DialectAsmPrinter &printer) const; +} + +def D_SingleParameterAttr : TestAttr<"SingleParameter"> { + let parameters = ( + ins + "int": $num + ); +// DECL-LABEL: struct SingleParameterAttrStorage; +// DECL-LABEL: class SingleParameterAttr +// DECL-NEXT: detail::SingleParameterAttrStorage +} diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir @@ -0,0 +1,5 @@ +// RUN: mlir-opt %s | mlir-opt -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func private @compoundA() +// CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]> +func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [5, 6]>} diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -19,9 +19,11 @@ // DEF: ::mlir::test::SingleParameterType, // DEF: ::mlir::test::IntegerType -// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &parser, ::llvm::StringRef mnemonic) +// DEF-LABEL: ::mlir::Type generatedTypeParser(::mlir::MLIRContext *context, +// DEF-NEXT: ::mlir::DialectAsmParser &parser, +// DEF-NEXT: ::llvm::StringRef mnemonic) { // DEF: if (mnemonic == ::mlir::test::CompoundAType::getMnemonic()) return ::mlir::test::CompoundAType::parse(context, parser); -// DEF return ::mlir::Type(); +// DEF: return ::mlir::Type(); def Test_Dialect: Dialect { // DECL-NOT: TestDialect diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -0,0 +1,849 @@ +//===- AttrOrTypeDefGen.cpp - MLIR AttrOrType definitions generator -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/AttrOrTypeDef.h" +#include "mlir/TableGen/CodeGenHelpers.h" +#include "mlir/TableGen/Format.h" +#include "mlir/TableGen/GenInfo.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/TableGenBackend.h" + +#define DEBUG_TYPE "mlir-tblgen-attrortypedefgen" + +using namespace mlir; +using namespace mlir::tblgen; + +/// Find all the AttrOrTypeDef for the specified dialect. If no dialect +/// specified and can only find one dialect's defs, use that. +static void collectAllDefs(StringRef selectedDialect, + std::vector records, + SmallVectorImpl &resultDefs) { + auto defs = llvm::map_range( + records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); }); + if (defs.empty()) + return; + + StringRef dialectName; + if (selectedDialect.empty()) { + if (defs.empty()) + return; + + Dialect dialect(nullptr); + for (const AttrOrTypeDef &typeDef : defs) { + if (!dialect) { + dialect = typeDef.getDialect(); + } else if (dialect != typeDef.getDialect()) { + llvm::PrintFatalError("defs belonging to more than one dialect. Must " + "select one via '--(attr|type)defs-dialect'"); + } + } + + dialectName = dialect.getName(); + } else { + dialectName = selectedDialect; + } + + for (const AttrOrTypeDef &def : defs) + if (def.getDialect().getName().equals(dialectName)) + resultDefs.push_back(def); +} + +//===----------------------------------------------------------------------===// +// ParamCommaFormatter +//===----------------------------------------------------------------------===// + +namespace { + +/// Pass an instance of this class to llvm::formatv() to emit a comma separated +/// list of parameters in the format by 'EmitFormat'. +class ParamCommaFormatter : public llvm::detail::format_adapter { +public: + /// Choose the output format + enum EmitFormat { + /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name, + /// [...]". + TypeNamePairs, + + /// Emit "parameter1(parameter1), parameter2(parameter2), [...]". + TypeNameInitializer, + + /// Emit "param1Name, param2Name, [...]". + JustParams, + }; + + ParamCommaFormatter(EmitFormat emitFormat, + ArrayRef params, + bool prependComma = true) + : emitFormat(emitFormat), params(params), prependComma(prependComma) {} + + /// llvm::formatv will call this function when using an instance as a + /// replacement value. + void format(raw_ostream &os, StringRef options) override { + if (!params.empty() && prependComma) + os << ", "; + + switch (emitFormat) { + case EmitFormat::TypeNamePairs: + interleaveComma(params, os, [&](const AttrOrTypeParameter &p) { + emitTypeNamePair(p, os); + }); + break; + case EmitFormat::TypeNameInitializer: + interleaveComma(params, os, [&](const AttrOrTypeParameter &p) { + emitTypeNameInitializer(p, os); + }); + break; + case EmitFormat::JustParams: + interleaveComma(params, os, + [&](const AttrOrTypeParameter &p) { os << p.getName(); }); + break; + } + } + +private: + // Emit "paramType paramName". + static void emitTypeNamePair(const AttrOrTypeParameter ¶m, + raw_ostream &os) { + os << param.getCppType() << " " << param.getName(); + } + // Emit "paramName(paramName)" + void emitTypeNameInitializer(const AttrOrTypeParameter ¶m, + raw_ostream &os) { + os << param.getName() << "(" << param.getName() << ")"; + } + + EmitFormat emitFormat; + ArrayRef params; + bool prependComma; +}; + +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// DefGenerator +//===----------------------------------------------------------------------===// + +namespace { +/// This struct is the base generator used when processing tablegen interfaces. +class DefGenerator { +public: + bool emitDecls(StringRef selectedDialect); + bool emitDefs(StringRef selectedDialect); + +protected: + DefGenerator(std::vector &&defs, raw_ostream &os) + : defRecords(std::move(defs)), os(os), isAttrGenerator(false) {} + + /// Emit the declaration of a single def. + void emitDefDecl(const AttrOrTypeDef &def); + /// Emit the list of def type names. + void emitTypeDefList(ArrayRef defs); + /// Emit the code to dispatch between different defs during parsing/printing. + void emitParsePrintDispatch(ArrayRef defs); + /// Emit the definition of a single def. + void emitDefDef(const AttrOrTypeDef &def); + /// Emit the storage class for the given def. + void emitStorageClass(const AttrOrTypeDef &def); + /// Emit the parser/printer for the given def. + void emitParsePrint(const AttrOrTypeDef &def); + + /// The set of def records to emit. + std::vector defRecords; + /// The stream to emit to. + raw_ostream &os; + /// The prefix of the tablegen def name, e.g. Attr or Type. + StringRef defTypePrefix; + /// The C++ base value type of the def, e.g. Attribute or Type. + StringRef valueType; + /// Flag indicating if this generator is for Attributes. False if the + /// generator is for types. + bool isAttrGenerator; +}; + +/// A specialized generator for AttrDefs. +struct AttrDefGenerator : public DefGenerator { + AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + : DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os) { + isAttrGenerator = true; + defTypePrefix = "Attr"; + valueType = "Attribute"; + } +}; +/// A specialized generator for TypeDefs. +struct TypeDefGenerator : public DefGenerator { + TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) + : DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os) { + defTypePrefix = "Type"; + valueType = "Type"; + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// GEN: Declarations +//===----------------------------------------------------------------------===// + +/// Print this above all the other declarations. Contains type declarations used +/// later on. +static const char *const typeDefDeclHeader = R"( +namespace mlir { +class DialectAsmParser; +class DialectAsmPrinter; +} // namespace mlir +)"; + +/// The code block for the start of a typeDef class declaration -- singleton +/// case. +/// +/// {0}: The name of the def class. +/// {1}: The name of the type base class. +/// {2}: The name of the base value type, e.g. Attribute or Type. +/// {3}: The tablegen record type prefix, e.g. Attr or Type. +static const char *const defDeclSingletonBeginStr = R"( + class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage> {{ + public: + /// Inherit some necessary constructors from '{3}Base'. + using Base::Base; +)"; + +/// The code block for the start of a typeDef class declaration -- parametric +/// case. +/// +/// {0}: The name of the typeDef class. +/// {1}: The name of the type base class. +/// {2}: The typeDef storage class namespace. +/// {3}: The storage class name. +/// {4}: The name of the base value type, e.g. Attribute or Type. +/// {5}: The tablegen record type prefix, e.g. Attr or Type. +static const char *const defDeclParametricBeginStr = R"( + namespace {2} { + struct {3}; + } // end namespace {2} + class {0} : public ::mlir::{4}::{5}Base<{0}, {1}, + {2}::{3}> {{ + public: + /// Inherit some necessary constructors from '{5}Base'. + using Base::Base; + +)"; + +/// The code snippet for print/parse of an Attribute/Type. +/// +/// {0}: The name of the base value type, e.g. Attribute or Type. +/// {1}: Extra parser parameters. +static const char *const defDeclParsePrintStr = R"( + static ::mlir::{0} parse(::mlir::MLIRContext *context, + ::mlir::DialectAsmParser &parser{1}); + void print(::mlir::DialectAsmPrinter &printer) const; +)"; + +/// The code block for the verify method declaration. +/// +/// {0}: List of parameters, parameters style. +static const char *const defDeclVerifyStr = R"( + using Base::getChecked; + static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0}); +)"; + +/// Emit the builders for the given def. +static void emitBuilderDecls(const AttrOrTypeDef &def, raw_ostream &os, + ParamCommaFormatter ¶mTypes) { + StringRef typeClass = def.getCppClassName(); + bool genCheckedMethods = def.genVerifyDecl(); + if (!def.skipDefaultBuilders()) { + os << llvm::formatv( + " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass, + paramTypes); + if (genCheckedMethods) { + os << llvm::formatv(" static {0} " + "getChecked(llvm::function_ref<::mlir::" + "InFlightDiagnostic()> emitError, " + "::mlir::MLIRContext *context{1});\n", + typeClass, paramTypes); + } + } + + // Generate the builders specified by the user. + for (const AttrOrTypeBuilder &builder : def.getBuilders()) { + std::string paramStr; + llvm::raw_string_ostream paramOS(paramStr); + llvm::interleaveComma( + builder.getParameters(), paramOS, + [&](const AttrOrTypeBuilder::Parameter ¶m) { + // Note: AttrOrTypeBuilder parameters are guaranteed to have names. + paramOS << param.getCppType() << " " << *param.getName(); + if (Optional defaultParamValue = param.getDefaultValue()) + paramOS << " = " << *defaultParamValue; + }); + paramOS.flush(); + + // Generate the `get` variant of the builder. + os << " static " << typeClass << " get("; + if (!builder.hasInferredContextParameter()) { + os << "::mlir::MLIRContext *context"; + if (!paramStr.empty()) + os << ", "; + } + os << paramStr << ");\n"; + + // Generate the `getChecked` variant of the builder. + if (genCheckedMethods) { + os << " static " << typeClass + << " getChecked(llvm::function_ref " + "emitError"; + if (!builder.hasInferredContextParameter()) + os << ", ::mlir::MLIRContext *context"; + if (!paramStr.empty()) + os << ", "; + os << paramStr << ");\n"; + } + } +} + +void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) { + SmallVector params; + def.getParameters(params); + + // Emit the beginning string template: either the singleton or parametric + // template. + if (def.getNumParameters() == 0) { + os << formatv(defDeclSingletonBeginStr, def.getCppClassName(), + def.getCppBaseClassName(), valueType, defTypePrefix); + } else { + os << formatv(defDeclParametricBeginStr, def.getCppClassName(), + def.getCppBaseClassName(), def.getStorageNamespace(), + def.getStorageClassName(), valueType, defTypePrefix); + } + + // Emit the extra declarations first in case there's a definition in there. + if (Optional extraDecl = def.getExtraDecls()) + os << *extraDecl << "\n"; + + ParamCommaFormatter emitTypeNamePairsAfterComma( + ParamCommaFormatter::EmitFormat::TypeNamePairs, params); + if (!params.empty()) { + emitBuilderDecls(def, os, emitTypeNamePairsAfterComma); + + // Emit the verify invariants declaration. + if (def.genVerifyDecl()) + os << llvm::formatv(defDeclVerifyStr, emitTypeNamePairsAfterComma); + } + + // Emit the mnenomic, if specified. + if (auto mnenomic = def.getMnemonic()) { + os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n" + << " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n" + << " }\n"; + + // If mnemonic specified, emit print/parse declarations. + if (def.getParserCode() || def.getPrinterCode() || !params.empty()) { + os << llvm::formatv(defDeclParsePrintStr, valueType, + isAttrGenerator ? ", ::mlir::Type type" : ""); + } + } + + if (def.genAccessors()) { + SmallVector parameters; + def.getParameters(parameters); + + for (AttrOrTypeParameter ¶meter : parameters) { + SmallString<16> name = parameter.getName(); + name[0] = llvm::toUpper(name[0]); + os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name); + } + } + + // End the decl. + os << " };\n"; +} + +bool DefGenerator::emitDecls(StringRef selectedDialect) { + emitSourceFileHeader((defTypePrefix + "Def Declarations").str(), os); + IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os); + + // Output the common "header". + os << typeDefDeclHeader; + + SmallVector defs; + collectAllDefs(selectedDialect, defRecords, defs); + if (defs.empty()) + return false; + + NamespaceEmitter nsEmitter(os, defs.front().getDialect()); + + // Declare all the def classes first (in case they reference each other). + for (const AttrOrTypeDef &def : defs) + os << " class " << def.getCppClassName() << ";\n"; + + // Emit the declarations. + for (const AttrOrTypeDef &def : defs) + emitDefDecl(def); + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: Def List +//===----------------------------------------------------------------------===// + +void DefGenerator::emitTypeDefList(ArrayRef defs) { + IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_LIST", os); + auto interleaveFn = [&](const AttrOrTypeDef &def) { + os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName(); + }; + llvm::interleave(defs, os, interleaveFn, ",\n"); + os << "\n"; +} + +//===----------------------------------------------------------------------===// +// GEN: Definitions +//===----------------------------------------------------------------------===// + +/// The code block used to start the auto-generated parser function. +/// +/// {0}: The name of the base value type, e.g. Attribute or Type. +/// {1}: Additional parser parameters. +static const char *const defParserDispatchStartStr = R"( +static ::mlir::{0} generated{0}Parser(::mlir::MLIRContext *context, + ::mlir::DialectAsmParser &parser, + ::llvm::StringRef mnemonic{1}) {{ +)"; + +/// The code block used to start the auto-generated printer function. +/// +/// {0}: The name of the base value type, e.g. Attribute or Type. +static const char *const defPrinterDispatchStartStr = R"( +static ::mlir::LogicalResult generated{0}Printer( + ::mlir::{0} def, ::mlir::DialectAsmPrinter &printer) {{ + return ::llvm::TypeSwitch<::mlir::{0}, ::mlir::LogicalResult>(def) +)"; + +/// Beginning of storage class. +/// {0}: Storage class namespace. +/// {1}: Storage class c++ name. +/// {2}: Parameters parameters. +/// {3}: Parameter initializer string. +/// {4}: Parameter name list. +/// {5}: Parameter types. +/// {6}: The name of the base value type, e.g. Attribute or Type. +static const char *const defStorageClassBeginStr = R"( +namespace {0} {{ + struct {1} : public ::mlir::{6}Storage {{ + {1} ({2}) + : {3} {{ } + + /// The hash key is a tuple of the parameter types. + using KeyTy = std::tuple<{5}>; + + /// Define the comparison function for the key type. + bool operator==(const KeyTy &key) const {{ + return key == KeyTy({4}); + } +)"; + +/// The storage class' constructor template. +/// +/// {0}: storage class name. +/// {1}: The name of the base value type, e.g. Attribute or Type. +static const char *const defStorageClassConstructorBeginStr = R"( + /// Define a construction method for creating a new instance of this + /// storage. + static {0} *construct(::mlir::{1}StorageAllocator &allocator, + const KeyTy &key) {{ +)"; + +/// The storage class' constructor return template. +/// +/// {0}: storage class name. +/// {1}: list of parameters. +static const char *const defStorageClassConstructorEndStr = R"( + return new (allocator.allocate<{0}>()) + {0}({1}); + } +)"; + +/// Use tgfmt to emit custom allocation code for each parameter, if necessary. +static void emitStorageParameterAllocation(const AttrOrTypeDef &def, + raw_ostream &os) { + SmallVector parameters; + def.getParameters(parameters); + FmtContext fmtCtxt = FmtContext().addSubst("_allocator", "allocator"); + for (AttrOrTypeParameter ¶meter : parameters) { + if (Optional allocCode = parameter.getAllocator()) { + fmtCtxt.withSelf(parameter.getName()); + fmtCtxt.addSubst("_dst", parameter.getName()); + os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n"; + } + } +} + +void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) { + SmallVector parameters; + def.getParameters(parameters); + + // Collect the parameter names and types. + auto parameterNames = + map_range(parameters, [](AttrOrTypeParameter parameter) { + return parameter.getName(); + }); + auto parameterTypes = + map_range(parameters, [](AttrOrTypeParameter parameter) { + return parameter.getCppType(); + }); + auto parameterList = join(parameterNames, ", "); + auto parameterTypeList = join(parameterTypes, ", "); + + // 1) Emit most of the storage class up until the hashKey body. + os << formatv( + defStorageClassBeginStr, def.getStorageNamespace(), + def.getStorageClassName(), + ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, + parameters, /*prependComma=*/false), + ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNameInitializer, + parameters, /*prependComma=*/false), + parameterList, parameterTypeList, valueType); + + // 2) Emit the haskKey method. + os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n"; + + // Extract each parameter from the key. + os << " return ::llvm::hash_combine("; + llvm::interleaveComma( + llvm::seq(0, parameters.size()), os, + [&](unsigned it) { os << "std::get<" << it << ">(key)"; }); + os << ");\n }\n"; + + // 3) Emit the construct method. + + // If user wants to build the storage constructor themselves, declare it + // here and then they can write the definition elsewhere. + if (def.hasStorageCustomConstructor()) { + os << llvm::formatv(" static {0} *construct(::mlir::{1}StorageAllocator " + "&allocator, const KeyTy &key);\n", + def.getStorageClassName(), valueType); + + // Otherwise, generate one. + } else { + // First, unbox the parameters. + os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(), + valueType); + for (unsigned i = 0, e = parameters.size(); i < e; ++i) { + os << formatv(" auto {0} = std::get<{1}>(key);\n", + parameters[i].getName(), i); + } + + // Second, reassign the parameter variables with allocation code, if it's + // specified. + emitStorageParameterAllocation(def, os); + + // Last, return an allocated copy. + os << formatv(defStorageClassConstructorEndStr, def.getStorageClassName(), + parameterList); + } + + // 4) Emit the parameters as storage class members. + for (auto parameter : parameters) { + os << " " << parameter.getCppType() << " " << parameter.getName() + << ";\n"; + } + os << " };\n"; + + os << "} // namespace " << def.getStorageNamespace() << "\n"; +} + +void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) { + // Emit the printer code, if specified. + if (Optional printerCode = def.getPrinterCode()) { + // Both the mnenomic and printerCode must be defined (for parity with + // parserCode). + os << "void " << def.getCppClassName() + << "::print(::mlir::DialectAsmPrinter &printer) const {\n"; + if (printerCode->empty()) { + // If no code specified, emit error. + PrintFatalError(def.getLoc(), + def.getName() + + ": printer (if specified) must have non-empty code"); + } + FmtContext fmtCtxt = FmtContext().addSubst("_printer", "printer"); + os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n"; + } + + // Emit the parser code, if specified. + if (Optional parserCode = def.getParserCode()) { + FmtContext fmtCtxt; + fmtCtxt.addSubst("_parser", "parser").addSubst("_ctxt", "context"); + + // The mnenomic must be defined so the dispatcher knows how to dispatch. + os << llvm::formatv("::mlir::{0} {1}::parse(::mlir::MLIRContext *context, " + "::mlir::DialectAsmParser &parser", + valueType, def.getCppClassName()); + if (isAttrGenerator) { + // Attributes also accept a type parameter instead of a context. + os << ", ::mlir::Type type"; + fmtCtxt.addSubst("_type", "type"); + } + os << ") {\n"; + + if (parserCode->empty()) { + PrintFatalError(def.getLoc(), + def.getName() + + ": parser (if specified) must have non-empty code"); + } + os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n"; + } +} + +/// Replace all instances of 'from' to 'to' in `str` and return the new string. +static std::string replaceInStr(std::string str, StringRef from, StringRef to) { + size_t pos = 0; + while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos) + str.replace(pos, from.size(), to.data(), to.size()); + return str; +} + +/// Emit the builders for the given def. +static void emitBuilderDefs(const AttrOrTypeDef &def, raw_ostream &os, + ArrayRef params) { + bool genCheckedMethods = def.genVerifyDecl(); + StringRef className = def.getCppClassName(); + if (!def.skipDefaultBuilders()) { + os << llvm::formatv( + "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n" + " return Base::get(context{2});\n}\n", + className, + ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, + params), + ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams, + params)); + if (genCheckedMethods) { + os << llvm::formatv( + "{0} {0}::getChecked(" + "llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, " + "::mlir::MLIRContext *context{1}) {{\n" + " return Base::getChecked(emitError, context{2});\n}\n", + className, + ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, + params), + ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams, + params)); + } + } + + auto builderFmtCtx = + FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get"); + auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get"); + auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context"); + + // Generate the builders specified by the user. + for (const AttrOrTypeBuilder &builder : def.getBuilders()) { + Optional body = builder.getBody(); + if (!body) + continue; + std::string paramStr; + llvm::raw_string_ostream paramOS(paramStr); + llvm::interleaveComma(builder.getParameters(), paramOS, + [&](const AttrOrTypeBuilder::Parameter ¶m) { + // Note: AttrOrTypeBuilder parameters are guaranteed + // to have names. + paramOS << param.getCppType() << " " + << *param.getName(); + }); + paramOS.flush(); + + // Emit the `get` variant of the builder. + os << llvm::formatv("{0} {0}::get(", className); + if (!builder.hasInferredContextParameter()) { + os << "::mlir::MLIRContext *context"; + if (!paramStr.empty()) + os << ", "; + os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, + tgfmt(*body, &builderFmtCtx).str()); + } else { + os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, + tgfmt(*body, &inferredCtxBuilderFmtCtx).str()); + } + + // Emit the `getChecked` variant of the builder. + if (genCheckedMethods) { + os << llvm::formatv("{0} " + "{0}::getChecked(llvm::function_ref<::mlir::" + "InFlightDiagnostic()> emitErrorFn", + className); + std::string checkedBody = + replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, "); + if (!builder.hasInferredContextParameter()) { + os << ", ::mlir::MLIRContext *context"; + checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str(); + } + if (!paramStr.empty()) + os << ", "; + os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody); + } + } +} + +/// Print all the def-specific definition code. +void DefGenerator::emitDefDef(const AttrOrTypeDef &def) { + NamespaceEmitter ns(os, def.getDialect()); + + SmallVector parameters; + def.getParameters(parameters); + if (!parameters.empty()) { + // Emit the storage class, if requested and necessary. + if (def.genStorageClass()) + emitStorageClass(def); + + // Emit the builders for this def. + emitBuilderDefs(def, os, parameters); + + // Generate accessor definitions only if we also generate the storage class. + // Otherwise, let the user define the exact accessor definition. + if (def.genAccessors() && def.genStorageClass()) { + for (const AttrOrTypeParameter ¶meter : parameters) { + SmallString<16> name = parameter.getName(); + name[0] = llvm::toUpper(name[0]); + os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n", + parameter.getCppType(), name, parameter.getName(), + def.getCppClassName()); + } + } + } + + // If mnemonic is specified maybe print definitions for the parser and printer + // code, if they're specified. + if (def.getMnemonic()) + emitParsePrint(def); +} + +/// Emit the dialect printer/parser dispatcher. User's code should call these +/// functions from their dialect's print/parse methods. +void DefGenerator::emitParsePrintDispatch(ArrayRef defs) { + if (llvm::none_of(defs, [](const AttrOrTypeDef &def) { + return def.getMnemonic().hasValue(); + })) { + return; + } + + // The parser dispatch is just a list of if-elses, matching on the mnemonic + // and calling the def's parse function. + os << llvm::formatv(defParserDispatchStartStr, valueType, + isAttrGenerator ? ", ::mlir::Type type" : ""); + for (const AttrOrTypeDef &def : defs) { + if (def.getMnemonic()) { + os << formatv( + " if (mnemonic == {0}::{1}::getMnemonic()) return {0}::{1}::", + def.getDialect().getCppNamespace(), def.getCppClassName()); + + // If the def has no parameters and no parser code, just invoke a normal + // `get`. + if (def.getNumParameters() == 0 && !def.getParserCode()) { + os << "get(context);\n"; + continue; + } + + os << "parse(context, parser" << (isAttrGenerator ? ", type" : "") + << ");\n"; + } + } + os << " return ::mlir::" << valueType << "();\n"; + os << "}\n\n"; + + // The printer dispatch uses llvm::TypeSwitch to find and call the correct + // printer. + os << llvm::formatv(defPrinterDispatchStartStr, valueType); + for (const AttrOrTypeDef &def : defs) { + Optional mnemonic = def.getMnemonic(); + if (!mnemonic) + continue; + + StringRef cppNamespace = def.getDialect().getCppNamespace(); + StringRef cppClassName = def.getCppClassName(); + os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ", + cppNamespace, cppClassName); + + // If the def has no parameters and no printer, just print the mnemonic. + if (def.getNumParameters() == 0 && !def.getPrinterCode()) { + os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace, + cppClassName); + } else { + os << "t.print(printer);"; + } + os << "\n return ::mlir::success();\n })\n"; + } + os << llvm::formatv( + " .Default([](::mlir::{0}) {{ return ::mlir::failure(); });\n}\n\n", + valueType); +} + +bool DefGenerator::emitDefs(StringRef selectedDialect) { + emitSourceFileHeader((defTypePrefix + "Def Definitions").str(), os); + + SmallVector defs; + collectAllDefs(selectedDialect, defRecords, defs); + if (defs.empty()) + return false; + emitTypeDefList(defs); + + IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os); + emitParsePrintDispatch(defs); + for (const AttrOrTypeDef &def : defs) + emitDefDef(def); + + return false; +} + +//===----------------------------------------------------------------------===// +// GEN: Registration hooks +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// AttrDef + +static llvm::cl::OptionCategory attrdefGenCat("Options for -gen-attrdef-*"); +static llvm::cl::opt + attrDialect("attrdefs-dialect", + llvm::cl::desc("Generate attributes for this dialect"), + llvm::cl::cat(attrdefGenCat), llvm::cl::CommaSeparated); + +static mlir::GenRegistration + genAttrDefs("gen-attrdef-defs", "Generate AttrDef definitions", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + AttrDefGenerator generator(records, os); + return generator.emitDefs(attrDialect); + }); +static mlir::GenRegistration + genAttrDecls("gen-attrdef-decls", "Generate AttrDef declarations", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + AttrDefGenerator generator(records, os); + return generator.emitDecls(attrDialect); + }); + +//===----------------------------------------------------------------------===// +// TypeDef + +static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*"); +static llvm::cl::opt + typeDialect("typedefs-dialect", + llvm::cl::desc("Generate types for this dialect"), + llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated); + +static mlir::GenRegistration + genTypeDefs("gen-typedef-defs", "Generate TypeDef definitions", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + TypeDefGenerator generator(records, os); + return generator.emitDefs(typeDialect); + }); +static mlir::GenRegistration + genTypeDecls("gen-typedef-decls", "Generate TypeDef declarations", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + TypeDefGenerator generator(records, os); + return generator.emitDecls(typeDialect); + }); diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -5,6 +5,7 @@ ) add_tablegen(mlir-tblgen MLIR + AttrOrTypeDefGen.cpp DialectGen.cpp DirectiveCommonGen.cpp EnumsGen.cpp @@ -22,7 +23,6 @@ RewriterGen.cpp SPIRVUtilsGen.cpp StructsGen.cpp - TypeDefGen.cpp ) set_target_properties(mlir-tblgen PROPERTIES FOLDER "Tablegenning") diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -13,9 +13,9 @@ #include "DocGenUtilities.h" #include "mlir/Support/IndentedOstream.h" +#include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" -#include "mlir/TableGen/TypeDef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/FormatVariadic.h" @@ -164,7 +164,7 @@ /// Emit the assembly format of a type. static void emitTypeAssemblyFormat(TypeDef td, raw_ostream &os) { - SmallVector parameters; + SmallVector parameters; td.getParameters(parameters); if (parameters.size() == 0) { os << "\nSyntax: `!" << td.getDialect().getName() << "." << td.getMnemonic() @@ -198,7 +198,7 @@ } // Emit attribute documentation. - SmallVector parameters; + SmallVector parameters; td.getParameters(parameters); if (!parameters.empty()) { os << "\n#### Type parameters:\n\n"; diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp deleted file mode 100644 --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ /dev/null @@ -1,739 +0,0 @@ -//===- TypeDefGen.cpp - MLIR typeDef definitions generator ----------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// TypeDefGen uses the description of typeDefs to generate C++ definitions. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Support/LogicalResult.h" -#include "mlir/TableGen/CodeGenHelpers.h" -#include "mlir/TableGen/Format.h" -#include "mlir/TableGen/GenInfo.h" -#include "mlir/TableGen/TypeDef.h" -#include "llvm/ADT/SmallSet.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/TableGen/Error.h" -#include "llvm/TableGen/TableGenBackend.h" - -#define DEBUG_TYPE "mlir-tblgen-typedefgen" - -using namespace mlir; -using namespace mlir::tblgen; - -static llvm::cl::OptionCategory typedefGenCat("Options for -gen-typedef-*"); -static llvm::cl::opt - selectedDialect("typedefs-dialect", - llvm::cl::desc("Gen types for this dialect"), - llvm::cl::cat(typedefGenCat), llvm::cl::CommaSeparated); - -/// Find all the TypeDefs for the specified dialect. If no dialect specified and -/// can only find one dialect's types, use that. -static void findAllTypeDefs(const llvm::RecordKeeper &recordKeeper, - SmallVectorImpl &typeDefs) { - auto recDefs = recordKeeper.getAllDerivedDefinitions("TypeDef"); - auto defs = llvm::map_range( - recDefs, [&](const llvm::Record *rec) { return TypeDef(rec); }); - if (defs.empty()) - return; - - StringRef dialectName; - if (selectedDialect.getNumOccurrences() == 0) { - if (defs.empty()) - return; - - llvm::SmallSet dialects; - for (const TypeDef typeDef : defs) - dialects.insert(typeDef.getDialect()); - if (dialects.size() != 1) - llvm::PrintFatalError("TypeDefs belonging to more than one dialect. Must " - "select one via '--typedefs-dialect'"); - - dialectName = (*dialects.begin()).getName(); - } else if (selectedDialect.getNumOccurrences() == 1) { - dialectName = selectedDialect.getValue(); - } else { - llvm::PrintFatalError("Cannot select multiple dialects for which to " - "generate types via '--typedefs-dialect'."); - } - - for (const TypeDef typeDef : defs) - if (typeDef.getDialect().getName().equals(dialectName)) - typeDefs.push_back(typeDef); -} - -namespace { - -/// Pass an instance of this class to llvm::formatv() to emit a comma separated -/// list of parameters in the format by 'EmitFormat'. -class TypeParamCommaFormatter : public llvm::detail::format_adapter { -public: - /// Choose the output format - enum EmitFormat { - /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name, - /// [...]". - TypeNamePairs, - - /// Emit "parameter1(parameter1), parameter2(parameter2), [...]". - TypeNameInitializer, - - /// Emit "param1Name, param2Name, [...]". - JustParams, - }; - - TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef params, - bool prependComma = true) - : emitFormat(emitFormat), params(params), prependComma(prependComma) {} - - /// llvm::formatv will call this function when using an instance as a - /// replacement value. - void format(raw_ostream &os, StringRef options) override { - if (!params.empty() && prependComma) - os << ", "; - - switch (emitFormat) { - case EmitFormat::TypeNamePairs: - interleaveComma(params, os, - [&](const TypeParameter &p) { emitTypeNamePair(p, os); }); - break; - case EmitFormat::TypeNameInitializer: - interleaveComma(params, os, [&](const TypeParameter &p) { - emitTypeNameInitializer(p, os); - }); - break; - case EmitFormat::JustParams: - interleaveComma(params, os, - [&](const TypeParameter &p) { os << p.getName(); }); - break; - } - } - -private: - // Emit "paramType paramName". - static void emitTypeNamePair(const TypeParameter ¶m, raw_ostream &os) { - os << param.getCppType() << " " << param.getName(); - } - // Emit "paramName(paramName)" - void emitTypeNameInitializer(const TypeParameter ¶m, raw_ostream &os) { - os << param.getName() << "(" << param.getName() << ")"; - } - - EmitFormat emitFormat; - ArrayRef params; - bool prependComma; -}; - -} // end anonymous namespace - -//===----------------------------------------------------------------------===// -// GEN: TypeDef declarations -//===----------------------------------------------------------------------===// - -/// Print this above all the other declarations. Contains type declarations used -/// later on. -static const char *const typeDefDeclHeader = R"( -namespace mlir { -class DialectAsmParser; -class DialectAsmPrinter; -} // namespace mlir -)"; - -/// The code block for the start of a typeDef class declaration -- singleton -/// case. -/// -/// {0}: The name of the typeDef class. -/// {1}: The name of the type base class. -static const char *const typeDefDeclSingletonBeginStr = R"( - class {0} : public ::mlir::Type::TypeBase<{0}, {1}, ::mlir::TypeStorage> {{ - public: - /// Inherit some necessary constructors from 'TypeBase'. - using Base::Base; - -)"; - -/// The code block for the start of a typeDef class declaration -- parametric -/// case. -/// -/// {0}: The name of the typeDef class. -/// {1}: The name of the type base class. -/// {2}: The typeDef storage class namespace. -/// {3}: The storage class name. -/// {4}: The list of parameters with types. -static const char *const typeDefDeclParametricBeginStr = R"( - namespace {2} { - struct {3}; - } // end namespace {2} - class {0} : public ::mlir::Type::TypeBase<{0}, {1}, - {2}::{3}> {{ - public: - /// Inherit some necessary constructors from 'TypeBase'. - using Base::Base; - -)"; - -/// The snippet for print/parse. -static const char *const typeDefParsePrint = R"( - static ::mlir::Type parse(::mlir::MLIRContext *context, - ::mlir::DialectAsmParser &parser); - void print(::mlir::DialectAsmPrinter &printer) const; -)"; - -/// The code block for the verify method declaration. -/// -/// {0}: List of parameters, parameters style. -static const char *const typeDefDeclVerifyStr = R"( - using Base::getChecked; - static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0}); -)"; - -/// Emit the builders for the given type. -static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os, - TypeParamCommaFormatter ¶mTypes) { - StringRef typeClass = typeDef.getCppClassName(); - bool genCheckedMethods = typeDef.genVerifyDecl(); - if (!typeDef.skipDefaultBuilders()) { - os << llvm::formatv( - " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass, - paramTypes); - if (genCheckedMethods) { - os << llvm::formatv(" static {0} " - "getChecked(llvm::function_ref<::mlir::" - "InFlightDiagnostic()> emitError, " - "::mlir::MLIRContext *context{1});\n", - typeClass, paramTypes); - } - } - - // Generate the builders specified by the user. - for (const TypeBuilder &builder : typeDef.getBuilders()) { - std::string paramStr; - llvm::raw_string_ostream paramOS(paramStr); - llvm::interleaveComma( - builder.getParameters(), paramOS, - [&](const TypeBuilder::Parameter ¶m) { - // Note: TypeBuilder parameters are guaranteed to have names. - paramOS << param.getCppType() << " " << *param.getName(); - if (Optional defaultParamValue = param.getDefaultValue()) - paramOS << " = " << *defaultParamValue; - }); - paramOS.flush(); - - // Generate the `get` variant of the builder. - os << " static " << typeClass << " get("; - if (!builder.hasInferredContextParameter()) { - os << "::mlir::MLIRContext *context"; - if (!paramStr.empty()) - os << ", "; - } - os << paramStr << ");\n"; - - // Generate the `getChecked` variant of the builder. - if (genCheckedMethods) { - os << " static " << typeClass - << " getChecked(llvm::function_ref " - "emitError"; - if (!builder.hasInferredContextParameter()) - os << ", ::mlir::MLIRContext *context"; - if (!paramStr.empty()) - os << ", "; - os << paramStr << ");\n"; - } - } -} - -/// Generate the declaration for the given typeDef class. -static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) { - SmallVector params; - typeDef.getParameters(params); - - // Emit the beginning string template: either the singleton or parametric - // template. - if (typeDef.getNumParameters() == 0) - os << formatv(typeDefDeclSingletonBeginStr, typeDef.getCppClassName(), - typeDef.getCppBaseClassName()); - else - os << formatv(typeDefDeclParametricBeginStr, typeDef.getCppClassName(), - typeDef.getCppBaseClassName(), typeDef.getStorageNamespace(), - typeDef.getStorageClassName()); - - // Emit the extra declarations first in case there's a type definition in - // there. - if (Optional extraDecl = typeDef.getExtraDecls()) - os << *extraDecl << "\n"; - - TypeParamCommaFormatter emitTypeNamePairsAfterComma( - TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params); - if (!params.empty()) { - emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma); - - // Emit the verify invariants declaration. - if (typeDef.genVerifyDecl()) - os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma); - } - - // Emit the mnenomic, if specified. - if (auto mnenomic = typeDef.getMnemonic()) { - os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n" - << " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n" - << " }\n"; - - // If mnemonic specified, emit print/parse declarations. - if (typeDef.getParserCode() || typeDef.getPrinterCode() || !params.empty()) - os << typeDefParsePrint; - } - - if (typeDef.genAccessors()) { - SmallVector parameters; - typeDef.getParameters(parameters); - - for (TypeParameter ¶meter : parameters) { - SmallString<16> name = parameter.getName(); - name[0] = llvm::toUpper(name[0]); - os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name); - } - } - - // End the typeDef decl. - os << " };\n"; -} - -/// Main entry point for decls. -static bool emitTypeDefDecls(const llvm::RecordKeeper &recordKeeper, - raw_ostream &os) { - emitSourceFileHeader("TypeDef Declarations", os); - - SmallVector typeDefs; - findAllTypeDefs(recordKeeper, typeDefs); - - IfDefScope scope("GET_TYPEDEF_CLASSES", os); - - // Output the common "header". - os << typeDefDeclHeader; - - if (!typeDefs.empty()) { - NamespaceEmitter nsEmitter(os, typeDefs.begin()->getDialect()); - - // Declare all the type classes first (in case they reference each other). - for (const TypeDef &typeDef : typeDefs) - os << " class " << typeDef.getCppClassName() << ";\n"; - - // Declare all the typedefs. - for (const TypeDef &typeDef : typeDefs) - emitTypeDefDecl(typeDef, os); - } - - return false; -} - -//===----------------------------------------------------------------------===// -// GEN: TypeDef list -//===----------------------------------------------------------------------===// - -static void emitTypeDefList(SmallVectorImpl &typeDefs, - raw_ostream &os) { - IfDefScope scope("GET_TYPEDEF_LIST", os); - for (auto *i = typeDefs.begin(); i != typeDefs.end(); i++) { - os << i->getDialect().getCppNamespace() << "::" << i->getCppClassName(); - if (i < typeDefs.end() - 1) - os << ",\n"; - else - os << "\n"; - } -} - -//===----------------------------------------------------------------------===// -// GEN: TypeDef definitions -//===----------------------------------------------------------------------===// - -/// Beginning of storage class. -/// {0}: Storage class namespace. -/// {1}: Storage class c++ name. -/// {2}: Parameters parameters. -/// {3}: Parameter initializer string. -/// {4}: Parameter name list. -/// {5}: Parameter types. -static const char *const typeDefStorageClassBegin = R"( -namespace {0} {{ - struct {1} : public ::mlir::TypeStorage {{ - {1} ({2}) - : {3} {{ } - - /// The hash key for this storage is a pair of the integer and type params. - using KeyTy = std::tuple<{5}>; - - /// Define the comparison function for the key type. - bool operator==(const KeyTy &key) const {{ - return key == KeyTy({4}); - } -)"; - -/// The storage class' constructor template. -/// {0}: storage class name. -static const char *const typeDefStorageClassConstructorBegin = R"( - /// Define a construction method for creating a new instance of this storage. - static {0} *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &key) {{ -)"; - -/// The storage class' constructor return template. -/// {0}: storage class name. -/// {1}: list of parameters. -static const char *const typeDefStorageClassConstructorReturn = R"( - return new (allocator.allocate<{0}>()) - {0}({1}); - } -)"; - -/// Use tgfmt to emit custom allocation code for each parameter, if necessary. -static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) { - SmallVector parameters; - typeDef.getParameters(parameters); - auto fmtCtxt = FmtContext().addSubst("_allocator", "allocator"); - for (TypeParameter ¶meter : parameters) { - auto allocCode = parameter.getAllocator(); - if (allocCode) { - fmtCtxt.withSelf(parameter.getName()); - fmtCtxt.addSubst("_dst", parameter.getName()); - os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n"; - } - } -} - -/// Emit the storage class code for type 'typeDef'. -/// This includes (in-order): -/// 1) typeDefStorageClassBegin, which includes: -/// - The class constructor. -/// - The KeyTy definition. -/// - The equality (==) operator. -/// 2) The hashKey method. -/// 3) The construct method. -/// 4) The list of parameters as the storage class member variables. -static void emitStorageClass(TypeDef typeDef, raw_ostream &os) { - SmallVector parameters; - typeDef.getParameters(parameters); - - // Initialize a bunch of variables to be used later on. - auto parameterNames = map_range( - parameters, [](TypeParameter parameter) { return parameter.getName(); }); - auto parameterTypes = map_range(parameters, [](TypeParameter parameter) { - return parameter.getCppType(); - }); - auto parameterList = join(parameterNames, ", "); - auto parameterTypeList = join(parameterTypes, ", "); - - // 1) Emit most of the storage class up until the hashKey body. - os << formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(), - typeDef.getStorageClassName(), - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNamePairs, - parameters, /*prependComma=*/false), - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, - parameters, /*prependComma=*/false), - parameterList, parameterTypeList); - - // 2) Emit the haskKey method. - os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n"; - // Extract each parameter from the key. - for (size_t i = 0, e = parameters.size(); i < e; ++i) - os << llvm::formatv(" const auto &{0} = std::get<{1}>(key);\n", - parameters[i].getName(), i); - // Then combine them all. This requires all the parameters types to have a - // hash_value defined. - os << llvm::formatv( - " return ::llvm::hash_combine({0});\n }\n", - TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, - parameters, /* prependComma */ false)); - - // 3) Emit the construct method. - if (typeDef.hasStorageCustomConstructor()) { - // If user wants to build the storage constructor themselves, declare it - // here and then they can write the definition elsewhere. - os << " static " << typeDef.getStorageClassName() - << " *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy " - "&key);\n"; - } else { - // If not, autogenerate one. - - // First, unbox the parameters. - os << formatv(typeDefStorageClassConstructorBegin, - typeDef.getStorageClassName()); - for (size_t i = 0; i < parameters.size(); ++i) { - os << formatv(" auto {0} = std::get<{1}>(key);\n", - parameters[i].getName(), i); - } - // Second, reassign the parameter variables with allocation code, if it's - // specified. - emitParameterAllocationCode(typeDef, os); - - // Last, return an allocated copy. - os << formatv(typeDefStorageClassConstructorReturn, - typeDef.getStorageClassName(), parameterList); - } - - // 4) Emit the parameters as storage class members. - for (auto parameter : parameters) { - os << " " << parameter.getCppType() << " " << parameter.getName() - << ";\n"; - } - os << " };\n"; - - os << "} // namespace " << typeDef.getStorageNamespace() << "\n"; -} - -/// Emit the parser and printer for a particular type, if they're specified. -void emitParserPrinter(TypeDef typeDef, raw_ostream &os) { - // Emit the printer code, if specified. - if (auto printerCode = typeDef.getPrinterCode()) { - // Both the mnenomic and printerCode must be defined (for parity with - // parserCode). - os << "void " << typeDef.getCppClassName() - << "::print(::mlir::DialectAsmPrinter &printer) const {\n"; - if (*printerCode == "") { - // If no code specified, emit error. - PrintFatalError(typeDef.getLoc(), - typeDef.getName() + - ": printer (if specified) must have non-empty code"); - } - auto fmtCtxt = FmtContext().addSubst("_printer", "printer"); - os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n"; - } - - // emit a parser, if specified. - if (auto parserCode = typeDef.getParserCode()) { - // The mnenomic must be defined so the dispatcher knows how to dispatch. - os << "::mlir::Type " << typeDef.getCppClassName() - << "::parse(::mlir::MLIRContext *context, ::mlir::DialectAsmParser &" - "parser) " - "{\n"; - if (*parserCode == "") { - // if no code specified, emit error. - PrintFatalError(typeDef.getLoc(), - typeDef.getName() + - ": parser (if specified) must have non-empty code"); - } - auto fmtCtxt = - FmtContext().addSubst("_parser", "parser").addSubst("_ctxt", "context"); - os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n"; - } -} - -/// Replace all instances of 'from' to 'to' in `str` and return the new string. -static std::string replaceInStr(std::string str, StringRef from, StringRef to) { - size_t pos = 0; - while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos) - str.replace(pos, from.size(), to.data(), to.size()); - return str; -} - -/// Emit the builders for the given type. -static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os, - ArrayRef typeDefParams) { - bool genCheckedMethods = typeDef.genVerifyDecl(); - StringRef typeClass = typeDef.getCppClassName(); - if (!typeDef.skipDefaultBuilders()) { - os << llvm::formatv( - "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n" - " return Base::get(context{2});\n}\n", - typeClass, - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNamePairs, typeDefParams), - TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams, - typeDefParams)); - if (genCheckedMethods) { - os << llvm::formatv( - "{0} {0}::getChecked(" - "llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, " - "::mlir::MLIRContext *context{1}) {{\n" - " return Base::getChecked(emitError, context{2});\n}\n", - typeClass, - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::TypeNamePairs, - typeDefParams), - TypeParamCommaFormatter( - TypeParamCommaFormatter::EmitFormat::JustParams, typeDefParams)); - } - } - - auto builderFmtCtx = - FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get"); - auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get"); - auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context"); - - // Generate the builders specified by the user. - for (const TypeBuilder &builder : typeDef.getBuilders()) { - Optional body = builder.getBody(); - if (!body) - continue; - std::string paramStr; - llvm::raw_string_ostream paramOS(paramStr); - llvm::interleaveComma(builder.getParameters(), paramOS, - [&](const TypeBuilder::Parameter ¶m) { - // Note: TypeBuilder parameters are guaranteed to - // have names. - paramOS << param.getCppType() << " " - << *param.getName(); - }); - paramOS.flush(); - - // Emit the `get` variant of the builder. - os << llvm::formatv("{0} {0}::get(", typeClass); - if (!builder.hasInferredContextParameter()) { - os << "::mlir::MLIRContext *context"; - if (!paramStr.empty()) - os << ", "; - os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, - tgfmt(*body, &builderFmtCtx).str()); - } else { - os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, - tgfmt(*body, &inferredCtxBuilderFmtCtx).str()); - } - - // Emit the `getChecked` variant of the builder. - if (genCheckedMethods) { - os << llvm::formatv("{0} " - "{0}::getChecked(llvm::function_ref<::mlir::" - "InFlightDiagnostic()> emitErrorFn", - typeClass); - std::string checkedBody = - replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, "); - if (!builder.hasInferredContextParameter()) { - os << ", ::mlir::MLIRContext *context"; - checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str(); - } - if (!paramStr.empty()) - os << ", "; - os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody); - } - } -} - -/// Print all the typedef-specific definition code. -static void emitTypeDefDef(const TypeDef &typeDef, raw_ostream &os) { - NamespaceEmitter ns(os, typeDef.getDialect()); - - SmallVector parameters; - typeDef.getParameters(parameters); - if (!parameters.empty()) { - // Emit the storage class, if requested and necessary. - if (typeDef.genStorageClass()) - emitStorageClass(typeDef, os); - - // Emit the builders for this type. - emitTypeBuilderDefs(typeDef, os, parameters); - - // Generate accessor definitions only if we also generate the storage class. - // Otherwise, let the user define the exact accessor definition. - if (typeDef.genAccessors() && typeDef.genStorageClass()) { - // Emit the parameter accessors. - for (const TypeParameter ¶meter : parameters) { - SmallString<16> name = parameter.getName(); - name[0] = llvm::toUpper(name[0]); - os << formatv("{0} {3}::get{1}() const { return getImpl()->{2}; }\n", - parameter.getCppType(), name, parameter.getName(), - typeDef.getCppClassName()); - } - } - } - - // If mnemonic is specified maybe print definitions for the parser and printer - // code, if they're specified. - if (typeDef.getMnemonic()) - emitParserPrinter(typeDef, os); -} - -/// Emit the dialect printer/parser dispatcher. User's code should call these -/// functions from their dialect's print/parse methods. -static void emitParsePrintDispatch(ArrayRef types, raw_ostream &os) { - if (llvm::none_of(types, [](const TypeDef &type) { - return type.getMnemonic().hasValue(); - })) { - return; - } - - // The parser dispatch is just a list of if-elses, matching on the - // mnemonic and calling the class's parse function. - os << "static ::mlir::Type generatedTypeParser(::mlir::MLIRContext *" - "context, ::mlir::DialectAsmParser &parser, " - "::llvm::StringRef mnemonic) {\n"; - for (const TypeDef &type : types) { - if (type.getMnemonic()) { - os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) return " - "{0}::{1}::", - type.getDialect().getCppNamespace(), - type.getCppClassName()); - - // If the type has no parameters and no parser code, just invoke a normal - // `get`. - if (type.getNumParameters() == 0 && !type.getParserCode()) - os << "get(context);\n"; - else - os << "parse(context, parser);\n"; - } - } - os << " return ::mlir::Type();\n"; - os << "}\n\n"; - - // The printer dispatch uses llvm::TypeSwitch to find and call the correct - // printer. - os << "static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type " - "type, " - "::mlir::DialectAsmPrinter &printer) {\n" - << " return ::llvm::TypeSwitch<::mlir::Type, " - "::mlir::LogicalResult>(type)\n"; - for (const TypeDef &type : types) { - if (Optional mnemonic = type.getMnemonic()) { - StringRef cppNamespace = type.getDialect().getCppNamespace(); - StringRef cppClassName = type.getCppClassName(); - os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ", - cppNamespace, cppClassName); - - // If the type has no parameters and no printer code, just print the - // mnemonic. - if (type.getNumParameters() == 0 && !type.getPrinterCode()) - os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace, - cppClassName); - else - os << "t.print(printer);"; - os << "\n return ::mlir::success();\n })\n"; - } - } - os << " .Default([](::mlir::Type) { return ::mlir::failure(); });\n" - << "}\n\n"; -} - -/// Entry point for typedef definitions. -static bool emitTypeDefDefs(const llvm::RecordKeeper &recordKeeper, - raw_ostream &os) { - emitSourceFileHeader("TypeDef Definitions", os); - - SmallVector typeDefs; - findAllTypeDefs(recordKeeper, typeDefs); - emitTypeDefList(typeDefs, os); - - IfDefScope scope("GET_TYPEDEF_CLASSES", os); - emitParsePrintDispatch(typeDefs, os); - for (const TypeDef &typeDef : typeDefs) - emitTypeDefDef(typeDef, os); - - return false; -} - -//===----------------------------------------------------------------------===// -// GEN: TypeDef registration hooks -//===----------------------------------------------------------------------===// - -static mlir::GenRegistration - genTypeDefDefs("gen-typedef-defs", "Generate TypeDef definitions", - [](const llvm::RecordKeeper &records, raw_ostream &os) { - return emitTypeDefDefs(records, os); - }); - -static mlir::GenRegistration - genTypeDefDecls("gen-typedef-decls", "Generate TypeDef declarations", - [](const llvm::RecordKeeper &records, raw_ostream &os) { - return emitTypeDefDecls(records, os); - });