diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1029,10 +1029,11 @@ private: std::tuple Ranges; - template iterator begin_impl(std::index_sequence) { + template + iterator begin_impl(std::index_sequence) const { return iterator(std::get(Ranges)...); } - template iterator end_impl(std::index_sequence) { + template iterator end_impl(std::index_sequence) const { return iterator(make_range(std::end(std::get(Ranges)), std::end(std::get(Ranges)))...); } @@ -1041,8 +1042,12 @@ concat_range(RangeTs &&... Ranges) : Ranges(std::forward(Ranges)...) {} - iterator begin() { return begin_impl(std::index_sequence_for{}); } - iterator end() { return end_impl(std::index_sequence_for{}); } + iterator begin() const { + return begin_impl(std::index_sequence_for{}); + } + iterator end() const { + return end_impl(std::index_sequence_for{}); + } }; } // end namespace detail @@ -1997,10 +2002,16 @@ enumerator_iter begin() { return enumerator_iter(0, std::begin(TheRange)); } + enumerator_iter begin() const { + return enumerator_iter(0, std::begin(TheRange)); + } enumerator_iter end() { return enumerator_iter(std::end(TheRange)); } + enumerator_iter end() const { + return enumerator_iter(std::end(TheRange)); + } private: R TheRange; diff --git a/mlir/include/mlir/Support/IndentedOstream.h b/mlir/include/mlir/Support/IndentedOstream.h --- a/mlir/include/mlir/Support/IndentedOstream.h +++ b/mlir/include/mlir/Support/IndentedOstream.h @@ -29,28 +29,32 @@ /// Simple RAII struct to use to indentation around entering/exiting region. struct DelimitedScope { explicit DelimitedScope(raw_indented_ostream &os, StringRef open = "", - StringRef close = "") - : os(os), open(open), close(close) { + StringRef close = "", bool indent = true) + : os(os), open(open), close(close), indent(indent) { os << open; - os.indent(); + if (indent) + os.indent(); } ~DelimitedScope() { - os.unindent(); + if (indent) + os.unindent(); os << close; } raw_indented_ostream &os; private: - llvm::StringRef open, close; + StringRef open, close; + bool indent; }; /// Returns the underlying (unindented) raw_ostream. raw_ostream &getOStream() const { return os; } /// Returns DelimitedScope. - DelimitedScope scope(StringRef open = "", StringRef close = "") { - return DelimitedScope(*this, open, close); + DelimitedScope scope(StringRef open = "", StringRef close = "", + bool indent = true) { + return DelimitedScope(*this, open, close, indent); } /// Re-indents by removing the leading whitespace from the first non-empty diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -27,7 +27,6 @@ namespace mlir { namespace tblgen { class Dialect; -class AttrOrTypeParameter; //===----------------------------------------------------------------------===// // AttrOrTypeBuilder @@ -42,6 +41,76 @@ bool hasInferredContextParameter() const; }; +//===----------------------------------------------------------------------===// +// AttrOrTypeParameter +//===----------------------------------------------------------------------===// + +// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to +// AttrOrTypeDefs to parameterize them. +class AttrOrTypeParameter { +public: + explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index) + : def(def), index(index) {} + + // Get the parameter name. + StringRef getName() const; + + // If specified, get the custom allocator code for this parameter. + Optional getAllocator() const; + + // If specified, get the custom comparator code for this parameter. + Optional getComparator() const; + + // Get the C++ type of this parameter. + StringRef getCppType() const; + + // Get the C++ accessor type of this parameter. + StringRef getCppAccessorType() const; + + // Get the C++ storage type of this parameter. + StringRef getCppStorageType() const; + + // Get an optional C++ parameter parser. + Optional getParser() const; + + // Get an optional C++ parameter printer. + Optional getPrinter() const; + + // Get a description of this parameter for documentation purposes. + Optional getSummary() const; + + // Get the assembly syntax documentation. + StringRef getSyntax() const; + + // Return the underlying def of this parameter. + const llvm::Init *getDef() const; + + // The parameter is pointer-comparable. + bool operator==(const AttrOrTypeParameter &other) const { + return def == other.def && index == other.index; + } + bool operator!=(const AttrOrTypeParameter &other) const { + return !(*this == other); + } + +private: + /// The underlying tablegen parameter list this parameter is a part of. + const llvm::DagInit *def; + /// The index of the parameter within the parameter list (`def`). + unsigned index; +}; + +//===----------------------------------------------------------------------===// +// AttributeSelfTypeParameter +//===----------------------------------------------------------------------===// + +// A wrapper class for the AttributeSelfTypeParameter tblgen class. This +// represents a parameter of mlir::Type that is the value type of an AttrDef. +class AttributeSelfTypeParameter : public AttrOrTypeParameter { +public: + static bool classof(const AttrOrTypeParameter *param); +}; + //===----------------------------------------------------------------------===// // AttrOrTypeDef //===----------------------------------------------------------------------===// @@ -82,9 +151,8 @@ // Indicates whether or not to generate the storage class constructor. bool hasStorageCustomConstructor() const; - // Fill a list with this def's parameters. See AttrOrTypeDef in OpBase.td for - // documentation of parameter usage. - void getParameters(SmallVectorImpl &) const; + // Get a list of this attribute or type's parameters. + SmallVector getParameters() const; // Return the number of parameters unsigned getNumParameters() const; @@ -104,6 +172,19 @@ // Returns the custom assembly format, if one was specified. Optional getAssemblyFormat() const; + // An attribute or type with parameters needs a parser. + bool needsParserPrinter() const { return getNumParameters() != 0; } + + // Returns true if this attribute or type has a generated parser. + bool hasGeneratedParser() const { + return getParserCode() || getAssemblyFormat(); + } + + // Returns true if this attribute or type has a generated printer. + bool hasGeneratedPrinter() const { + return getPrinterCode() || getAssemblyFormat(); + } + // Returns true if the accessors based on the parameters should be generated. bool genAccessors() const; @@ -176,68 +257,6 @@ using AttrOrTypeDef::AttrOrTypeDef; }; -//===----------------------------------------------------------------------===// -// AttrOrTypeParameter -//===----------------------------------------------------------------------===// - -// A wrapper class for tblgen AttrOrTypeParameter, arrays of which belong to -// AttrOrTypeDefs to parameterize them. -class AttrOrTypeParameter { -public: - explicit AttrOrTypeParameter(const llvm::DagInit *def, unsigned index) - : def(def), index(index) {} - - // Get the parameter name. - StringRef getName() const; - - // If specified, get the custom allocator code for this parameter. - Optional getAllocator() const; - - // If specified, get the custom comparator code for this parameter. - Optional getComparator() const; - - // Get the C++ type of this parameter. - StringRef getCppType() const; - - // Get the C++ accessor type of this parameter. - StringRef getCppAccessorType() const; - - // Get the C++ storage type of this parameter. - StringRef getCppStorageType() const; - - // Get an optional C++ parameter parser. - Optional getParser() const; - - // Get an optional C++ parameter printer. - Optional getPrinter() const; - - // Get a description of this parameter for documentation purposes. - Optional getSummary() const; - - // Get the assembly syntax documentation. - StringRef getSyntax() const; - - // Return the underlying def of this parameter. - const llvm::Init *getDef() const; - -private: - /// The underlying tablegen parameter list this parameter is a part of. - const llvm::DagInit *def; - /// The index of the parameter within the parameter list (`def`). - unsigned index; -}; - -//===----------------------------------------------------------------------===// -// AttributeSelfTypeParameter -//===----------------------------------------------------------------------===// - -// A wrapper class for the AttributeSelfTypeParameter tblgen class. This -// represents a parameter of mlir::Type that is the value type of an AttrDef. -class AttributeSelfTypeParameter : public AttrOrTypeParameter { -public: - static bool classof(const AttrOrTypeParameter *param); -}; - } // end namespace tblgen } // end namespace mlir diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// // -// This file defines several classes for Op C++ code emission. They are only +// This file defines several classes for C++ code emission. They are only // expected to be used by MLIR TableGen backends. // -// We emit the op declaration and definition into separate files: *Ops.h.inc -// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and -// the latter for dialect *Ops.cpp. This way provides a cleaner interface. +// We emit the declarations and definitions into separate files: *.h.inc and +// *.cpp.inc. The former is to be included in the dialect *.h and the latter for +// dialect *.cpp. This way provides a cleaner interface. // // In order to do this split, we need to track method signature and // implementation logic separately. Signature information is used for both @@ -23,6 +23,7 @@ #ifndef MLIR_TABLEGEN_CLASS_H_ #define MLIR_TABLEGEN_CLASS_H_ +#include "mlir/Support/IndentedOstream.h" #include "mlir/Support/LLVM.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "llvm/ADT/SetVector.h" @@ -30,7 +31,6 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/Twine.h" -#include "llvm/Support/raw_ostream.h" #include #include @@ -61,9 +61,9 @@ /*defaultValue=*/"", optional) {} /// Write the parameter as part of a method declaration. - void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); } + void writeDeclTo(raw_indented_ostream &os) const; /// Write the parameter as part of a method definition. - void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); } + void writeDefTo(raw_indented_ostream &os) const; /// Get the C++ type. const std::string &getType() const { return type; } @@ -71,8 +71,6 @@ bool hasDefaultValue() const { return !defaultValue.empty(); } private: - void writeTo(raw_ostream &os, bool emitDefault) const; - /// The C++ type. std::string type; /// The variable name. @@ -95,9 +93,9 @@ : parameters(std::move(parameters)) {} /// Write the parameters as part of a method declaration. - void writeDeclTo(raw_ostream &os) const; + void writeDeclTo(raw_indented_ostream &os) const; /// Write the parameters as part of a method definition. - void writeDefTo(raw_ostream &os) const; + void writeDefTo(raw_indented_ostream &os) const; /// Determine whether this list of parameters "subsumes" another, which occurs /// when this parameter list is identical to the other and has zero or more @@ -119,10 +117,16 @@ SmallVector &¶meters) : returnType(retType), methodName(name), parameters(std::move(parameters)) {} + MethodSignature(StringRef retType, StringRef name, + ArrayRef parameters) + : MethodSignature(retType, name, + SmallVector(parameters.begin(), + parameters.end())) {} template MethodSignature(StringRef retType, StringRef name, Parameters &&...parameters) - : returnType(retType), methodName(name), - parameters({std::forward(parameters)...}) {} + : MethodSignature(retType, name, + ArrayRef( + {std::forward(parameters)...})) {} /// Determine whether a method with this signature makes a method with /// `other` signature redundant. This occurs if the signatures have the same @@ -140,12 +144,12 @@ unsigned getNumParameters() const { return parameters.getNumParameters(); } /// Write the signature as part of a method declaration. - void writeDeclTo(raw_ostream &os) const; + void writeDeclTo(raw_indented_ostream &os) const; /// Write the signature as part of a method definition. `namePrefix` is to be /// prepended to the method name (typically namespaces for qualifying the /// method definition). - void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + void writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const; private: /// The method's C++ return type. @@ -156,60 +160,152 @@ MethodParameters parameters; }; -// Class for holding the body of an op's method for C++ code emission +/// This class contains the body of a C++ method. class MethodBody { public: - explicit MethodBody(bool declOnly); + /// Create a method body, indicating whether it should be elided for methods + /// that are declaration-only. + MethodBody(bool declOnly); + + /// Define a move constructor to correctly initialize the streams. + MethodBody(MethodBody &&other) + : declOnly(other.declOnly), body(std::move(other.body)), stringOs(body), + os(stringOs) {} + + /// Write a value to the method body. + template + MethodBody &operator<<(ValueT &&value) { + if (!declOnly) { + os << std::forward(value); + os.flush(); + } + return *this; + } - MethodBody &operator<<(Twine content); - MethodBody &operator<<(int content); - MethodBody &operator<<(const FmtObjectBase &content); + void writeTo(raw_indented_ostream &os) const; - void writeTo(raw_ostream &os) const; + /// Expose methods to manipulate the indented output stream. + MethodBody &indent() { + os.indent(); + return *this; + } + MethodBody &unindent() { + os.unindent(); + return *this; + } + raw_indented_ostream::DelimitedScope + scope(StringRef open = "", StringRef close = "", bool indent = false) { + return os.scope(open, close, indent); + } + raw_indented_ostream &getStream() { return os; } private: - // Whether this class should record method body. - bool isEffective; + /// Whether the body should be elided. + bool declOnly; + /// The body data. std::string body; + /// The string output stream. + llvm::raw_string_ostream stringOs; + /// An indented output stream for formatting input. + raw_indented_ostream os; +}; + +/// A class declaration is a class element that appears as part of its +/// declaration. +class ClassDeclaration { +public: + virtual ~ClassDeclaration() = default; + + /// Kinds for LLVM-style RTTI. + enum Kind { + Method, + UsingDeclaration, + VisibilityDeclaration, + Field, + ExtraClassDeclaration + }; + ClassDeclaration(Kind kind) : kind(kind) {} + + /// Get the class declaration kind. + Kind getKind() const { return kind; } + + /// Write the declaration. + virtual void writeDeclTo(raw_indented_ostream &os) const = 0; + + /// Write the definition, if any. `namePrefix` is the namespace prefix, which + /// may contains a class name. + virtual void writeDefTo(raw_indented_ostream &os, + StringRef namePrefix) const {} + +private: + /// The class declaration kind. + Kind kind; +}; + +/// Base class for class declarations. +template +class ClassDeclarationBase : public ClassDeclaration { +public: + using Base = ClassDeclarationBase; + ClassDeclarationBase() : ClassDeclaration(DeclKind) {} + + static bool classof(const ClassDeclaration *other) { + return other->getKind() == DeclKind; + } }; // Class for holding an op's method for C++ code emission -class Method { +class Method : public ClassDeclarationBase { public: // Properties (qualifiers) of class methods. Bitfield is used here to help // querying properties. - enum Property { - MP_None = 0x0, - MP_Static = 0x1, - MP_Constructor = 0x2, - MP_Private = 0x4, - MP_Declaration = 0x8, - MP_Inline = 0x10, - MP_Constexpr = 0x20 | MP_Inline, - MP_StaticDeclaration = MP_Static | MP_Declaration, + enum Properties { + None = 0x0, + Static = 0x1, + Constructor = 0x2, + Private = 0x4, + Declaration = 0x8, + Inline = 0x10, + ConstexprValue = 0x20, + Const = 0x40, + + Constexpr = ConstexprValue | Inline, + StaticDeclaration = Static | Declaration, + StaticInline = Static | Inline, + ConstInline = Const | Inline, + ConstDeclaration = Const | Declaration }; template - Method(StringRef retType, StringRef name, Property property, Args &&...args) - : properties(property), + Method(StringRef retType, StringRef name, Properties properties, + Args &&...args) + : properties(properties), methodSignature(retType, name, std::forward(args)...), - methodBody(properties & MP_Declaration) {} + methodBody(properties & Declaration) {} + Method(StringRef retType, StringRef name, Properties properties, + std::initializer_list params) + : properties(properties), methodSignature(retType, name, params), + methodBody(properties & Declaration) {} Method(Method &&) = default; Method &operator=(Method &&) = default; - virtual ~Method() = default; - MethodBody &body() { return methodBody; } // Returns true if this is a static method. - bool isStatic() const { return properties & MP_Static; } + bool isStatic() const { return properties & Static; } // Returns true if this is a private method. - bool isPrivate() const { return properties & MP_Private; } + bool isPrivate() const { return properties & Private; } // Returns true if this is an inline method. - bool isInline() const { return properties & MP_Inline; } + bool isInline() const { return properties & Inline; } + + // Returns true if this is a constructor. + bool isConstructor() const { return properties & Constructor; } + + /// Returns true if this class method is const. + bool isConst() const { return properties & Const; } // Returns the name of this method. StringRef getName() const { return methodSignature.getName(); } @@ -219,158 +315,301 @@ return methodSignature.makesRedundant(other.methodSignature); } - // Writes the method as a declaration to the given `os`. - virtual void writeDeclTo(raw_ostream &os) const; - - // Writes the method as a definition to the given `os`. `namePrefix` is the - // prefix to be prepended to the method name (typically namespaces for - // qualifying the method definition). - virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + void writeDeclTo(raw_indented_ostream &os) const override; + void writeDefTo(raw_indented_ostream &os, + StringRef namePrefix) const override; protected: /// A collection of method properties. - Property properties; + Properties properties; /// The signature of the method. MethodSignature methodSignature; /// The body of the method, if it has one. MethodBody methodBody; }; +/// This enum describes C++ inheritance visibility. +enum class Visibility { Public, Protected, Private }; + +// Class for holding an op's constructor method for C++ code emission. +class Constructor : public Method { +public: + template + Constructor(StringRef className, Properties properties, Args &&...args) + : Method("", className, properties, std::forward(args)...) {} + + // Add member initializer to constructor initializing `name` with `value`. + template + void addMemberInitializer(NameT name, ValueT value) { + initializers.emplace_back(stringify(name), stringify(value)); + } + + void writeDeclTo(raw_indented_ostream &os) const override; + void writeDefTo(raw_indented_ostream &os, + StringRef namePrefix) const override; + + /// Return true if a method is a constructor. + static bool classof(const ClassDeclaration *other) { + return isa(other) && cast(other)->isConstructor(); + } + + /// Initialization of a class field in a constructor. + class MemberInitializer { + public: + MemberInitializer(std::string name, std::string value) + : name(std::move(name)), value(std::move(value)) {} + + void writeTo(raw_indented_ostream &os) const; + + private: + /// The name of the class field. + std::string name; + /// The value with which to initialize it. + std::string value; + }; + +private: + /// The list of member initializers. + SmallVector initializers; +}; + } // end namespace tblgen } // end namespace mlir +/// Write "public", "protected", or "private". +llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + mlir::tblgen::Visibility visibility); + /// The OR of two method properties should return method properties. Ensure that /// this function is visible to `Class`. -inline constexpr mlir::tblgen::Method::Property -operator|(mlir::tblgen::Method::Property lhs, - mlir::tblgen::Method::Property rhs) { - return mlir::tblgen::Method::Property(static_cast(lhs) | - static_cast(rhs)); +inline constexpr mlir::tblgen::Method::Properties +operator|(mlir::tblgen::Method::Properties lhs, + mlir::tblgen::Method::Properties rhs) { + return mlir::tblgen::Method::Properties(static_cast(lhs) | + static_cast(rhs)); +} + +/// Select a property with a bool. +inline constexpr mlir::tblgen::Method::Properties +operator&(mlir::tblgen::Method::Properties lhs, bool rhs) { + return rhs ? lhs : mlir::tblgen::Method::None; +} +inline constexpr bool operator&(mlir::tblgen::Method::Properties lhs, + mlir::tblgen::Method::Properties rhs) { + return static_cast(lhs) & static_cast(rhs); } namespace mlir { namespace tblgen { -// Class for holding an op's constructor method for C++ code emission. -class Constructor : public Method { +/// This class describes a C++ parent class. +class ParentClass { public: - template - Constructor(StringRef className, Property property, - Parameters &&...parameters) - : Method("", className, property, - std::forward(parameters)...) {} + ParentClass(std::string name, Visibility visibility = Visibility::Public) + : name(std::move(name)), visibility(visibility) {} - // Add member initializer to constructor initializing `name` with `value`. - void addMemberInitializer(StringRef name, StringRef value); + /// Add a template parameter. + void addTemplateParam(std::string param) { + templateParams.insert(std::move(param)); + } + template + void addTemplateParams(ContainerT &&container) { + templateParams.insert(std::begin(container), std::end(container)); + } + + void writeTo(raw_indented_ostream &os) const; + +private: + /// The fully resolved C++ name of the parent class. + std::string name; + /// The visibility of the parent class. + Visibility visibility; + /// An optional list of class template parameters. + SetVector, StringSet<>> templateParams; +}; + +/// This class describes a using-declaration for a class. E.g. +/// +/// using Op::Op; +/// using Adaptor = OpAdaptor; +/// +class UsingDeclaration + : public ClassDeclarationBase { +public: + UsingDeclaration(std::string name, std::string value = "") + : name(std::move(name)), value(std::move(value)) {} - // Writes the method as a definition to the given `os`. `namePrefix` is the - // prefix to be prepended to the method name (typically namespaces for - // qualifying the method definition). - void writeDefTo(raw_ostream &os, StringRef namePrefix) const override; + void writeDeclTo(raw_indented_ostream &os) const override; private: - // Member initializers. - std::string memberInitializers; + /// The name of the declaration, or a resolved name to an inherited function. + std::string name; + /// The type that is being aliased. Leave empty for inheriting functions. + std::string value; +}; + +/// This class describes a class field. Class fields are always private and are +/// always declared at the bottom of the class declaration. +class Field : public ClassDeclarationBase { +public: + Field(std::string type, std::string name, bool publicField = false) + : type(std::move(type)), name(std::move(name)), publicField(publicField) { + } + + void writeDeclTo(raw_indented_ostream &os) const override; + bool isPublic() const { return publicField; } + +private: + std::string type; + std::string name; + const bool publicField; +}; + +/// A declaration for the visibility of subsequent declarations. +class VisibilityDeclaration + : public ClassDeclarationBase { +public: + VisibilityDeclaration(Visibility visibility) : visibility(visibility) {} + + /// Get the visibility. + Visibility getVisibility() const { return visibility; } + + void writeDeclTo(raw_indented_ostream &os) const override; + +private: + Visibility visibility; +}; + +/// Unstructured extra class declarations. +class ExtraClassDeclaration + : public ClassDeclarationBase { +public: + ExtraClassDeclaration(StringRef extraClassDeclaration) + : extraClassDeclaration(extraClassDeclaration) {} + + void writeDeclTo(raw_indented_ostream &os) const override; + +private: + StringRef extraClassDeclaration; }; // A class used to emit C++ classes from Tablegen. Contains a list of public // methods and a list of private fields to be emitted. class Class { public: - explicit Class(StringRef name); + Class(std::string name, bool isStruct = false) + : className(std::move(name)), isStruct(isStruct) {} + virtual ~Class() = default; /// Add a new constructor to this class and prune and constructors made /// redundant by it. Returns null if the constructor was not added. Else, /// returns a pointer to the new constructor. - template - Constructor *addConstructorAndPrune(Parameters &&...parameters) { - return addConstructorAndPrune( - Constructor(getClassName(), Method::MP_Constructor, - std::forward(parameters)...)); + template + Constructor *addConstructor(Args &&...args) { + return addConstructorAndPrune(Constructor(getClassName(), + Properties | Method::Constructor, + std::forward(args)...)); } /// Add a new method to this class and prune any methods made redundant by it. /// Returns null if the method was not added (because an existing method would /// make it redundant). Else, returns a pointer to the new method. - template + template Method *addMethod(StringRef retType, StringRef name, - Method::Property properties, Parameters &&...parameters) { - return addMethodAndPrune(Method(retType, name, properties, - std::forward(parameters)...)); + Method::Properties properties, Args &&...args) { + return addMethodAndPrune(Method(retType, name, Properties | properties, + std::forward(args)...)); } /// Add a method with statically-known properties. - template - Method *addMethod(StringRef retType, StringRef name, - Parameters &&...parameters) { - return addMethod(retType, name, Properties, - std::forward(parameters)...); + template + Method *addMethod(StringRef retType, StringRef name, Args &&...args) { + return addMethod(retType, name, Properties, std::forward(args)...); } /// Add a static method. - template - Method *addStaticMethod(StringRef retType, StringRef name, - Parameters &&...parameters) { - return addMethod( - retType, name, std::forward(parameters)...); + template + Method *addStaticMethod(StringRef retType, StringRef name, Args &&...args) { + return addMethod(retType, name, + std::forward(args)...); } /// Add an inline static method. - template + template Method *addStaticInlineMethod(StringRef retType, StringRef name, - Parameters &&...parameters) { - return addMethod( - retType, name, std::forward(parameters)...); + Args &&...args) { + return addMethod( + retType, name, std::forward(args)...); } /// Add an inline method. - template - Method *addInlineMethod(StringRef retType, StringRef name, - Parameters &&...parameters) { - return addMethod( - retType, name, std::forward(parameters)...); + template + Method *addInlineMethod(StringRef retType, StringRef name, Args &&...args) { + return addMethod(retType, name, + std::forward(args)...); + } + + /// Add a const method. + template + Method *addConstMethod(StringRef retType, StringRef name, Args &&...args) { + return addMethod(retType, name, + std::forward(args)...); } /// Add a declaration for a method. - template - Method *declareMethod(StringRef retType, StringRef name, - Parameters &&...parameters) { - return addMethod( - retType, name, std::forward(parameters)...); + template + Method *declareMethod(StringRef retType, StringRef name, Args &&...args) { + return addMethod( + retType, name, std::forward(args)...); } /// Add a declaration for a static method. - template + template Method *declareStaticMethod(StringRef retType, StringRef name, - Parameters &&...parameters) { - return addMethod( - retType, name, std::forward(parameters)...); + Args &&...args) { + return addMethod( + retType, name, std::forward(args)...); } - // Creates a new field in this class. - void newField(StringRef type, StringRef name, StringRef defaultValue = ""); + /// Add a new field to the class. + void addField(std::string type, std::string name); - // Writes this op's class as a declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const; - // Writes the method definitions in this op's class to the given `os`. - void writeDefTo(raw_ostream &os) const; + /// Add a declaration. + template + DeclT *declare(Args &&...args) { + auto decl = std::make_unique(std::forward(args)...); + auto *ret = decl.get(); + declarations.push_back(std::move(decl)); + return ret; + } + + /// Add a parent class. + ParentClass &addParent(ParentClass parent); + ParentClass &addParent(std::string name, + Visibility visibility = Visibility::Public); - // Returns the C++ class name of the op. + /// Return the C++ name of the class. StringRef getClassName() const { return className; } -protected: - // Get a list of all the methods to emit, filtering out hidden ones. - void forAllMethods(llvm::function_ref func) const { - llvm::for_each(constructors, [&](auto &ctor) { func(ctor); }); - llvm::for_each(methods, [&](auto &method) { func(method); }); + void writeDeclTo(raw_ostream &rawOs) const { + raw_indented_ostream os(rawOs); + writeDeclTo(os); } + void writeDefTo(raw_ostream &rawOs) const { + raw_indented_ostream os(rawOs); + writeDefTo(os); + } + + /// Write the declaration of this class, including its parent classes, + /// constructors, methods and inline method bodies, and fields. + void writeDeclTo(raw_indented_ostream &os) const; + /// Write the definitions of thiss class's out-of-line constructors and + /// methods. + void writeDefTo(raw_indented_ostream &os) const; + virtual void finalize(); + +protected: /// Add a new constructor if it is not made redundant by any existing /// constructors and prune and existing constructors made redundant. Constructor *addConstructorAndPrune(Constructor &&newCtor); @@ -378,31 +617,22 @@ /// prune and existing methods made redundant. Method *addMethodAndPrune(Method &&newMethod); + /// Get the last visibility declaration. + Visibility getLastVisibilityDecl() const; + /// The C++ class name. std::string className; - /// The list of constructors. - std::vector constructors; - /// The list of class methods. - std::vector methods; - /// The list of class members. - SmallVector fields; -}; - -// Class for holding an op for C++ code emission -class OpClass : public Class { -public: - explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); - - // Adds an op trait. - void addTrait(Twine trait); - - // Writes this op's class as a declaration to the given `os`. Redefines - // Class::writeDeclTo to also emit traits and extra class declarations. - void writeDeclTo(raw_ostream &os) const; - -private: - StringRef extraClassDeclaration; - llvm::SetVector, StringSet<>> traits; + /// The list of parent classes. + SmallVector parents; + /// A list of methods. + std::vector> methods; + /// A list of declarations in the class, emitted in order. + std::vector> declarations; + /// A list of class fields, always private and emitted at the end of the class + /// declaration. + SmallVector fields; + /// Whether this is a `class` or a `struct`. + const bool isStruct; }; } // namespace tblgen diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -29,6 +29,12 @@ class Constraint; class DagLeaf; +// Format into a std::string +template +std::string strfmt(const char *fmt, Parameters &&...parameters) { + return llvm::formatv(fmt, std::forward(parameters)...).str(); +} + // Simple RAII helper for defining ifdef-undef-endif scopes. class IfDefScope { public: @@ -58,7 +64,7 @@ ~NamespaceEmitter() { for (StringRef ns : llvm::reverse(namespaces)) - os << "} // namespace " << ns << "\n"; + os << "} // end namespace " << ns << "\n"; } private: @@ -230,6 +236,13 @@ return twine.str(); } }; +template +struct stringifier> { + static std::string apply(Optional optional) { + return optional.hasValue() ? stringifier::apply(*optional) + : std::string(); + } +}; } // end namespace detail /// Generically convert a value to a std::string. diff --git a/mlir/include/mlir/TableGen/Format.h b/mlir/include/mlir/TableGen/Format.h --- a/mlir/include/mlir/TableGen/Format.h +++ b/mlir/include/mlir/TableGen/Format.h @@ -50,6 +50,9 @@ FmtContext() = default; + // Create a format context with a list of substitutions. + FmtContext(std::initializer_list> subs); + // Setter for custom placeholders FmtContext &addSubst(StringRef placeholder, Twine subst); diff --git a/mlir/lib/Support/IndentedOstream.cpp b/mlir/lib/Support/IndentedOstream.cpp --- a/mlir/lib/Support/IndentedOstream.cpp +++ b/mlir/lib/Support/IndentedOstream.cpp @@ -16,19 +16,29 @@ using namespace mlir; raw_indented_ostream &mlir::raw_indented_ostream::reindent(StringRef str) { - StringRef remaining = str; - // Find leading whitespace indent. - while (!remaining.empty()) { - auto split = remaining.split('\n'); + StringRef output = str; + // Skip empty lines. + while (!output.empty()) { + auto split = output.split('\n'); size_t indent = split.first.find_first_not_of(" \t"); if (indent != StringRef::npos) { + // Set an initial value. leadingWs = indent; break; } + output = split.second; + } + // Determine the maximum indent. + StringRef remaining = output; + while (!remaining.empty()) { + auto split = remaining.split('\n'); + size_t indent = split.first.find_first_not_of(" \t"); + if (indent != StringRef::npos) + leadingWs = std::min(leadingWs, static_cast(indent)); remaining = split.second; } // Print, skipping the empty lines. - *this << remaining; + *this << output; leadingWs = 0; return *this; } diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -107,12 +107,13 @@ return def->getValueAsBit("hasStorageCustomConstructor"); } -void AttrOrTypeDef::getParameters( - SmallVectorImpl ¶meters) const { +SmallVector AttrOrTypeDef::getParameters() const { + SmallVector parameters; if (auto *parametersDag = def->getValueAsDag("parameters")) { for (unsigned i = 0, e = parametersDag->getNumArgs(); i < e; ++i) parameters.push_back(AttrOrTypeParameter(parametersDag, i)); } + return parameters; } unsigned AttrOrTypeDef::getNumParameters() const { diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp --- a/mlir/lib/TableGen/Class.cpp +++ b/mlir/lib/TableGen/Class.cpp @@ -7,15 +7,10 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Class.h" - #include "mlir/TableGen/Format.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include - -#define DEBUG_TYPE "mlir-tblgen-opclass" using namespace mlir; using namespace mlir::tblgen; @@ -30,23 +25,29 @@ // MethodParameter definitions //===----------------------------------------------------------------------===// -void MethodParameter::writeTo(raw_ostream &os, bool emitDefault) const { +void MethodParameter::writeDeclTo(raw_indented_ostream &os) const { if (optional) os << "/*optional*/"; os << type << getSpaceAfterType(type) << name; - if (emitDefault && hasDefaultValue()) + if (hasDefaultValue()) os << " = " << defaultValue; } +void MethodParameter::writeDefTo(raw_indented_ostream &os) const { + if (optional) + os << "/*optional*/"; + os << type << getSpaceAfterType(type) << name; +} + //===----------------------------------------------------------------------===// // MethodParameters definitions //===----------------------------------------------------------------------===// -void MethodParameters::writeDeclTo(raw_ostream &os) const { +void MethodParameters::writeDeclTo(raw_indented_ostream &os) const { llvm::interleaveComma(parameters, os, [&os](auto ¶m) { param.writeDeclTo(os); }); } -void MethodParameters::writeDefTo(raw_ostream &os) const { +void MethodParameters::writeDefTo(raw_indented_ostream &os) const { llvm::interleaveComma(parameters, os, [&os](auto ¶m) { param.writeDefTo(os); }); } @@ -78,14 +79,14 @@ parameters.subsumes(other.parameters); } -void MethodSignature::writeDeclTo(raw_ostream &os) const { +void MethodSignature::writeDeclTo(raw_indented_ostream &os) const { os << returnType << getSpaceAfterType(returnType) << methodName << "("; parameters.writeDeclTo(os); os << ")"; } -void MethodSignature::writeDefTo(raw_ostream &os, - StringRef namePrefix) const { +void MethodSignature::writeDefTo(raw_indented_ostream &os, + StringRef namePrefix) const { os << returnType << getSpaceAfterType(returnType) << namePrefix << (namePrefix.empty() ? "" : "::") << methodName << "("; parameters.writeDefTo(os); @@ -96,30 +97,15 @@ // MethodBody definitions //===----------------------------------------------------------------------===// -MethodBody::MethodBody(bool declOnly) : isEffective(!declOnly) {} - -MethodBody &MethodBody::operator<<(Twine content) { - if (isEffective) - body.append(content.str()); - return *this; -} - -MethodBody &MethodBody::operator<<(int content) { - if (isEffective) - body.append(std::to_string(content)); - return *this; -} +MethodBody::MethodBody(bool declOnly) + : declOnly(declOnly), stringOs(body), os(stringOs) {} -MethodBody &MethodBody::operator<<(const FmtObjectBase &content) { - if (isEffective) - body.append(content.str()); - return *this; -} - -void MethodBody::writeTo(raw_ostream &os) const { +void MethodBody::writeTo(raw_indented_ostream &os) const { auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); os << bodyRef; - if (bodyRef.empty() || bodyRef.back() != '\n') + if (bodyRef.empty()) + return; + if (bodyRef.back() != '\n') os << "\n"; } @@ -127,173 +113,247 @@ // Method definitions //===----------------------------------------------------------------------===// -void Method::writeDeclTo(raw_ostream &os) const { - os.indent(2); +void Method::writeDeclTo(raw_indented_ostream &os) const { if (isStatic()) os << "static "; - if ((properties & MP_Constexpr) == MP_Constexpr) + if (properties & ConstexprValue) os << "constexpr "; methodSignature.writeDeclTo(os); + if (isConst()) + os << " const"; if (!isInline()) { - os << ";"; - } else { - os << " {\n"; - methodBody.writeTo(os.indent(2)); - os.indent(2) << "}"; + os << ";\n\n"; + return; } + os << " {\n"; + methodBody.writeTo(os); + os << "}\n\n"; } -void Method::writeDefTo(raw_ostream &os, StringRef namePrefix) const { - // Do not write definition if the method is decl only. - if (properties & MP_Declaration) - return; - // Do not generate separate definition for inline method - if (isInline()) +void Method::writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const { + /// The method has no definition to write if it is declaration only or inline. + if (properties & Declaration || isInline()) return; + methodSignature.writeDefTo(os, namePrefix); + if (isConst()) + os << " const"; os << " {\n"; methodBody.writeTo(os); - os << "}"; + os << "}\n\n"; } //===----------------------------------------------------------------------===// // Constructor definitions //===----------------------------------------------------------------------===// -void Constructor::addMemberInitializer(StringRef name, StringRef value) { - memberInitializers.append(std::string(llvm::formatv( - "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value))); +void Constructor::writeDeclTo(raw_indented_ostream &os) const { + if (properties & ConstexprValue) + os << "constexpr "; + methodSignature.writeDeclTo(os); + if (!isInline()) { + os << ";\n\n"; + return; + } + os << ' '; + if (!initializers.empty()) + os << ": "; + llvm::interleaveComma(initializers, os, + [&](auto &initializer) { initializer.writeTo(os); }); + if (!initializers.empty()) + os << ' '; + os << "{"; + methodBody.writeTo(os); + os << "}\n\n"; } -void Constructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const { - // Do not write definition if the method is decl only. - if (properties & MP_Declaration) +void Constructor::writeDefTo(raw_indented_ostream &os, + StringRef namePrefix) const { + /// The method has no definition to write if it is declaration only or inline. + if (properties & Declaration || isInline()) return; methodSignature.writeDefTo(os, namePrefix); - os << " " << memberInitializers << " {\n"; + os << ' '; + if (!initializers.empty()) + os << ": "; + llvm::interleaveComma(initializers, os, + [&](auto &initializer) { initializer.writeTo(os); }); + if (!initializers.empty()) + os << ' '; + os << "{"; methodBody.writeTo(os); - os << "}\n"; + os << "}\n\n"; +} + +void Constructor::MemberInitializer::writeTo(raw_indented_ostream &os) const { + os << name << '(' << value << ')'; } //===----------------------------------------------------------------------===// -// Class definitions +// Visibility definitions //===----------------------------------------------------------------------===// -Class::Class(StringRef name) : className(name) {} - -void Class::newField(StringRef type, StringRef name, StringRef defaultValue) { - std::string varName = formatv("{0} {1}", type, name).str(); - std::string field = defaultValue.empty() - ? varName - : formatv("{0} = {1}", varName, defaultValue).str(); - fields.push_back(std::move(field)); +raw_ostream &operator<<(raw_ostream &os, Visibility visibility) { + switch (visibility) { + case Visibility::Public: + return os << "public"; + case Visibility::Protected: + return os << "protected"; + default: + assert(visibility == Visibility::Private); + return os << "private"; + } } -void Class::writeDeclTo(raw_ostream &os) const { - bool hasPrivateMethod = false; - os << "class " << className << " {\n"; - os << "public:\n"; - - forAllMethods([&](const Method &method) { - if (!method.isPrivate()) { - method.writeDeclTo(os); - os << '\n'; - } else { - hasPrivateMethod = true; - } - }); +//===----------------------------------------------------------------------===// +// ParentClass definitions +//===----------------------------------------------------------------------===// - os << '\n'; - os << "private:\n"; - if (hasPrivateMethod) { - forAllMethods([&](const Method &method) { - if (method.isPrivate()) { - method.writeDeclTo(os); - os << '\n'; - } - }); - os << '\n'; +void ParentClass::writeTo(raw_indented_ostream &os) const { + os << visibility << ' ' << name; + if (!templateParams.empty()) { + auto scope = os.scope("<", ">", /*indent=*/false); + llvm::interleaveComma(templateParams, os, + [&](auto ¶m) { os << param; }); } - - for (const auto &field : fields) - os.indent(2) << field << ";\n"; - os << "};\n"; } -void Class::writeDefTo(raw_ostream &os) const { - forAllMethods([&](const Method &method) { - method.writeDefTo(os, className); - os << "\n"; - }); +//===----------------------------------------------------------------------===// +// UsingDeclaration definitions +//===----------------------------------------------------------------------===// + +void UsingDeclaration::writeDeclTo(raw_indented_ostream &os) const { + os << "using " << name; + if (!value.empty()) + os << " = " << value; + os << ";\n"; } -/// Insert a new method into a list of methods, if it would not be pruned, and -/// prune and existing methods. -template -MethodT *insertAndPrune(ContainerT &methods, MethodT newMethod) { - if (llvm::any_of(methods, [&](auto &method) { - return method.makesRedundant(newMethod); - })) - return nullptr; +//===----------------------------------------------------------------------===// +// Field definitions +//===----------------------------------------------------------------------===// - llvm::erase_if( - methods, [&](auto &method) { return newMethod.makesRedundant(method); }); - methods.push_back(std::move(newMethod)); - return &methods.back(); +void Field::writeDeclTo(raw_indented_ostream &os) const { + os << type << ' ' << name << ";\n"; } -Method *Class::addMethodAndPrune(Method &&newMethod) { - return insertAndPrune(methods, std::move(newMethod)); +//===----------------------------------------------------------------------===// +// VisibilityDeclaration definitions +//===----------------------------------------------------------------------===// + +void VisibilityDeclaration::writeDeclTo(raw_indented_ostream &os) const { + os.unindent(); + os << visibility << ":\n"; + os.indent(); } -Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) { - return insertAndPrune(constructors, std::move(newCtor)); +//===----------------------------------------------------------------------===// +// ExtraClassDeclaration definitions +//===----------------------------------------------------------------------===// + +void ExtraClassDeclaration::writeDeclTo(raw_indented_ostream &os) const { + os.reindent(extraClassDeclaration); } //===----------------------------------------------------------------------===// -// OpClass definitions +// Class definitions //===----------------------------------------------------------------------===// -OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) - : Class(name), extraClassDeclaration(extraClassDeclaration) {} +ParentClass &Class::addParent(ParentClass parent) { + parents.push_back(std::move(parent)); + return parents.back(); +} -void OpClass::addTrait(Twine trait) { - traits.insert(trait.str()); +ParentClass &Class::addParent(std::string name, Visibility visibility) { + return addParent({std::move(name), visibility}); } -void OpClass::writeDeclTo(raw_ostream &os) const { - os << "class " << className << " : public ::mlir::Op<" << className; - for (const auto &trait : traits) - os << ", " << trait; - os << "> {\npublic:\n" - << " using Op::Op;\n" - << " using Op::print;\n" - << " using Adaptor = " << className << "Adaptor;\n"; - - bool hasPrivateMethod = false; - forAllMethods([&](const Method &method) { - if (!method.isPrivate()) { - method.writeDeclTo(os); - os << "\n"; - } else { - hasPrivateMethod = true; - } - }); +void Class::addField(std::string type, std::string name) { + fields.emplace_back(std::move(type), std::move(name)); +} + +void Class::writeDeclTo(raw_indented_ostream &os) const { + /// Declare the class. + os << (isStruct ? "struct" : "class") << ' ' << className << ' '; + + /// Declare the parent classes, if any. + if (!parents.empty()) { + os << ": "; + llvm::interleaveComma(parents, os, + [&](auto &parent) { parent.writeTo(os); }); + os << ' '; + } + auto classScope = os.scope("{\n", "};\n", /*indent=*/true); + + /// Print all the class declarations. + for (auto &decl : declarations) + decl->writeDeclTo(os); +} + +void Class::writeDefTo(raw_indented_ostream &os) const { + /// Print all the definitions. + for (auto &decl : declarations) + decl->writeDefTo(os, className); +} - // TODO: Add line control markers to make errors easier to debug. - if (!extraClassDeclaration.empty()) - os << extraClassDeclaration << "\n"; - - if (hasPrivateMethod) { - os << "\nprivate:\n"; - forAllMethods([&](const Method &method) { - if (method.isPrivate()) { - method.writeDeclTo(os); - os << "\n"; - } - }); +void Class::finalize() { + SmallVector> publicMethods, privateMethods; + for (auto &method : methods) { + (method->isPrivate() ? privateMethods : publicMethods) + .push_back(std::move(method)); } + methods.clear(); + + const auto declareMethod = [&](auto &method) { + declarations.push_back(std::move(method)); + }; + + if (!publicMethods.empty() && getLastVisibilityDecl() != Visibility::Public) + declare(Visibility::Public); + llvm::for_each(publicMethods, declareMethod); + + if (!privateMethods.empty() && getLastVisibilityDecl() != Visibility::Private) + declare(Visibility::Private); + llvm::for_each(privateMethods, declareMethod); + + if (!fields.empty() && getLastVisibilityDecl() != Visibility::Private) + declare(Visibility::Private); + for (auto &field : fields) + declare(std::move(field)); + fields.clear(); +} - os << "};\n"; +Visibility Class::getLastVisibilityDecl() const { + auto reverseDecls = llvm::reverse(declarations); + auto it = llvm::find_if(reverseDecls, [](auto &decl) { + return isa(decl); + }); + return it == reverseDecls.end() + ? (isStruct ? Visibility::Public : Visibility::Private) + : cast(*it).getVisibility(); +} + +Method *insertAndPruneMethods(std::vector> &methods, + std::unique_ptr newMethod) { + if (llvm::any_of(methods, [&](auto &method) { + return method->makesRedundant(*newMethod); + })) + return nullptr; + + llvm::erase_if(methods, [&](auto &method) { + return newMethod->makesRedundant(*method); + }); + methods.push_back(std::move(newMethod)); + return methods.back().get(); +} + +Method *Class::addMethodAndPrune(Method &&newMethod) { + return insertAndPruneMethods(methods, + std::make_unique(std::move(newMethod))); +} + +Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) { + return dyn_cast_or_null(insertAndPruneMethods( + methods, std::make_unique(std::move(newCtor)))); } diff --git a/mlir/lib/TableGen/Format.cpp b/mlir/lib/TableGen/Format.cpp --- a/mlir/lib/TableGen/Format.cpp +++ b/mlir/lib/TableGen/Format.cpp @@ -21,6 +21,12 @@ // Marker to indicate an error happened when replacing a placeholder. const char *const kMarkerForNoSubst = ""; +FmtContext::FmtContext( + std::initializer_list> subs) { + for (auto &sub : subs) + addSubst(sub.first, sub.second); +} + FmtContext &FmtContext::addSubst(StringRef placeholder, Twine subst) { customSubstMap[placeholder] = subst.str(); return *this; diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -35,7 +35,7 @@ /// Check simple attribute parser and printer are generated correctly. // ATTR: ::mlir::Attribute TestAAttr::parse(::mlir::AsmParser &parser, -// ATTR: ::mlir::Type attrType) { +// ATTR: ::mlir::Type type) { // ATTR: FailureOr _result_value; // ATTR: FailureOr _result_complex; // ATTR: if (parser.parseKeyword("hello")) @@ -47,7 +47,7 @@ // ATTR: return {}; // ATTR: if (parser.parseComma()) // ATTR: return {}; -// ATTR: _result_complex = ::parseAttrParamA(parser, attrType); +// ATTR: _result_complex = ::parseAttrParamA(parser, type); // ATTR: if (failed(_result_complex)) // ATTR: return {}; // ATTR: if (parser.parseRParen()) @@ -81,7 +81,7 @@ /// Test simple struct parser and printer are generated correctly. // ATTR: ::mlir::Attribute TestBAttr::parse(::mlir::AsmParser &parser, -// ATTR: ::mlir::Type attrType) { +// ATTR: ::mlir::Type type) { // ATTR: bool _seen_v0 = false; // ATTR: bool _seen_v1 = false; // ATTR: for (unsigned _index = 0; _index < 2; ++_index) { @@ -92,12 +92,12 @@ // ATTR: return {}; // ATTR: if (!_seen_v0 && _paramKey == "v0") { // ATTR: _seen_v0 = true; -// ATTR: _result_v0 = ::parseAttrParamA(parser, attrType); +// ATTR: _result_v0 = ::parseAttrParamA(parser, type); // ATTR: if (failed(_result_v0)) // ATTR: return {}; // ATTR: } else if (!_seen_v1 && _paramKey == "v1") { // ATTR: _seen_v1 = true; -// ATTR: _result_v1 = attrType ? ::parseAttrWithType(parser, attrType) : ::parseAttrWithout(parser); +// ATTR: _result_v1 = type ? ::parseAttrWithType(parser, type) : ::parseAttrWithout(parser); // ATTR: if (failed(_result_v1)) // ATTR: return {}; // ATTR: } else { @@ -136,7 +136,7 @@ /// Test attribute with capture-all params has correct parser and printer. // ATTR: ::mlir::Attribute TestFAttr::parse(::mlir::AsmParser &parser, -// ATTR: ::mlir::Type attrType) { +// ATTR: ::mlir::Type type) { // ATTR: ::mlir::FailureOr _result_v0; // ATTR: ::mlir::FailureOr _result_v1; // ATTR: _result_v0 = ::mlir::FieldParser::parse(parser); diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -7,9 +7,9 @@ // DECL: #undef GET_ATTRDEF_CLASSES // DECL: namespace mlir { -// DECL: class DialectAsmParser; -// DECL: class DialectAsmPrinter; -// DECL: } // namespace mlir +// DECL: class AsmParser; +// DECL: class AsmPrinter; +// DECL: } // end namespace mlir // DEF: #ifdef GET_ATTRDEF_LIST // DEF: #undef GET_ATTRDEF_LIST @@ -19,9 +19,9 @@ // DEF: ::test::SingleParameterAttr // DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser( -// DEF-NEXT: ::mlir::AsmParser &parser, -// DEF-NEXT: ::llvm::StringRef mnemonic, ::mlir::Type type, -// DEF-NEXT: ::mlir::Attribute &value) { +// DEF-SAME: ::mlir::AsmParser &parser, +// DEF-SAME: ::llvm::StringRef mnemonic, ::mlir::Type type, +// DEF-SAME: ::mlir::Attribute &value) { // DEF: if (mnemonic == ::test::CompoundAAttr::getMnemonic()) { // DEF-NEXT: value = ::test::CompoundAAttr::parse(parser, type); // DEF-NEXT: return ::mlir::success(!!value); @@ -61,10 +61,10 @@ let genVerifyDecl = 1; // DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute -// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef dims, ::mlir::Type inner); +// DECL: static CompoundAAttr getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef dims, ::mlir::Type inner); // DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef dims, ::mlir::Type inner); // DECL: static constexpr ::llvm::StringLiteral getMnemonic() { -// DECL: return ::llvm::StringLiteral("cmpnd_a"); +// DECL: return "cmpnd_a"; // DECL: } // DECL: static ::mlir::Attribute parse( // DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type); @@ -75,27 +75,23 @@ // Check that AttributeSelfTypeParameter is handled properly. // DEF-LABEL: struct CompoundAAttrStorage -// DEF: CompoundAAttrStorage ( -// DEF-NEXT: : ::mlir::AttributeStorage(inner), +// DEF: CompoundAAttrStorage( +// DEF-SAME: : ::mlir::AttributeStorage(inner), // DEF: bool operator==(const KeyTy &tblgenKey) const { -// DEF-NEXT: if (!(widthOfSomething == std::get<0>(tblgenKey))) -// DEF-NEXT: return false; -// DEF-NEXT: if (!(exampleTdType == std::get<1>(tblgenKey))) -// DEF-NEXT: return false; -// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(tblgenKey)))) -// DEF-NEXT: return false; -// DEF-NEXT: if (!(dims == std::get<3>(tblgenKey))) -// DEF-NEXT: return false; -// DEF-NEXT: if (!(getType() == std::get<4>(tblgenKey))) -// DEF-NEXT: return false; -// DEF-NEXT: return true; +// DEF-NEXT: return +// DEF-SAME: (widthOfSomething == std::get<0>(tblgenKey)) && +// DEF-SAME: (exampleTdType == std::get<1>(tblgenKey)) && +// DEF-SAME: (apFloat.bitwiseIsEqual(std::get<2>(tblgenKey))) && +// DEF-SAME: (dims == std::get<3>(tblgenKey)) && +// DEF-SAME: (getType() == std::get<4>(tblgenKey)); // DEF: static CompoundAAttrStorage *construct // DEF: return new (allocator.allocate()) -// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner); +// DEF-SAME: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner); -// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType().cast<::mlir::Type>(); } +// DEF: ::mlir::Type CompoundAAttr::getInner() const { +// DEF-NEXT: return getImpl()->getType().cast<::mlir::Type>(); } def C_IndexAttr : TestAttr<"Index"> { @@ -108,7 +104,7 @@ // DECL-LABEL: class IndexAttr : public ::mlir::Attribute // DECL: static constexpr ::llvm::StringLiteral getMnemonic() { -// DECL: return ::llvm::StringLiteral("index"); +// DECL: return "index"; // DECL: } // DECL: static ::mlir::Attribute parse( // DECL-SAME: ::mlir::AsmParser &parser, ::mlir::Type type); @@ -122,7 +118,7 @@ ); // DECL-LABEL: struct SingleParameterAttrStorage; // DECL-LABEL: class SingleParameterAttr -// DECL-NEXT: detail::SingleParameterAttrStorage +// DECL-SAME: detail::SingleParameterAttrStorage } // An attribute testing AttributeSelfTypeParameter. @@ -133,8 +129,8 @@ } // DEF-LABEL: struct AttrWithTypeBuilderAttrStorage -// DEF: AttrWithTypeBuilderAttrStorage (::mlir::IntegerAttr attr) -// DEF-NEXT: : ::mlir::AttributeStorage(attr.getType()), attr(attr) +// DEF: AttrWithTypeBuilderAttrStorage(::mlir::IntegerAttr attr) +// DEF-SAME: : ::mlir::AttributeStorage(attr.getType()), attr(attr) def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> { let parameters = (ins AttrParameter<"std::string", "", "StringRef">:$param); @@ -143,6 +139,5 @@ // DECL-LABEL: class ParamWithAccessorTypeAttr // DECL: StringRef getParam() // DEF: ParamWithAccessorTypeAttrStorage -// DEF-NEXT: ParamWithAccessorTypeAttrStorage (std::string param) +// DEF: ParamWithAccessorTypeAttrStorage(std::string param) // DEF: StringRef ParamWithAccessorTypeAttr::getParam() - diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -9,7 +9,7 @@ // DECL: namespace mlir { // DECL: class AsmParser; // DECL: class AsmPrinter; -// DECL: } // namespace mlir +// DECL: } // end namespace mlir // DEF: #ifdef GET_TYPEDEF_LIST // DEF: #undef GET_TYPEDEF_LIST @@ -20,9 +20,9 @@ // DEF: ::test::IntegerType // DEF-LABEL: ::mlir::OptionalParseResult generatedTypeParser( -// DEF-NEXT: ::mlir::AsmParser &parser, -// DEF-NEXT: ::llvm::StringRef mnemonic, -// DEF-NEXT: ::mlir::Type &value) { +// DEF-SAME: ::mlir::AsmParser &parser, +// DEF-SAME: ::llvm::StringRef mnemonic, +// DEF-SAME: ::mlir::Type &value) { // DEF: if (mnemonic == ::test::CompoundAType::getMnemonic()) { // DEF-NEXT: value = ::test::CompoundAType::parse(parser); // DEF-NEXT: return ::mlir::success(!!value); @@ -66,10 +66,10 @@ let genVerifyDecl = 1; // DECL-LABEL: class CompoundAType : public ::mlir::Type -// DECL: static CompoundAType getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); +// DECL: static CompoundAType getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); // DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); // DECL: static constexpr ::llvm::StringLiteral getMnemonic() { -// DECL: return ::llvm::StringLiteral("cmpnd_a"); +// DECL: return "cmpnd_a"; // DECL: } // DECL: static ::mlir::Type parse(::mlir::AsmParser &parser); // DECL: void print(::mlir::AsmPrinter &printer) const; @@ -88,7 +88,7 @@ // DECL-LABEL: class IndexType : public ::mlir::Type // DECL: static constexpr ::llvm::StringLiteral getMnemonic() { -// DECL: return ::llvm::StringLiteral("index"); +// DECL: return "index"; // DECL: } // DECL: static ::mlir::Type parse(::mlir::AsmParser &parser); // DECL: void print(::mlir::AsmPrinter &printer) const; @@ -101,7 +101,7 @@ ); // DECL-LABEL: struct SingleParameterTypeStorage; // DECL-LABEL: class SingleParameterType -// DECL-NEXT: detail::SingleParameterTypeStorage +// DECL-SAME: detail::SingleParameterTypeStorage } def E_IntegerType : TestType<"Integer"> { diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -9,11 +9,13 @@ #include "AttrOrTypeFormatGen.h" #include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/AttrOrTypeDef.h" +#include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" @@ -43,104 +45,559 @@ SmallVectorImpl &resultDefs) { auto defs = llvm::map_range( records, [&](const llvm::Record *rec) { return AttrOrTypeDef(rec); }); - if (defs.empty()) - return; - - StringRef dialectName; if (selectedDialect.empty()) { - if (defs.empty()) - return; - - Dialect dialect(nullptr); - for (const AttrOrTypeDef &typeDef : defs) { - if (!dialect) { - dialect = typeDef.getDialect(); - } else if (dialect != typeDef.getDialect()) { - llvm::PrintFatalError("defs belonging to more than one dialect. Must " - "select one via '--(attr|type)defs-dialect'"); - } + if (!llvm::is_splat( + llvm::map_range(defs, [](auto def) { return def.getDialect(); }))) { + llvm::PrintFatalError("defs belonging to more than one dialect. Must " + "select one via '--(attr|type)defs-dialect'"); } - - dialectName = dialect.getName(); + resultDefs.assign(defs.begin(), defs.end()); } else { - dialectName = selectedDialect; + auto dialectDefs = llvm::make_filter_range(defs, [&](auto def) { + return def.getDialect().getName().equals(selectedDialect); + }); + resultDefs.assign(dialectDefs.begin(), dialectDefs.end()); } - - for (const AttrOrTypeDef &def : defs) - if (def.getDialect().getName().equals(dialectName)) - resultDefs.push_back(def); } //===----------------------------------------------------------------------===// -// ParamCommaFormatter +// DefGen //===----------------------------------------------------------------------===// namespace { - -/// Pass an instance of this class to llvm::formatv() to emit a comma separated -/// list of parameters in the format by 'EmitFormat'. -class ParamCommaFormatter : public llvm::detail::format_adapter { +class DefGen { public: - /// Choose the output format - enum EmitFormat { - /// Emit "parameter1Type parameter1Name, parameter2Type parameter2Name, - /// [...]". - TypeNamePairs, + /// Create the attribute or type class. + DefGen(const AttrOrTypeDef &def); - /// Emit "parameter1(parameter1), parameter2(parameter2), [...]". - TypeNameInitializer, - - /// Emit "param1Name, param2Name, [...]". - JustParams, - }; - - ParamCommaFormatter(EmitFormat emitFormat, - ArrayRef params, - bool prependComma = true) - : emitFormat(emitFormat), params(params), prependComma(prependComma) {} - - /// llvm::formatv will call this function when using an instance as a - /// replacement value. - void format(raw_ostream &os, StringRef options) override { - if (!params.empty() && prependComma) - os << ", "; - - switch (emitFormat) { - case EmitFormat::TypeNamePairs: - interleaveComma(params, os, [&](const AttrOrTypeParameter &p) { - emitTypeNamePair(p, os); - }); - break; - case EmitFormat::TypeNameInitializer: - interleaveComma(params, os, [&](const AttrOrTypeParameter &p) { - emitTypeNameInitializer(p, os); - }); - break; - case EmitFormat::JustParams: - interleaveComma(params, os, - [&](const AttrOrTypeParameter &p) { os << p.getName(); }); - break; + void emitDecl(raw_ostream &os) const { + if (storageCls) { + NamespaceEmitter ns(os, def.getStorageNamespace()); + os << "struct " << def.getStorageClassName() << ";\n"; + } + defCls.writeDeclTo(os); + } + void emitDef(raw_ostream &os) const { + if (storageCls && def.genStorageClass()) { + NamespaceEmitter ns(os, def.getStorageNamespace()); + storageCls->writeDeclTo(os); // everything is inline } + defCls.writeDefTo(os); } private: - // Emit "paramType paramName". - static void emitTypeNamePair(const AttrOrTypeParameter ¶m, - raw_ostream &os) { - os << param.getCppType() << " " << param.getName(); + /// Add traits from the TableGen definition to the class. + void createParentWithTraits(); + /// Emit top-level declarations: using declarations and any extra class + /// declarations. + void emitTopLevelDeclarations(); + /// Emit attribute or type builders. + void emitBuilders(); + /// Emit a verifier for the def. + void emitVerifier(); + /// Emit parsers and printers. + void emitParserPrinter(); + /// Emit parameter accessors, if required. + void emitAccessors(); + /// Emit interface methods. + void emitInterfaceMethods(); + + //===--------------------------------------------------------------------===// + // Builder Emission + + /// Emit the default builder `Attribute::get` + void emitDefaultBuilder(); + /// Emit the checked builder `Attribute::getChecked` + void emitCheckedBuilder(); + /// Emit a custom builder. + void emitCustomBuilder(const AttrOrTypeBuilder &builder); + /// Emit a checked custom builder. + void emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder); + + //===--------------------------------------------------------------------===// + // Parser and Printer Emission + void emitParserPrinterBody(MethodBody &parser, MethodBody &printer); + + //===--------------------------------------------------------------------===// + // Interface Method Emission + + /// Emit methods for a trait. + void emitTraitMethods(const InterfaceTrait &trait); + /// Emit a trait method. + void emitTraitMethod(const InterfaceMethod &method); + + //===--------------------------------------------------------------------===// + // Storage Class Emission + void emitStorageClass(); + /// Generate the storage class constructor. + void emitStorageConstructor(); + /// Emit the key type `KeyTy`. + void emitKeyType(); + /// Emit the equality comparison operator. + void emitEquals(); + /// Emit the key hash function. + void emitHashKey(); + /// Emit the function to construct the storage class. + void emitConstruct(); + + //===--------------------------------------------------------------------===// + // Utility Function Declarations + + /// Get the method parameters for a def builder, where the first several + /// parameters may be different. + SmallVector + getBuilderParams(std::initializer_list prefix) const; + + //===--------------------------------------------------------------------===// + // Class fields + + /// The attribute or type definition. + const AttrOrTypeDef &def; + /// The list of attribute or type parameters. + const SmallVector params; + /// The attribute or type class. + Class defCls; + /// An optional attribute or type storage class. The storage class will + /// existif and only if the def has more than zero parameters. + Optional storageCls; + + /// The C++ base value of the def, either "Attribute" or "Type". + StringRef valueType; + /// The prefix/suffix of the TableGen def name, either "Attr" or "Type". + StringRef defType; +}; +} // end anonymous namespace + +DefGen::DefGen(const AttrOrTypeDef &def) + : def(def), params(def.getParameters()), + defCls(def.getCppClassName().str()), + valueType(isa(def) ? "Attribute" : "Type"), + defType(isa(def) ? "Attr" : "Type") { + /// If a storage class is needed, create one. + if (def.getNumParameters() > 0) + storageCls.emplace(def.getStorageClassName().str(), /*isStruct=*/true); + + /// Create the parent class with any indicated traits. + createParentWithTraits(); + /// Emit top-level declarations. + emitTopLevelDeclarations(); + /// Emit builders for defs with parameters + if (storageCls) + emitBuilders(); + /// Emit the verifier. + if (storageCls && def.genVerifyDecl()) + emitVerifier(); + /// Emit the mnemonic, if there is one, and any associated parser and printer. + if (def.getMnemonic()) + emitParserPrinter(); + /// Emit accessors + if (def.genAccessors()) + emitAccessors(); + /// Emit trait interface methods + emitInterfaceMethods(); + defCls.finalize(); + /// Emit a storage class if one is needed + if (storageCls && def.genStorageClass()) + emitStorageClass(); +} + +void DefGen::createParentWithTraits() { + ParentClass defParent(strfmt("::mlir::{0}::{1}Base", valueType, defType)); + defParent.addTemplateParam(def.getCppClassName().str()); + defParent.addTemplateParam(def.getCppBaseClassName().str()); + defParent.addTemplateParam(storageCls + ? strfmt("{0}::{1}", def.getStorageNamespace(), + def.getStorageClassName()) + : strfmt("::mlir::{0}Storage", valueType)); + llvm::for_each(def.getTraits(), [&](auto &trait) { + defParent.addTemplateParam( + isa(&trait) + ? cast(&trait)->getFullyQualifiedTraitName() + : cast(&trait)->getFullyQualifiedTraitName()); + }); + defCls.addParent(std::move(defParent)); +} + +void DefGen::emitTopLevelDeclarations() { + /// Inherit constructors from the attribute or type class. + defCls.declare(Visibility::Public); + defCls.declare("Base::Base"); + + /// Emit the extra declarations first in case there's a definition in there. + if (Optional extraDecl = def.getExtraDecls()) + defCls.declare(*extraDecl); +} + +void DefGen::emitBuilders() { + if (!def.skipDefaultBuilders()) { + emitDefaultBuilder(); + if (def.genVerifyDecl()) + emitCheckedBuilder(); } - // Emit "paramName(paramName)" - void emitTypeNameInitializer(const AttrOrTypeParameter ¶m, - raw_ostream &os) { - os << param.getName() << "(" << param.getName() << ")"; + for (auto &builder : def.getBuilders()) { + emitCustomBuilder(builder); + if (def.genVerifyDecl()) + emitCheckedCustomBuilder(builder); } +} - EmitFormat emitFormat; - ArrayRef params; - bool prependComma; -}; +void DefGen::emitVerifier() { + defCls.declare("Base::getChecked"); + defCls.declareStaticMethod( + "::mlir::LogicalResult", "verify", + getBuilderParams({{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", + "emitError"}})); +} -} // end anonymous namespace +void DefGen::emitParserPrinter() { + auto *mnemonic = defCls.addStaticMethod( + "::llvm::StringLiteral", "getMnemonic"); + mnemonic->body().indent() << strfmt("return \"{0}\";", *def.getMnemonic()); + /// Declare the parser and printer, if needed. + if (!def.needsParserPrinter() && !def.hasGeneratedParser() && + !def.hasGeneratedPrinter()) + return; + + /// Declare the parser. + SmallVector parserParams; + parserParams.emplace_back("::mlir::AsmParser &", "parser"); + if (isa(&def)) + parserParams.emplace_back("::mlir::Type", "type"); + auto *parser = defCls.addMethod( + strfmt("::mlir::{0}", valueType), "parse", + def.hasGeneratedParser() ? Method::Static : Method::StaticDeclaration, + std::move(parserParams)); + /// Declare the printer. + auto props = + def.hasGeneratedPrinter() ? Method::Const : Method::ConstDeclaration; + Method *printer = + defCls.addMethod("void", "print", props, + MethodParameter("::mlir::AsmPrinter &", "printer")); + /// Emit the bodies. + emitParserPrinterBody(parser->body(), printer->body()); +} + +void DefGen::emitAccessors() { + for (auto ¶m : params) { + Method *m = defCls.addMethod( + param.getCppAccessorType(), getParameterAccessorName(param.getName()), + def.genStorageClass() ? Method::Const : Method::ConstDeclaration); + /// Generate accessor definitions only if we also generate the storage + /// class. Otherwise, let the user define the exact accessor definition. + if (!def.genStorageClass()) + continue; + auto scope = m->body().indent().scope("return getImpl()->", ";"); + if (isa(param)) + m->body() << formatv("getType().cast<{0}>()", param.getCppType()); + else + m->body() << param.getName(); + } +} + +void DefGen::emitInterfaceMethods() { + for (auto &traitDef : def.getTraits()) + if (auto *trait = dyn_cast(&traitDef)) + if (trait->shouldDeclareMethods()) + emitTraitMethods(*trait); +} + +//===----------------------------------------------------------------------===// +// Builder Emission + +SmallVector +DefGen::getBuilderParams(std::initializer_list prefix) const { + SmallVector builderParams; + builderParams.append(prefix.begin(), prefix.end()); + for (auto ¶m : params) + builderParams.emplace_back(param.getCppType(), param.getName()); + return builderParams; +} + +void DefGen::emitDefaultBuilder() { + Method *m = defCls.addStaticMethod( + def.getCppClassName(), "get", + getBuilderParams({{"::mlir::MLIRContext *", "context"}})); + MethodBody &body = m->body().indent(); + auto scope = body.scope("return Base::get(context", ");"); + llvm::for_each(params, [&](auto ¶m) { body << ", " << param.getName(); }); +} + +void DefGen::emitCheckedBuilder() { + Method *m = defCls.addStaticMethod( + def.getCppClassName(), "getChecked", + getBuilderParams( + {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}, + {"::mlir::MLIRContext *", "context"}})); + MethodBody &body = m->body().indent(); + auto scope = body.scope("return Base::getChecked(emitError, context", ");"); + llvm::for_each(params, [&](auto ¶m) { body << ", " << param.getName(); }); +} + +static SmallVector +getCustomBuilderParams(std::initializer_list prefix, + const AttrOrTypeBuilder &builder) { + auto params = builder.getParameters(); + SmallVector builderParams; + builderParams.append(prefix.begin(), prefix.end()); + if (!builder.hasInferredContextParameter()) + builderParams.emplace_back("::mlir::MLIRContext *", "context"); + for (auto ¶m : params) { + builderParams.emplace_back(param.getCppType(), param.getName(), + param.getDefaultValue()); + } + return builderParams; +} + +void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) { + /// Don't emit a body if there isn't one. + auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; + Method *m = defCls.addMethod(def.getCppClassName(), "get", props, + getCustomBuilderParams({}, builder)); + if (!builder.getBody()) + return; + + /// Format the body and emit it. + FmtContext ctx; + ctx.addSubst("_get", "Base::get"); + if (!builder.hasInferredContextParameter()) + ctx.addSubst("_ctxt", "context"); + std::string bodyStr = tgfmt(*builder.getBody(), &ctx); + m->body().indent().getStream().reindent(bodyStr); +} + +/// Replace all instances of 'from' to 'to' in `str` and return the new string. +static std::string replaceInStr(std::string str, StringRef from, StringRef to) { + size_t pos = 0; + while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos) + str.replace(pos, from.size(), to.data(), to.size()); + return str; +} + +void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) { + /// Don't emit a body if there isn't one. + auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; + Method *m = defCls.addMethod( + def.getCppClassName(), "getChecked", props, + getCustomBuilderParams( + {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}, + builder)); + if (!builder.getBody()) + return; + + /// Format the body and emit it. Replace $_get(...) with + /// Base::getChecked(emitError, ...) + FmtContext ctx; + if (!builder.hasInferredContextParameter()) + ctx.addSubst("_ctxt", "context"); + std::string bodyStr = replaceInStr(builder.getBody()->str(), "$_get(", + "Base::getChecked(emitError, "); + bodyStr = tgfmt(bodyStr, &ctx); + m->body().indent().getStream().reindent(bodyStr); +} + +//===----------------------------------------------------------------------===// +// Parser and Printer Emission + +void DefGen::emitParserPrinterBody(MethodBody &parser, MethodBody &printer) { + Optional parserCode = def.getParserCode(); + Optional printerCode = def.getPrinterCode(); + Optional asmFormat = def.getAssemblyFormat(); + /// Verify the parser-printer specification first. + if (asmFormat && (parserCode || printerCode)) { + PrintFatalError(def.getLoc(), + def.getName() + ": assembly format cannot be specified at " + "the same time as printer or parser code"); + } + /// Specified code cannot be empty. + if (parserCode && parserCode->empty()) + PrintFatalError(def.getLoc(), def.getName() + ": parser cannot be empty"); + if (printerCode && printerCode->empty()) + PrintFatalError(def.getLoc(), def.getName() + ": printer cannot be empty"); + /// Assembly format requires accessors to be generated. + if (asmFormat && !def.genAccessors()) { + PrintFatalError(def.getLoc(), + def.getName() + + ": the generated printer from 'assemblyFormat' " + "requires 'genAccessors' to be true"); + } + + /// Generate the parser and printer bodies. + if (asmFormat) + return generateAttrOrTypeFormat(def, parser, printer); + + FmtContext ctx = FmtContext( + {{"_parser", "parser"}, {"_printer", "printer"}, {"_type", "type"}}); + if (parserCode) { + ctx.addSubst("_ctxt", "parser.getContext()"); + parser.indent().getStream().reindent(tgfmt(*parserCode, &ctx).str()); + } + if (printerCode) { + ctx.addSubst("_ctxt", "printer.getContext()"); + printer.indent().getStream().reindent(tgfmt(*printerCode, &ctx).str()); + } +} + +//===----------------------------------------------------------------------===// +// Interface Method Emission + +void DefGen::emitTraitMethods(const InterfaceTrait &trait) { + /// Get the set of methods that should always be declared. + auto alwaysDeclaredMethods = trait.getAlwaysDeclaredMethods(); + StringSet<> alwaysDeclared; + alwaysDeclared.insert(alwaysDeclaredMethods.begin(), + alwaysDeclaredMethods.end()); + + Interface iface = trait.getInterface(); // causes strange bugs if elided + for (auto &method : iface.getMethods()) { + /// Don't declare if the method has a body. Or if the method has a default + /// implementation and the def didn't request that it always be declared. + if (method.getBody() || (method.getDefaultImplementation() && + !alwaysDeclared.count(method.getName()))) + continue; + emitTraitMethod(method); + } +} + +void DefGen::emitTraitMethod(const InterfaceMethod &method) { + /// All interface methods are declaration-only. + auto props = Method::Declaration | + (method.isStatic() ? Method::Static : Method::Const); + SmallVector params; + for (auto ¶m : method.getArguments()) + params.emplace_back(param.type, param.name); + defCls.addMethod(method.getReturnType(), method.getName(), props, + std::move(params)); +} + +//===----------------------------------------------------------------------===// +// Storage Class Emission + +void DefGen::emitStorageConstructor() { + Constructor *ctor = + storageCls->addConstructor(getBuilderParams({})); + if (auto *attrDef = dyn_cast(&def)) { + /// For attributes, a parameter marked with AttributeSelfTypeParameter is + /// the type initializer that must be passed to the parent constructor. + const auto isSelfType = [](const AttrOrTypeParameter ¶m) { + return isa(param); + }; + auto *selfTypeParam = llvm::find_if(params, isSelfType); + if (std::count_if(selfTypeParam, params.end(), isSelfType) > 1) { + PrintFatalError(def.getLoc(), + "Only one attribute parameter can be marked as " + "AttributeSelfTypeParameter"); + } + /// Alternatively, if a type builder was specified, use that instead. + std::string attrStorageInit = + selfTypeParam == params.end() ? "" : selfTypeParam->getName().str(); + if (attrDef->getTypeBuilder()) { + FmtContext ctx; + for (auto ¶m : params) + ctx.addSubst(strfmt("_{0}", param.getName()), param.getName()); + attrStorageInit = tgfmt(*attrDef->getTypeBuilder(), &ctx); + } + ctor->addMemberInitializer("::mlir::AttributeStorage", + std::move(attrStorageInit)); + /// Initialize members that aren't the attribute's type. + for (auto ¶m : params) + if (selfTypeParam == params.end() || *selfTypeParam != param) + ctor->addMemberInitializer(param.getName(), param.getName()); + } else { + for (auto ¶m : params) + ctor->addMemberInitializer(param.getName(), param.getName()); + } +} + +void DefGen::emitKeyType() { + std::string keyType("std::tuple<"); + llvm::interleave( + params, [&](auto ¶m) { keyType += param.getCppType(); }, + [&] { keyType += ", "; }); + keyType.push_back('>'); + storageCls->declare("KeyTy", std::move(keyType)); +} + +void DefGen::emitEquals() { + Method *eq = storageCls->addConstMethod( + "bool", "operator==", MethodParameter("const KeyTy &", "tblgenKey")); + auto &body = eq->body().indent(); + auto scope = body.scope("return (", ");"); + const auto eachFn = [&](auto it) { + FmtContext ctx({{"_lhs", isa(it.value()) + ? "getType()" + : it.value().getName()}, + {"_rhs", strfmt("std::get<{0}>(tblgenKey)", it.index())}}); + Optional comparator = it.value().getComparator(); + body << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &ctx); + }; + llvm::interleave(llvm::enumerate(params), body, eachFn, ") && ("); +} + +void DefGen::emitHashKey() { + Method *hash = storageCls->addStaticInlineMethod( + "::llvm::hash_code", "hashKey", + MethodParameter("const KeyTy &", "tblgenKey")); + auto &body = hash->body().indent(); + auto scope = body.scope("return ::llvm::hash_combine(", ");"); + llvm::interleaveComma(llvm::enumerate(params), body, [&](auto it) { + body << llvm::formatv("std::get<{0}>(tblgenKey)", it.index()); + }); +} + +void DefGen::emitConstruct() { + Method *construct = storageCls->addMethod( + strfmt("{0} *", def.getStorageClassName()), "construct", + def.hasStorageCustomConstructor() ? Method::StaticDeclaration + : Method::Static, + MethodParameter(strfmt("::mlir::{0}StorageAllocator &", valueType), + "allocator"), + MethodParameter("const KeyTy &", "tblgenKey")); + if (!def.hasStorageCustomConstructor()) { + auto &body = construct->body().indent(); + for (auto it : llvm::enumerate(params)) { + body << formatv("auto {0} = std::get<{1}>(tblgenKey);\n", + it.value().getName(), it.index()); + } + /// Use the parameters' custom allocator code, if provided. + FmtContext ctx = FmtContext().addSubst("_allocator", "allocator"); + for (auto ¶m : params) { + if (Optional allocCode = param.getAllocator()) { + ctx.withSelf(param.getName()).addSubst("_dst", param.getName()); + body << tgfmt(*allocCode, &ctx) << '\n'; + } + } + auto scope = + body.scope(strfmt("return new (allocator.allocate<{0}>()) {0}(", + def.getStorageClassName()), + ");"); + llvm::interleaveComma(params, body, + [&](auto ¶m) { body << param.getName(); }); + } +} + +void DefGen::emitStorageClass() { + /// Add the appropriate parent class. + storageCls->addParent(strfmt("::mlir::{0}Storage", valueType)); + /// Add the constructor. + emitStorageConstructor(); + /// Declare the key type. + emitKeyType(); + /// Add the comparison method. + emitEquals(); + /// Emit the key hash method. + emitHashKey(); + /// Emit the storage constructor. Just declare it if the user wants to define + /// it themself. + emitConstruct(); + /// Emit the storage class members as public, at the very end of the struct. + storageCls->finalize(); + for (auto ¶m : params) { + if (!isa(param)) { + storageCls->declare(param.getCppType().str(), + param.getName().str()); + } + } +} //===----------------------------------------------------------------------===// // DefGenerator @@ -154,51 +611,41 @@ bool emitDefs(StringRef selectedDialect); protected: - DefGenerator(std::vector &&defs, raw_ostream &os) - : defRecords(std::move(defs)), os(os), isAttrGenerator(false) {} + DefGenerator(std::vector &&defs, raw_ostream &os, + StringRef defType, StringRef valueType, bool isAttrGenerator) + : defRecords(std::move(defs)), os(os), defType(defType), + valueType(valueType), isAttrGenerator(isAttrGenerator) {} - /// Emit the declaration of a single def. - void emitDefDecl(const AttrOrTypeDef &def); /// Emit the list of def type names. void emitTypeDefList(ArrayRef defs); /// Emit the code to dispatch between different defs during parsing/printing. void emitParsePrintDispatch(ArrayRef defs); - /// Emit the definition of a single def. - void emitDefDef(const AttrOrTypeDef &def); - /// Emit the storage class for the given def. - void emitStorageClass(const AttrOrTypeDef &def); - /// Emit the parser/printer for the given def. - void emitParsePrint(const AttrOrTypeDef &def); /// The set of def records to emit. std::vector defRecords; + /// The attribute or type class to emit. /// The stream to emit to. raw_ostream &os; /// The prefix of the tablegen def name, e.g. Attr or Type. - StringRef defTypePrefix; + const StringRef defType; /// The C++ base value type of the def, e.g. Attribute or Type. - StringRef valueType; + const StringRef valueType; /// Flag indicating if this generator is for Attributes. False if the /// generator is for types. - bool isAttrGenerator; + const bool isAttrGenerator; }; /// A specialized generator for AttrDefs. struct AttrDefGenerator : public DefGenerator { AttrDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os) { - isAttrGenerator = true; - defTypePrefix = "Attr"; - valueType = "Attribute"; - } + : DefGenerator(records.getAllDerivedDefinitions("AttrDef"), os, "Attr", + "Attribute", /*isAttrGenerator=*/true) {} }; /// A specialized generator for TypeDefs. struct TypeDefGenerator : public DefGenerator { TypeDefGenerator(const llvm::RecordKeeper &records, raw_ostream &os) - : DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os) { - defTypePrefix = "Type"; - valueType = "Type"; - } + : DefGenerator(records.getAllDerivedDefinitions("TypeDef"), os, "Type", + "Type", /*isAttrGenerator=*/false) {} }; } // end anonymous namespace @@ -211,240 +658,13 @@ static const char *const typeDefDeclHeader = R"( namespace mlir { class AsmParser; -class DialectAsmParser; class AsmPrinter; -class DialectAsmPrinter; -} // namespace mlir +} // end namespace mlir )"; -/// The code block for the start of a typeDef class declaration -- singleton -/// case. -/// -/// {0}: The name of the def class. -/// {1}: The name of the type base class. -/// {2}: The name of the base value type, e.g. Attribute or Type. -/// {3}: The tablegen record type prefix, e.g. Attr or Type. -/// {4}: The traits of the def class. -static const char *const defDeclSingletonBeginStr = R"( - class {0} : public ::mlir::{2}::{3}Base<{0}, {1}, ::mlir::{2}Storage{4}> {{ - public: - /// Inherit some necessary constructors from '{3}Base'. - using Base::Base; -)"; - -/// The code block for the start of a class declaration -- parametric case. -/// -/// {0}: The name of the def class. -/// {1}: The name of the base class. -/// {2}: The def storage class namespace. -/// {3}: The storage class name. -/// {4}: The name of the base value type, e.g. Attribute or Type. -/// {5}: The tablegen record type prefix, e.g. Attr or Type. -/// {6}: The traits of the def class. -static const char *const defDeclParametricBeginStr = R"( - namespace {2} { - struct {3}; - } // end namespace {2} - class {0} : public ::mlir::{4}::{5}Base<{0}, {1}, - {2}::{3}{6}> {{ - public: - /// Inherit some necessary constructors from '{5}Base'. - using Base::Base; - -)"; - -/// The code snippet for print/parse of an Attribute/Type. -/// -/// {0}: The name of the base value type, e.g. Attribute or Type. -/// {1}: Extra parser parameters. -static const char *const defDeclParsePrintStr = R"( - static ::mlir::{0} parse(::mlir::AsmParser &parser{1}); - void print(::mlir::AsmPrinter &printer) const; -)"; - -/// The code block for the verify method declaration. -/// -/// {0}: List of parameters, parameters style. -static const char *const defDeclVerifyStr = R"( - using Base::getChecked; - static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0}); -)"; - -/// Emit the builders for the given def. -static void emitBuilderDecls(const AttrOrTypeDef &def, raw_ostream &os, - ParamCommaFormatter ¶mTypes) { - StringRef typeClass = def.getCppClassName(); - bool genCheckedMethods = def.genVerifyDecl(); - if (!def.skipDefaultBuilders()) { - os << llvm::formatv( - " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass, - paramTypes); - if (genCheckedMethods) { - os << llvm::formatv(" static {0} " - "getChecked(llvm::function_ref<::mlir::" - "InFlightDiagnostic()> emitError, " - "::mlir::MLIRContext *context{1});\n", - typeClass, paramTypes); - } - } - - // Generate the builders specified by the user. - for (const AttrOrTypeBuilder &builder : def.getBuilders()) { - std::string paramStr; - llvm::raw_string_ostream paramOS(paramStr); - llvm::interleaveComma( - builder.getParameters(), paramOS, - [&](const AttrOrTypeBuilder::Parameter ¶m) { - // Note: AttrOrTypeBuilder parameters are guaranteed to have names. - paramOS << param.getCppType() << " " << *param.getName(); - if (Optional defaultParamValue = param.getDefaultValue()) - paramOS << " = " << *defaultParamValue; - }); - paramOS.flush(); - - // Generate the `get` variant of the builder. - os << " static " << typeClass << " get("; - if (!builder.hasInferredContextParameter()) { - os << "::mlir::MLIRContext *context"; - if (!paramStr.empty()) - os << ", "; - } - os << paramStr << ");\n"; - - // Generate the `getChecked` variant of the builder. - if (genCheckedMethods) { - os << " static " << typeClass - << " getChecked(llvm::function_ref " - "emitError"; - if (!builder.hasInferredContextParameter()) - os << ", ::mlir::MLIRContext *context"; - if (!paramStr.empty()) - os << ", "; - os << paramStr << ");\n"; - } - } -} - -static void emitInterfaceMethodDecls(const InterfaceTrait *trait, - raw_ostream &os) { - Interface interface = trait->getInterface(); - - // Get the set of methods that should always be declared. - auto alwaysDeclaredMethodsVec = trait->getAlwaysDeclaredMethods(); - llvm::StringSet<> alwaysDeclaredMethods; - alwaysDeclaredMethods.insert(alwaysDeclaredMethodsVec.begin(), - alwaysDeclaredMethodsVec.end()); - - for (const InterfaceMethod &method : interface.getMethods()) { - // Don't declare if the method has a body. - if (method.getBody()) - continue; - // Don't declare if the method has a default implementation and the def - // didn't request that it always be declared. - if (method.getDefaultImplementation() && - !alwaysDeclaredMethods.count(method.getName())) - continue; - - // Emit the method declaration. - os << " " << (method.isStatic() ? "static " : "") - << method.getReturnType() << " " << method.getName() << "("; - llvm::interleaveComma(method.getArguments(), os, - [&](const InterfaceMethod::Argument &arg) { - os << arg.type << " " << arg.name; - }); - os << ")" << (method.isStatic() ? "" : " const") << ";\n"; - } -} - -void DefGenerator::emitDefDecl(const AttrOrTypeDef &def) { - SmallVector params; - def.getParameters(params); - - // Build the trait list for this def. - std::vector traitList; - StringSet<> traitSet; - for (const Trait &baseTrait : def.getTraits()) { - std::string traitStr; - if (const auto *trait = dyn_cast(&baseTrait)) - traitStr = trait->getFullyQualifiedTraitName(); - else if (const auto *trait = dyn_cast(&baseTrait)) - traitStr = trait->getFullyQualifiedTraitName(); - else - llvm_unreachable("unexpected Attribute/Type trait type"); - - if (traitSet.insert(traitStr).second) - traitList.emplace_back(std::move(traitStr)); - } - std::string traitStr; - if (!traitList.empty()) - traitStr = ", " + llvm::join(traitList, ", "); - - // Emit the beginning string template: either the singleton or parametric - // template. - if (def.getNumParameters() == 0) { - os << formatv(defDeclSingletonBeginStr, def.getCppClassName(), - def.getCppBaseClassName(), valueType, defTypePrefix, - traitStr); - } else { - os << formatv(defDeclParametricBeginStr, def.getCppClassName(), - def.getCppBaseClassName(), def.getStorageNamespace(), - def.getStorageClassName(), valueType, defTypePrefix, - traitStr); - } - - // Emit the extra declarations first in case there's a definition in there. - if (Optional extraDecl = def.getExtraDecls()) - os << *extraDecl << "\n"; - - ParamCommaFormatter emitTypeNamePairsAfterComma( - ParamCommaFormatter::EmitFormat::TypeNamePairs, params); - if (!params.empty()) { - emitBuilderDecls(def, os, emitTypeNamePairsAfterComma); - - // Emit the verify invariants declaration. - if (def.genVerifyDecl()) - os << llvm::formatv(defDeclVerifyStr, emitTypeNamePairsAfterComma); - } - - // Emit the mnenomic, if specified. - if (auto mnenomic = def.getMnemonic()) { - os << " static constexpr ::llvm::StringLiteral getMnemonic() {\n" - << " return ::llvm::StringLiteral(\"" << mnenomic << "\");\n" - << " }\n"; - - // If mnemonic specified, emit print/parse declarations. - if (def.getParserCode() || def.getPrinterCode() || - def.getAssemblyFormat() || !params.empty()) { - os << llvm::formatv(defDeclParsePrintStr, valueType, - isAttrGenerator ? ", ::mlir::Type type" : ""); - } - } - - if (def.genAccessors()) { - SmallVector parameters; - def.getParameters(parameters); - - for (AttrOrTypeParameter ¶meter : parameters) { - os << formatv(" {0} {1}() const;\n", parameter.getCppAccessorType(), - getParameterAccessorName(parameter.getName())); - } - } - - // Emit any interface method declarations. - for (const Trait &trait : def.getTraits()) { - if (const auto *traitDef = dyn_cast(&trait)) { - if (traitDef->shouldDeclareMethods()) - emitInterfaceMethodDecls(traitDef, os); - } - } - - // End the decl. - os << " };\n"; -} - bool DefGenerator::emitDecls(StringRef selectedDialect) { - emitSourceFileHeader((defTypePrefix + "Def Declarations").str(), os); - IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os); + emitSourceFileHeader((defType + "Def Declarations").str(), os); + IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os); // Output the common "header". os << typeDefDeclHeader; @@ -458,11 +678,11 @@ // Declare all the def classes first (in case they reference each other). for (const AttrOrTypeDef &def : defs) - os << " class " << def.getCppClassName() << ";\n"; + os << "class " << def.getCppClassName() << ";\n"; // Emit the declarations. for (const AttrOrTypeDef &def : defs) - emitDefDecl(def); + DefGen(def).emitDecl(os); } // Emit the TypeID explicit specializations to have a single definition for // each of these. @@ -479,7 +699,7 @@ //===----------------------------------------------------------------------===// void DefGenerator::emitTypeDefList(ArrayRef defs) { - IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_LIST", os); + IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os); auto interleaveFn = [&](const AttrOrTypeDef &def) { os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName(); }; @@ -491,17 +711,6 @@ // GEN: Definitions //===----------------------------------------------------------------------===// -/// The code block used to start the auto-generated parser function. -/// -/// {0}: The name of the base value type, e.g. Attribute or Type. -/// {1}: Additional parser parameters. -static const char *const defParserDispatchStartStr = R"( -static ::mlir::OptionalParseResult generated{0}Parser( - ::mlir::AsmParser &parser, - ::llvm::StringRef mnemonic{1}, - ::mlir::{0} &value) {{ -)"; - /// The code block for default attribute parser/printer dispatch boilerplate. /// {0}: the dialect fully qualified class name. static const char *const dialectDefaultAttrPrinterParserDispatch = R"( @@ -556,412 +765,6 @@ } )"; -/// The code block used to start the auto-generated printer function. -/// -/// {0}: The name of the base value type, e.g. Attribute or Type. -static const char *const defPrinterDispatchStartStr = R"( -static ::mlir::LogicalResult generated{0}Printer( - ::mlir::{0} def, ::mlir::AsmPrinter &printer) {{ - return ::llvm::TypeSwitch<::mlir::{0}, ::mlir::LogicalResult>(def) -)"; - -/// Beginning of storage class. -/// {0}: Storage class namespace. -/// {1}: Storage class c++ name. -/// {2}: Parameters parameters. -/// {3}: Parameter initializer string. -/// {4}: Parameter types. -/// {5}: The name of the base value type, e.g. Attribute or Type. -static const char *const defStorageClassBeginStr = R"( -namespace {0} {{ - struct {1} : public ::mlir::{5}Storage {{ - {1} ({2}) - : {3} {{ } - - /// The hash key is a tuple of the parameter types. - using KeyTy = std::tuple<{4}>; -)"; - -/// The storage class' constructor template. -/// -/// {0}: storage class name. -/// {1}: The name of the base value type, e.g. Attribute or Type. -static const char *const defStorageClassConstructorBeginStr = R"( - /// Define a construction method for creating a new instance of this - /// storage. - static {0} *construct(::mlir::{1}StorageAllocator &allocator, - const KeyTy &tblgenKey) {{ -)"; - -/// The storage class' constructor return template. -/// -/// {0}: storage class name. -/// {1}: list of parameters. -static const char *const defStorageClassConstructorEndStr = R"( - return new (allocator.allocate<{0}>()) - {0}({1}); - } -)"; - -/// Use tgfmt to emit custom allocation code for each parameter, if necessary. -static void emitStorageParameterAllocation(const AttrOrTypeDef &def, - raw_ostream &os) { - SmallVector parameters; - def.getParameters(parameters); - FmtContext fmtCtxt = FmtContext().addSubst("_allocator", "allocator"); - for (AttrOrTypeParameter ¶meter : parameters) { - if (Optional allocCode = parameter.getAllocator()) { - fmtCtxt.withSelf(parameter.getName()); - fmtCtxt.addSubst("_dst", parameter.getName()); - os << " " << tgfmt(*allocCode, &fmtCtxt) << "\n"; - } - } -} - -/// Builds a code block that initializes the attribute storage of 'def'. -/// Attribute initialization is separated from Type initialization given that -/// the Attribute also needs to initialize its self-type, which has multiple -/// means of initialization. -static std::string buildAttributeStorageParamInitializer( - const AttrOrTypeDef &def, ArrayRef parameters) { - std::string paramInitializer; - llvm::raw_string_ostream paramOS(paramInitializer); - paramOS << "::mlir::AttributeStorage("; - - // If this is an attribute, we need to check for value type initialization. - Optional selfParamIndex; - for (auto it : llvm::enumerate(parameters)) { - const auto *selfParam = dyn_cast(&it.value()); - if (!selfParam) - continue; - if (selfParamIndex) { - llvm::PrintFatalError(def.getLoc(), - "Only one attribute parameter can be marked as " - "AttributeSelfTypeParameter"); - } - paramOS << selfParam->getName(); - selfParamIndex = it.index(); - } - - // If we didn't find a self param, but the def has a type builder we use that - // to construct the type. - if (!selfParamIndex) { - const AttrDef &attrDef = cast(def); - if (Optional typeBuilder = attrDef.getTypeBuilder()) { - FmtContext fmtContext; - for (const AttrOrTypeParameter ¶m : parameters) - fmtContext.addSubst(("_" + param.getName()).str(), param.getName()); - paramOS << tgfmt(*typeBuilder, &fmtContext); - } - } - paramOS << ")"; - - // Append the parameters to the initializer. - for (auto it : llvm::enumerate(parameters)) - if (it.index() != selfParamIndex) - paramOS << llvm::formatv(", {0}({0})", it.value().getName()); - - return paramOS.str(); -} - -void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) { - SmallVector params; - def.getParameters(params); - - // Collect the parameter types. - auto parameterTypes = - llvm::map_range(params, [](const AttrOrTypeParameter ¶meter) { - return parameter.getCppType(); - }); - std::string parameterTypeList = llvm::join(parameterTypes, ", "); - - // Collect the parameter initializer. - std::string paramInitializer; - if (isAttrGenerator) { - paramInitializer = buildAttributeStorageParamInitializer(def, params); - - } else { - llvm::raw_string_ostream initOS(paramInitializer); - llvm::interleaveComma(params, initOS, [&](const AttrOrTypeParameter &it) { - initOS << llvm::formatv("{0}({0})", it.getName()); - }); - } - - // * Emit most of the storage class up until the hashKey body. - os << formatv( - defStorageClassBeginStr, def.getStorageNamespace(), - def.getStorageClassName(), - ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, - params, /*prependComma=*/false), - paramInitializer, parameterTypeList, valueType); - - // * Emit the comparison method. - os << " bool operator==(const KeyTy &tblgenKey) const {\n"; - for (auto it : llvm::enumerate(params)) { - os << " if (!("; - - // Build the comparator context. - bool isSelfType = isa(it.value()); - FmtContext context; - context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName()) - .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(tblgenKey)"); - - // Use the parameter specified comparator if possible, otherwise default to - // operator==. - Optional comparator = it.value().getComparator(); - os << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &context); - os << "))\n return false;\n"; - } - os << " return true;\n }\n"; - - // * Emit the haskKey method. - os << " static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) {\n"; - - // Extract each parameter from the key. - os << " return ::llvm::hash_combine("; - llvm::interleaveComma( - llvm::seq(0, params.size()), os, - [&](unsigned it) { os << "std::get<" << it << ">(tblgenKey)"; }); - os << ");\n }\n"; - - // * Emit the construct method. - - // If user wants to build the storage constructor themselves, declare it - // here and then they can write the definition elsewhere. - if (def.hasStorageCustomConstructor()) { - os << llvm::formatv(" static {0} *construct(::mlir::{1}StorageAllocator " - "&allocator, const KeyTy &tblgenKey);\n", - def.getStorageClassName(), valueType); - - // Otherwise, generate one. - } else { - // First, unbox the parameters. - os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(), - valueType); - for (unsigned i = 0, e = params.size(); i < e; ++i) { - os << formatv(" auto {0} = std::get<{1}>(tblgenKey);\n", - params[i].getName(), i); - } - - // Second, reassign the parameter variables with allocation code, if it's - // specified. - emitStorageParameterAllocation(def, os); - - // Last, return an allocated copy. - auto parameterNames = llvm::map_range( - params, [](const auto ¶m) { return param.getName(); }); - os << formatv(defStorageClassConstructorEndStr, def.getStorageClassName(), - llvm::join(parameterNames, ", ")); - } - - // * Emit the parameters as storage class members. - for (const AttrOrTypeParameter ¶meter : params) { - // Attribute value types are not stored as fields in the storage. - if (!isa(parameter)) - os << " " << parameter.getCppType() << " " << parameter.getName() - << ";\n"; - } - os << " };\n"; - - os << "} // namespace " << def.getStorageNamespace() << "\n"; -} - -void DefGenerator::emitParsePrint(const AttrOrTypeDef &def) { - auto printerCode = def.getPrinterCode(); - auto parserCode = def.getParserCode(); - auto assemblyFormat = def.getAssemblyFormat(); - if (assemblyFormat && (printerCode || parserCode)) { - // Custom assembly format cannot be specified at the same time as either - // custom printer or parser code. - PrintFatalError(def.getLoc(), - def.getName() + ": assembly format cannot be specified at " - "the same time as printer or parser code"); - } - - // Generate a parser and printer based on the assembly format, if specified. - if (assemblyFormat) { - // A custom assembly format requires accessors to be generated for the - // generated printer. - if (!def.genAccessors()) { - PrintFatalError(def.getLoc(), - def.getName() + - ": the generated printer from 'assemblyFormat' " - "requires 'genAccessors' to be true"); - } - return generateAttrOrTypeFormat(def, os); - } - - // Emit the printer code, if specified. - if (printerCode) { - // Both the mnenomic and printerCode must be defined (for parity with - // parserCode). - os << "void " << def.getCppClassName() - << "::print(::mlir::AsmPrinter &printer) const {\n"; - if (printerCode->empty()) { - // If no code specified, emit error. - PrintFatalError(def.getLoc(), - def.getName() + - ": printer (if specified) must have non-empty code"); - } - FmtContext fmtCtxt = FmtContext().addSubst("_printer", "printer"); - os << tgfmt(*printerCode, &fmtCtxt) << "\n}\n"; - } - - // Emit the parser code, if specified. - if (parserCode) { - FmtContext fmtCtxt; - fmtCtxt.addSubst("_parser", "parser") - .addSubst("_ctxt", "parser.getContext()"); - - // The mnenomic must be defined so the dispatcher knows how to dispatch. - os << llvm::formatv("::mlir::{0} {1}::parse(" - "::mlir::AsmParser &parser", - valueType, def.getCppClassName()); - if (isAttrGenerator) { - // Attributes also accept a type parameter instead of a context. - os << ", ::mlir::Type type"; - fmtCtxt.addSubst("_type", "type"); - } - os << ") {\n"; - - if (parserCode->empty()) { - PrintFatalError(def.getLoc(), - def.getName() + - ": parser (if specified) must have non-empty code"); - } - os << tgfmt(*parserCode, &fmtCtxt) << "\n}\n"; - } -} - -/// Replace all instances of 'from' to 'to' in `str` and return the new string. -static std::string replaceInStr(std::string str, StringRef from, StringRef to) { - size_t pos = 0; - while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos) - str.replace(pos, from.size(), to.data(), to.size()); - return str; -} - -/// Emit the builders for the given def. -static void emitBuilderDefs(const AttrOrTypeDef &def, raw_ostream &os, - ArrayRef params) { - bool genCheckedMethods = def.genVerifyDecl(); - StringRef className = def.getCppClassName(); - if (!def.skipDefaultBuilders()) { - os << llvm::formatv( - "{0} {0}::get(::mlir::MLIRContext *context{1}) {{\n" - " return Base::get(context{2});\n}\n", - className, - ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, - params), - ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams, - params)); - if (genCheckedMethods) { - os << llvm::formatv( - "{0} {0}::getChecked(" - "llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, " - "::mlir::MLIRContext *context{1}) {{\n" - " return Base::getChecked(emitError, context{2});\n}\n", - className, - ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, - params), - ParamCommaFormatter(ParamCommaFormatter::EmitFormat::JustParams, - params)); - } - } - - auto builderFmtCtx = - FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get"); - auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get"); - auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context"); - - // Generate the builders specified by the user. - for (const AttrOrTypeBuilder &builder : def.getBuilders()) { - Optional body = builder.getBody(); - if (!body) - continue; - std::string paramStr; - llvm::raw_string_ostream paramOS(paramStr); - llvm::interleaveComma(builder.getParameters(), paramOS, - [&](const AttrOrTypeBuilder::Parameter ¶m) { - // Note: AttrOrTypeBuilder parameters are guaranteed - // to have names. - paramOS << param.getCppType() << " " - << *param.getName(); - }); - paramOS.flush(); - - // Emit the `get` variant of the builder. - os << llvm::formatv("{0} {0}::get(", className); - if (!builder.hasInferredContextParameter()) { - os << "::mlir::MLIRContext *context"; - if (!paramStr.empty()) - os << ", "; - os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, - tgfmt(*body, &builderFmtCtx).str()); - } else { - os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, - tgfmt(*body, &inferredCtxBuilderFmtCtx).str()); - } - - // Emit the `getChecked` variant of the builder. - if (genCheckedMethods) { - os << llvm::formatv("{0} " - "{0}::getChecked(llvm::function_ref<::mlir::" - "InFlightDiagnostic()> emitErrorFn", - className); - std::string checkedBody = - replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, "); - if (!builder.hasInferredContextParameter()) { - os << ", ::mlir::MLIRContext *context"; - checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str(); - } - if (!paramStr.empty()) - os << ", "; - os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody); - } - } -} - -/// Print all the def-specific definition code. -void DefGenerator::emitDefDef(const AttrOrTypeDef &def) { - NamespaceEmitter ns(os, def.getDialect()); - - SmallVector parameters; - def.getParameters(parameters); - if (!parameters.empty()) { - // Emit the storage class, if requested and necessary. - if (def.genStorageClass()) - emitStorageClass(def); - - // Emit the builders for this def. - emitBuilderDefs(def, os, parameters); - - // Generate accessor definitions only if we also generate the storage class. - // Otherwise, let the user define the exact accessor definition. - if (def.genAccessors() && def.genStorageClass()) { - for (const AttrOrTypeParameter ¶m : parameters) { - SmallString<32> paramStorageName; - if (isa(param)) { - Twine("getType().cast<" + param.getCppType() + ">()") - .toVector(paramStorageName); - } else { - paramStorageName = param.getName(); - } - - os << formatv("{0} {3}::{1}() const {{ return getImpl()->{2}; }\n", - param.getCppAccessorType(), - getParameterAccessorName(param.getName()), - paramStorageName, def.getCppClassName()); - } - } - } - - // If mnemonic is specified maybe print definitions for the parser and printer - // code, if they're specified. - if (def.getMnemonic()) - emitParsePrint(def); -} - /// Emit the dialect printer/parser dispatcher. User's code should call these /// functions from their dialect's print/parse methods. void DefGenerator::emitParsePrintDispatch(ArrayRef defs) { @@ -970,59 +773,66 @@ })) { return; } + /// Declare the parser. + SmallVector params = {{"::mlir::AsmParser &", "parser"}, + {"::llvm::StringRef", "mnemonic"}}; + if (isAttrGenerator) + params.emplace_back("::mlir::Type", "type"); + params.emplace_back(strfmt("::mlir::{0} &", valueType), "value"); + Method parse("::mlir::OptionalParseResult", + strfmt("generated{0}Parser", valueType), Method::StaticInline, + std::move(params)); + /// Declare the printer. + Method printer("::mlir::LogicalResult", + strfmt("generated{0}Printer", valueType), Method::StaticInline, + {{strfmt("::mlir::{0}", valueType), "def"}, + {"::mlir::AsmPrinter &", "printer"}}); // The parser dispatch is just a list of if-elses, matching on the mnemonic // and calling the def's parse function. - os << llvm::formatv(defParserDispatchStartStr, valueType, - isAttrGenerator ? ", ::mlir::Type type" : ""); - for (const AttrOrTypeDef &def : defs) { - if (def.getMnemonic()) { - os << formatv(" if (mnemonic == {0}::{1}::getMnemonic()) { \n" - " value = {0}::{1}::", - def.getDialect().getCppNamespace(), def.getCppClassName()); - - // If the def has no parameters and no parser code, just invoke a normal - // `get`. - if (def.getNumParameters() == 0 && !def.getParserCode()) { - os << "get(parser.getContext());\n"; - os << " return ::mlir::success(!!value);\n }\n"; - continue; - } - - os << "parse(parser" << (isAttrGenerator ? ", type" : "") - << ");\n return ::mlir::success(!!value);\n }\n"; - } + const char *const getValueForMnemonic = + R"( if (mnemonic == {0}::getMnemonic()) {{ + value = {0}::{1}; + return ::mlir::success(!!value); } - os << " return {};\n"; - os << "}\n\n"; - +)"; // The printer dispatch uses llvm::TypeSwitch to find and call the correct // printer. - os << llvm::formatv(defPrinterDispatchStartStr, valueType); - for (const AttrOrTypeDef &def : defs) { - Optional mnemonic = def.getMnemonic(); - if (!mnemonic) + printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType + << ", ::mlir::LogicalResult>(def)"; + const char *const printValue = R"( .Case<{0}>([&](auto t) {{ + printer << {0}::getMnemonic();{1} + return ::mlir::success(); + }) +)"; + for (auto &def : defs) { + if (!def.getMnemonic()) continue; - - StringRef cppNamespace = def.getDialect().getCppNamespace(); - StringRef cppClassName = def.getCppClassName(); - os << formatv(" .Case<{0}::{1}>([&]({0}::{1} t) {{\n ", - cppNamespace, cppClassName); - - os << formatv("printer << {0}::{1}::getMnemonic();", cppNamespace, - cppClassName); - // If the def has no parameters and no printer, just print the mnemonic. - if (def.getNumParameters() != 0 || def.getPrinterCode()) - os << "t.print(printer);"; - os << "\n return ::mlir::success();\n })\n"; - } - os << llvm::formatv( - " .Default([](::mlir::{0}) {{ return ::mlir::failure(); });\n}\n\n", - valueType); + std::string defClass = strfmt( + "{0}::{1}", def.getDialect().getCppNamespace(), def.getCppClassName()); + /// If the def has no parameters or parser code, invoke a normal `get`. + std::string parseOrGet = + def.needsParserPrinter() || def.hasGeneratedParser() + ? strfmt("parse(parser{0})", isAttrGenerator ? ", type" : "") + : "get(parser.getContext())"; + parse.body() << llvm::formatv(getValueForMnemonic, defClass, parseOrGet); + + /// If the def has no parameters and no printer, just print the mnemonic. + StringRef printDef = ""; + if (def.needsParserPrinter() || def.hasGeneratedPrinter()) + printDef = "\nt.print(printer);"; + printer.body() << llvm::formatv(printValue, defClass, printDef); + } + parse.body() << " return {};"; + printer.body() << " .Default([](auto) { return ::mlir::failure(); });"; + + raw_indented_ostream indentedOs(os); + parse.writeDeclTo(indentedOs); + printer.writeDeclTo(indentedOs); } bool DefGenerator::emitDefs(StringRef selectedDialect) { - emitSourceFileHeader((defTypePrefix + "Def Definitions").str(), os); + emitSourceFileHeader((defType + "Def Definitions").str(), os); SmallVector defs; collectAllDefs(selectedDialect, defRecords, defs); @@ -1030,10 +840,14 @@ return false; emitTypeDefList(defs); - IfDefScope scope("GET_" + defTypePrefix.upper() + "DEF_CLASSES", os); + IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os); emitParsePrintDispatch(defs); for (const AttrOrTypeDef &def : defs) { - emitDefDef(def); + { + NamespaceEmitter ns(os, def.getDialect()); + DefGen gen(def); + gen.emitDef(os); + } // Emit the TypeID explicit specializations to have a single symbol def. if (!def.getDialect().getCppNamespace().empty()) os << "DEFINE_EXPLICIT_TYPE_ID(" << def.getDialect().getCppNamespace() @@ -1041,16 +855,18 @@ } // Emit the default parser/printer for Attributes if the dialect asked for it. - if (valueType == "Attribute" && - defs.front().getDialect().useDefaultAttributePrinterParser()) + if (isAttrGenerator && + defs.front().getDialect().useDefaultAttributePrinterParser()) { os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, defs.front().getDialect().getCppClassName()); + } // Emit the default parser/printer for Types if the dialect asked for it. - if (valueType == "Type" && - defs.front().getDialect().useDefaultTypePrinterParser()) + if (!isAttrGenerator && + defs.front().getDialect().useDefaultTypePrinterParser()) { os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, defs.front().getDialect().getCppClassName()); + } return false; } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h @@ -9,9 +9,7 @@ #ifndef MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_ #define MLIR_TOOLS_MLIRTBLGEN_ATTRORTYPEFORMATGEN_H_ -#include "llvm/Support/raw_ostream.h" - -#include +#include "mlir/TableGen/Class.h" namespace mlir { namespace tblgen { @@ -19,7 +17,8 @@ /// Generate a parser and printer based on a custom assembly format for an /// attribute or type. -void generateAttrOrTypeFormat(const AttrOrTypeDef &def, llvm::raw_ostream &os); +void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser, + MethodBody &printer); /// From the parameter name, get the name of the accessor function in camelcase. /// The first letter of the parameter is upper-cased and prefixed with "get". diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -158,21 +158,6 @@ // Format Strings //===----------------------------------------------------------------------===// -/// Format for defining an attribute parser. -/// -/// $0: The attribute C++ class name. -static const char *const attrParserDefn = R"( -::mlir::Attribute $0::parse(::mlir::AsmParser &$_parser, - ::mlir::Type $_type) { -)"; - -/// Format for defining a type parser. -/// -/// $0: The type C++ class name. -static const char *const typeParserDefn = R"( -::mlir::Type $0::parse(::mlir::AsmParser &$_parser) { -)"; - /// Default parser for attribute or type parameters. static const char *const defaultParameterParser = "::mlir::FieldParser<$0>::parse($_parser)"; @@ -186,13 +171,6 @@ static const char *const parseErrorStr = "$_parser.emitError($_parser.getCurrentLocation(), "; -/// Format for defining an attribute or type printer. -/// -/// $0: The attribute or type C++ class name. -static const char *const attrOrTypePrinterDefn = R"( -void $0::print(::mlir::AsmPrinter &$_printer) const { -)"; - /// Loop declaration for struct parser. /// /// $0: Number of expected parameters. @@ -212,12 +190,12 @@ /// {0}: Code template for printing an error. /// {1}: Number of elements in the struct. static const char *const structParseLoopEnd = R"({{ - {0}"duplicate or unknown struct parameter name: ") << _paramKey; - return {{}; - } - if ((_index != {1} - 1) && parser.parseComma()) - return {{}; + {0}"duplicate or unknown struct parameter name: ") << _paramKey; + return {{}; } + if ((_index != {1} - 1) && parser.parseComma()) + return {{}; +} )"; /// Code format to parse a variable. Separate by lines because variable parsers @@ -228,26 +206,14 @@ /// {2}: Code template for printing an error. /// {3}: Name of the attribute or type. /// {4}: C++ class of the parameter. -static const char *const variableParser[] = { - " // Parse variable '{0}'", - " _result_{0} = {1};", - " if (failed(_result_{0})) {{", - " {2}\"failed to parse {3} parameter '{0}' which is to be a `{4}`\");", - " return {{};", - " }", -}; - -//===----------------------------------------------------------------------===// -// Utility Functions -//===----------------------------------------------------------------------===// - -/// Get a list of an attribute's or type's parameters. These can be wrapper -/// objects around `AttrOrTypeParameter` or string inits. -static auto getParameters(const AttrOrTypeDef &def) { - SmallVector params; - def.getParameters(params); - return params; +static const char *const variableParser = R"( +// Parse variable '{0}' +_result_{0} = {1}; +if (failed(_result_{0})) {{ + {2}"failed to parse {3} parameter '{0}' which is to be a `{4}`"); + return {{}; } +)"; //===----------------------------------------------------------------------===// // AttrOrTypeFormat @@ -261,35 +227,34 @@ : def(def), elements(std::move(elements)) {} /// Generate the attribute or type parser. - void genParser(raw_ostream &os); + void genParser(MethodBody &os); /// Generate the attribute or type printer. - void genPrinter(raw_ostream &os); + void genPrinter(MethodBody &os); private: /// Generate the parser code for a specific format element. - void genElementParser(Element *el, FmtContext &ctx, raw_ostream &os); + void genElementParser(Element *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a literal. - void genLiteralParser(StringRef value, FmtContext &ctx, raw_ostream &os, - unsigned indent = 0); + void genLiteralParser(StringRef value, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a variable. void genVariableParser(const AttrOrTypeParameter ¶m, FmtContext &ctx, - raw_ostream &os, unsigned indent = 0); + MethodBody &os); /// Generate the parser code for a `params` directive. - void genParamsParser(ParamsDirective *el, FmtContext &ctx, raw_ostream &os); + void genParamsParser(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the parser code for a `struct` directive. - void genStructParser(StructDirective *el, FmtContext &ctx, raw_ostream &os); + void genStructParser(StructDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a specific format element. - void genElementPrinter(Element *el, FmtContext &ctx, raw_ostream &os); + void genElementPrinter(Element *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a literal. - void genLiteralPrinter(StringRef value, FmtContext &ctx, raw_ostream &os); + void genLiteralPrinter(StringRef value, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a variable. void genVariablePrinter(const AttrOrTypeParameter ¶m, FmtContext &ctx, - raw_ostream &os); + MethodBody &os); /// Generate the printer code for a `params` directive. - void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, raw_ostream &os); + void genParamsPrinter(ParamsDirective *el, FmtContext &ctx, MethodBody &os); /// Generate the printer code for a `struct` directive. - void genStructPrinter(StructDirective *el, FmtContext &ctx, raw_ostream &os); + void genStructPrinter(StructDirective *el, FmtContext &ctx, MethodBody &os); /// The ODS definition of the attribute or type whose format is being used to /// generate a parser and printer. @@ -308,23 +273,18 @@ // ParserGen //===----------------------------------------------------------------------===// -void AttrOrTypeFormat::genParser(raw_ostream &os) { +void AttrOrTypeFormat::genParser(MethodBody &os) { FmtContext ctx; ctx.addSubst("_parser", "parser"); - - /// Generate the definition. - if (isa(def)) { - ctx.addSubst("_type", "attrType"); - os << tgfmt(attrParserDefn, &ctx, def.getCppClassName()); - } else { - os << tgfmt(typeParserDefn, &ctx, def.getCppClassName()); - } + if (isa(def)) + ctx.addSubst("_type", "type"); + os.indent(); /// Declare variables to store all of the parameters. Allocated parameters /// such as `ArrayRef` and `StringRef` must provide a `storageType`. Store /// FailureOr to defer type construction for parameters that are parsed in /// a loop (parsers return FailureOr anyways). - SmallVector params = getParameters(def); + SmallVector params = def.getParameters(); for (const AttrOrTypeParameter ¶m : params) { os << formatv(" ::mlir::FailureOr<{0}> _result_{1};\n", param.getCppStorageType(), param.getName()); @@ -332,8 +292,8 @@ /// Store the initial location of the parser. ctx.addSubst("_loc", "loc"); - os << tgfmt(" ::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n" - " (void) $_loc;\n", + os << tgfmt("::llvm::SMLoc $_loc = $_parser.getCurrentLocation();\n" + "(void) $_loc;\n", &ctx); /// Generate call to each parameter parser. @@ -343,19 +303,19 @@ /// Generate call to the attribute or type builder. Use the checked getter /// if one was generated. if (def.genVerifyDecl()) { - os << tgfmt(" return $_parser.getChecked<$0>($_loc, $_parser.getContext()", + os << tgfmt("return $_parser.getChecked<$0>($_loc, $_parser.getContext()", &ctx, def.getCppClassName()); } else { - os << tgfmt(" return $0::get($_parser.getContext()", &ctx, + os << tgfmt("return $0::get($_parser.getContext()", &ctx, def.getCppClassName()); } for (const AttrOrTypeParameter ¶m : params) os << formatv(",\n _result_{0}.getValue()", param.getName()); - os << ");\n}\n\n"; + os << ");"; } void AttrOrTypeFormat::genElementParser(Element *el, FmtContext &ctx, - raw_ostream &os) { + MethodBody &os) { if (auto *literal = dyn_cast(el)) return genLiteralParser(literal->getSpelling(), ctx, os); if (auto *var = dyn_cast(el)) @@ -369,9 +329,9 @@ } void AttrOrTypeFormat::genLiteralParser(StringRef value, FmtContext &ctx, - raw_ostream &os, unsigned indent) { - os.indent(indent) << " // Parse literal '" << value << "'\n"; - os.indent(indent) << tgfmt(" if ($_parser.parse", &ctx); + MethodBody &os) { + os << "// Parse literal '" << value << "'\n"; + os << tgfmt("if ($_parser.parse", &ctx); if (value.front() == '_' || isalpha(value.front())) { os << "Keyword(\"" << value << "\")"; } else { @@ -395,28 +355,23 @@ } os << ")\n"; // Parser will emit an error - os.indent(indent) << " return {};\n"; + os << " return {};\n"; } void AttrOrTypeFormat::genVariableParser(const AttrOrTypeParameter ¶m, - FmtContext &ctx, raw_ostream &os, - unsigned indent) { + FmtContext &ctx, MethodBody &os) { /// Check for a custom parser. Use the default attribute parser otherwise. auto customParser = param.getParser(); auto parser = customParser ? *customParser : StringRef(defaultParameterParser); - for (const char *line : variableParser) { - os.indent(indent) << formatv(line, param.getName(), - tgfmt(parser, &ctx, param.getCppStorageType()), - tgfmt(parseErrorStr, &ctx), def.getName(), - param.getCppType()) - << "\n"; - } + os << formatv(variableParser, param.getName(), + tgfmt(parser, &ctx, param.getCppStorageType()), + tgfmt(parseErrorStr, &ctx), def.getName(), param.getCppType()); } void AttrOrTypeFormat::genParamsParser(ParamsDirective *el, FmtContext &ctx, - raw_ostream &os) { - os << " // Parse parameter list\n"; + MethodBody &os) { + os << "// Parse parameter list\n"; llvm::interleave( el->getParams(), [&](auto param) { this->genVariableParser(param, ctx, os); }, @@ -424,28 +379,30 @@ } void AttrOrTypeFormat::genStructParser(StructDirective *el, FmtContext &ctx, - raw_ostream &os) { - os << " // Parse parameter struct\n"; + MethodBody &os) { + os << "// Parse parameter struct\n"; /// Declare a "seen" variable for each key. for (const AttrOrTypeParameter ¶m : el->getParams()) - os << formatv(" bool _seen_{0} = false;\n", param.getName()); + os << formatv("bool _seen_{0} = false;\n", param.getName()); /// Generate the parsing loop. - os << tgfmt(structParseLoopStart, &ctx, el->getNumParams()); - genLiteralParser("=", ctx, os, 2); - os << " "; + os.getStream().reindent( + tgfmt(structParseLoopStart, &ctx, el->getNumParams()).str()); + os.indent(); + genLiteralParser("=", ctx, os); for (const AttrOrTypeParameter ¶m : el->getParams()) { os << formatv("if (!_seen_{0} && _paramKey == \"{0}\") {\n" - " _seen_{0} = true;\n", + " _seen_{0} = true;\n", param.getName()); - genVariableParser(param, ctx, os, 4); - os << " } else "; + genVariableParser(param, ctx, os.indent()); + os.unindent() << "} else "; } + os.unindent(); /// Duplicate or unknown parameter. - os << formatv(structParseLoopEnd, tgfmt(parseErrorStr, &ctx), - el->getNumParams()); + os.getStream().reindent(strfmt(structParseLoopEnd, tgfmt(parseErrorStr, &ctx), + el->getNumParams())); /// Because the loop loops N times and each non-failing iteration sets 1 of /// N flags, successfully exiting the loop means that all parameters have been @@ -457,24 +414,19 @@ // PrinterGen //===----------------------------------------------------------------------===// -void AttrOrTypeFormat::genPrinter(raw_ostream &os) { +void AttrOrTypeFormat::genPrinter(MethodBody &os) { FmtContext ctx; ctx.addSubst("_printer", "printer"); - /// Generate the definition. - os << tgfmt(attrOrTypePrinterDefn, &ctx, def.getCppClassName()); - /// Generate printers. shouldEmitSpace = true; lastWasPunctuation = false; for (auto &el : elements) genElementPrinter(el.get(), ctx, os); - - os << "}\n\n"; } void AttrOrTypeFormat::genElementPrinter(Element *el, FmtContext &ctx, - raw_ostream &os) { + MethodBody &os) { if (auto *literal = dyn_cast(el)) return genLiteralPrinter(literal->getSpelling(), ctx, os); if (auto *params = dyn_cast(el)) @@ -488,7 +440,7 @@ } void AttrOrTypeFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, - raw_ostream &os) { + MethodBody &os) { /// Don't insert a space before certain punctuation. bool needSpace = shouldEmitSpace && shouldEmitSpaceBefore(value, lastWasPunctuation); @@ -502,7 +454,7 @@ } void AttrOrTypeFormat::genVariablePrinter(const AttrOrTypeParameter ¶m, - FmtContext &ctx, raw_ostream &os) { + FmtContext &ctx, MethodBody &os) { /// Insert a space before the next parameter, if necessary. if (shouldEmitSpace || !lastWasPunctuation) os << tgfmt(" $_printer << ' ';\n", &ctx); @@ -518,7 +470,7 @@ } void AttrOrTypeFormat::genParamsPrinter(ParamsDirective *el, FmtContext &ctx, - raw_ostream &os) { + MethodBody &os) { llvm::interleave( el->getParams(), [&](auto param) { this->genVariablePrinter(param, ctx, os); }, @@ -526,13 +478,12 @@ } void AttrOrTypeFormat::genStructPrinter(StructDirective *el, FmtContext &ctx, - raw_ostream &os) { + MethodBody &os) { llvm::interleave( el->getParams(), [&](auto param) { this->genLiteralPrinter(param.getName(), ctx, os); this->genLiteralPrinter("=", ctx, os); - os << tgfmt(" $_printer << ' ';\n", &ctx); this->genVariablePrinter(param, ctx, os); }, [&]() { this->genLiteralPrinter(",", ctx, os); }); @@ -624,7 +575,7 @@ } /// Check that all parameters have been seen. - SmallVector params = getParameters(def); + SmallVector params = def.getParameters(); for (auto it : llvm::enumerate(params)) { if (!seenParams.test(it.index())) { return emitError("format is missing reference to parameter: " + @@ -669,7 +620,7 @@ auto name = curToken.getSpelling().drop_front(); /// Lookup the parameter. - SmallVector params = getParameters(def); + SmallVector params = def.getParameters(); auto *it = llvm::find_if( params, [&](auto ¶m) { return param.getName() == name; }); @@ -705,7 +656,7 @@ FailureOr> FormatParser::parseParamsDirective() { consumeToken(); /// Collect all of the attribute's or type's parameters. - SmallVector params = getParameters(def); + SmallVector params = def.getParameters(); SmallVector> vars; /// Ensure that none of the parameters have already been captured. for (auto it : llvm::enumerate(params)) { @@ -759,15 +710,16 @@ //===----------------------------------------------------------------------===// void mlir::tblgen::generateAttrOrTypeFormat(const AttrOrTypeDef &def, - raw_ostream &os) { + MethodBody &parser, + MethodBody &printer) { llvm::SourceMgr mgr; mgr.AddNewSourceBuffer( llvm::MemoryBuffer::getMemBuffer(*def.getAssemblyFormat()), llvm::SMLoc()); /// Parse the custom assembly format> - FormatParser parser(mgr, def); - FailureOr format = parser.parse(); + FormatParser fmtParser(mgr, def); + FailureOr format = fmtParser.parse(); if (failed(format)) { if (formatErrorIsFatal) PrintFatalError(def.getLoc(), "failed to parse assembly format"); @@ -775,6 +727,6 @@ } /// Generate the parser and printer. - format->genParser(os); - format->genPrinter(os); + format->genParser(parser); + format->genPrinter(printer); } diff --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/tools/mlir-tblgen/OpClass.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/OpClass.h @@ -0,0 +1,37 @@ +//===- OpClass.h - Implementation of an Op Class --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TOOLS_MLIRTBLGEN_OPCLASS_H_ +#define MLIR_TOOLS_MLIRTBLGEN_OPCLASS_H_ + +#include "mlir/TableGen/Class.h" + +namespace mlir { +namespace tblgen { + +// Class for holding an op for C++ code emission +class OpClass : public Class { +public: + OpClass(StringRef name, StringRef extraClassDeclaration); + + // Add an op trait. + void addTrait(Twine trait) { parent.addTemplateParam(trait.str()); } + + void finalize() override; + +private: + /// Hand-written extra class declarations. + StringRef extraClassDeclaration; + /// The parent class, which also contains the traits to be inherited. + ParentClass &parent; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TOOLS_MLIRTBLGEN_OPCLASS_H_ diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/OpClass.cpp @@ -0,0 +1,36 @@ +//===- OpClass.cpp - Implementation of an Op Class ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OpClass.h" + +using namespace mlir; +using namespace mlir::tblgen; + +//===----------------------------------------------------------------------===// +// OpClass definitions +//===----------------------------------------------------------------------===// + +OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) + : Class(name.str()), extraClassDeclaration(extraClassDeclaration), + parent(addParent("::mlir::Op")) { + parent.addTemplateParam(getClassName().str()); + declare(Visibility::Public); + /// Inherit functions from Op. + declare("Op::Op"); + declare("Op::print"); + /// Type alias for the adaptor class. + declare("Adaptor", className + "Adaptor"); +} + +void OpClass::finalize() { + Class::finalize(); + if (getLastVisibilityDecl() != Visibility::Public && + !extraClassDeclaration.drop_while(std::isspace).startswith("public")) + declare(Visibility::Public); + declare(extraClassDeclaration); +} diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -11,6 +11,7 @@ // //===----------------------------------------------------------------------===// +#include "OpClass.h" #include "OpFormatGen.h" #include "OpGenHelpers.h" #include "mlir/TableGen/Class.h" @@ -42,21 +43,21 @@ static const char *const odsBuilder = "odsBuilder"; static const char *const builderOpState = "odsState"; -// Code for an Op to lookup an attribute. Uses cached identifiers. -// -// {0}: The attribute's getter name. +/// Code for an Op to lookup an attribute. Uses cached identifiers. +/// +/// {0}: The attribute's getter name. static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())"; -// The logic to calculate the actual value range for a declared operand/result -// of an op with variadic operands/results. Note that this logic is not for -// general use; it assumes all variadic operands/results must have the same -// number of values. -// -// {0}: The list of whether each declared operand/result is variadic. -// {1}: The total number of non-variadic operands/results. -// {2}: The total number of variadic operands/results. -// {3}: The total number of actual values. -// {4}: "operand" or "result". +/// The logic to calculate the actual value range for a declared operand/result +/// of an op with variadic operands/results. Note that this logic is not for +/// general use; it assumes all variadic operands/results must have the same +/// number of values. +/// +/// {0}: The list of whether each declared operand/result is variadic. +/// {1}: The total number of non-variadic operands/results. +/// {2}: The total number of variadic operands/results. +/// {3}: The total number of actual values. +/// {4}: "operand" or "result". static const char *const sameVariadicSizeValueRangeCalcCode = R"( bool isVariadic[] = {{{0}}; int prevVariadicCount = 0; @@ -75,12 +76,12 @@ return {{start, size}; )"; -// The logic to calculate the actual value range for a declared operand/result -// of an op with variadic operands/results. Note that this logic is assumes -// the op has an attribute specifying the size of each operand/result segment -// (variadic or not). -// -// {0}: The name of the attribute specifying the segment sizes. +/// The logic to calculate the actual value range for a declared operand/result +/// of an op with variadic operands/results. Note that this logic is assumes +/// the op has an attribute specifying the size of each operand/result segment +/// (variadic or not). +/// +/// {0}: The name of the attribute specifying the segment sizes. static const char *const adapterSegmentSizeAttrInitCode = R"( assert(odsAttrs && "missing segment size attribute for op"); auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); @@ -99,11 +100,12 @@ start += sizeAttrValueIt[i]; return {start, sizeAttrValueIt[index]}; )"; -// The logic to calculate the actual value range for a declared operand -// of an op with variadic of variadic operands within the OpAdaptor. -// -// {0}: The name of the segment attribute. -// {1}: The index of the main operand. + +/// The logic to calculate the actual value range for a declared operand +/// of an op with variadic of variadic operands within the OpAdaptor. +/// +/// {0}: The name of the segment attribute. +/// {1}: The index of the main operand. static const char *const variadicOfVariadicAdaptorCalcCode = R"( auto tblgenTmpOperands = getODSOperands({1}); auto sizeAttrValues = {0}().getValues(); @@ -117,16 +119,20 @@ return tblgenTmpOperandGroups; )"; -// The logic to build a range of either operand or result values. -// -// {0}: The begin iterator of the actual values. -// {1}: The call to generate the start and length of the value range. +/// The logic to build a range of either operand or result values. +/// +/// {0}: The begin iterator of the actual values. +/// {1}: The call to generate the start and length of the value range. static const char *const valueRangeReturnCode = R"( auto valueRange = {1}; return {{std::next({0}, valueRange.first), std::next({0}, valueRange.first + valueRange.second)}; )"; +/// A header for indicating code sections. +/// +/// {0}: Some text, or a class name. +/// {1}: Some text. static const char *const opCommentHeader = R"( //===----------------------------------------------------------------------===// // {0} {1} @@ -554,7 +560,7 @@ OpEmitter::OpEmitter(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter) : def(op.getDef()), op(op), - opClass(op.getCppClassName(), op.getExtraClassDeclaration()), + opClass(op.getCppClassName().str(), op.getExtraClassDeclaration()), staticVerifierEmitter(staticVerifierEmitter) { verifyCtx.withOp("(*this->getOperation())"); verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); @@ -597,9 +603,15 @@ OpEmitter(op, staticVerifierEmitter).emitDef(os); } -void OpEmitter::emitDecl(raw_ostream &os) { opClass.writeDeclTo(os); } +void OpEmitter::emitDecl(raw_ostream &os) { + opClass.finalize(); + opClass.writeDeclTo(os); +} -void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } +void OpEmitter::emitDef(raw_ostream &os) { + opClass.finalize(); + opClass.writeDefTo(os); +} static void errorIfPruned(size_t line, Method *m, const Twine &methodName, const Operator &op) { @@ -654,7 +666,7 @@ // Emit the getAttributeNameForIndex methods. { - auto *method = opClass.addInlineMethod( + auto *method = opClass.addInlineMethod( "::mlir::Identifier", "getAttributeNameForIndex", MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); @@ -662,15 +674,17 @@ << " return getAttributeNameForIndex((*this)->getName(), index);"; } { - auto *method = opClass.addStaticInlineMethod( + auto *method = opClass.addStaticInlineMethod( "::mlir::Identifier", "getAttributeNameForIndex", MethodParameter("::mlir::OperationName", "name"), MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); - method->body() << "assert(index < " << attributeNames.size() - << " && \"invalid attribute index\");\n" - " return name.getAbstractOperation()" - "->getAttributeNames()[index];"; + + const char *const getAttrName = R"( + assert(index < {0} && "invalid attribute index"); + return name.getAbstractOperation()->getAttributeNames()[index]; +)"; + method->body() << formatv(getAttrName, attributeNames.size()); } // Generate the AttrName methods, that expose the attribute names to @@ -1504,8 +1518,7 @@ SmallVector arguments = getBuilderSignature(builder); Optional body = builder.getBody(); - Method::Property properties = - body ? Method::MP_Static : Method::MP_StaticDeclaration; + auto properties = body ? Method::Static : Method::StaticDeclaration; auto *method = opClass.addMethod("void", "build", properties, std::move(arguments)); if (body) @@ -1874,7 +1887,7 @@ SmallVector paramList; paramList.emplace_back("::mlir::RewritePatternSet &", "results"); paramList.emplace_back("::mlir::MLIRContext *", "context"); - auto kind = hasBody ? Method::MP_Static : Method::MP_StaticDeclaration; + auto kind = hasBody ? Method::Static : Method::StaticDeclaration; auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind, std::move(paramList)); @@ -1937,11 +1950,9 @@ for (const InterfaceMethod::Argument &arg : method.getArguments()) paramList.emplace_back(arg.type, arg.name); - auto properties = method.isStatic() ? Method::MP_Static : Method::MP_None; - if (declaration) - properties = - static_cast(properties | Method::MP_Declaration); - return opClass.addMethod(method.getReturnType(), method.getName(), properties, + auto props = (Method::Static & method.isStatic()) | + (Method::Declaration & declaration); + return opClass.addMethod(method.getReturnType(), method.getName(), props, std::move(paramList)); } @@ -1960,10 +1971,10 @@ SideEffect effect; /// The index if the kind is not static. - unsigned index : 30; + unsigned index; /// The kind of the location. - unsigned kind : 2; + unsigned kind; }; StringMap> interfaceEffects; @@ -2461,7 +2472,7 @@ } void OpEmitter::genOpNameGetter() { - auto *method = opClass.addStaticMethod( + auto *method = opClass.addStaticMethod( "::llvm::StringLiteral", "getOperationName"); ERROR_IF_PRUNED(method, "getOperationName", op); method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName() @@ -2537,18 +2548,18 @@ const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter) : op(op), adaptor(op.getAdaptorName()), staticVerifierEmitter(staticVerifierEmitter) { - adaptor.newField("::mlir::ValueRange", "odsOperands"); - adaptor.newField("::mlir::DictionaryAttr", "odsAttrs"); - adaptor.newField("::mlir::RegionRange", "odsRegions"); + adaptor.addField("::mlir::ValueRange", "odsOperands"); + adaptor.addField("::mlir::DictionaryAttr", "odsAttrs"); + adaptor.addField("::mlir::RegionRange", "odsRegions"); const auto *attrSizedOperands = - op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); + op.getTrait("::m::OpTrait::AttrSizedOperandSegments"); { SmallVector paramList; paramList.emplace_back("::mlir::ValueRange", "values"); paramList.emplace_back("::mlir::DictionaryAttr", "attrs", attrSizedOperands ? "" : "nullptr"); paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); - auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList)); + auto *constructor = adaptor.addConstructor(std::move(paramList)); constructor->addMemberInitializer("odsOperands", "values"); constructor->addMemberInitializer("odsAttrs", "attrs"); @@ -2556,7 +2567,7 @@ } { - auto *constructor = adaptor.addConstructorAndPrune( + auto *constructor = adaptor.addConstructor( MethodParameter(op.getCppClassName() + " &", "op")); constructor->addMemberInitializer("odsOperands", "op->getOperands()"); constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); @@ -2646,6 +2657,7 @@ // Add verification function. addVerification(); + adaptor.finalize(); } void OpOperandAdaptorEmitter::addVerification() { diff --git a/mlir/tools/mlir-tblgen/OpDocGen.cpp b/mlir/tools/mlir-tblgen/OpDocGen.cpp --- a/mlir/tools/mlir-tblgen/OpDocGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDocGen.cpp @@ -226,8 +226,7 @@ static void emitAttrOrTypeDefAssemblyFormat(const AttrOrTypeDef &def, raw_ostream &os) { - SmallVector parameters; - def.getParameters(parameters); + SmallVector parameters = def.getParameters(); if (parameters.empty()) { os << "\nSyntax: `!" << def.getDialect().getName() << "." << def.getMnemonic() << "`\n"; @@ -265,8 +264,7 @@ } // Emit parameter documentation. - SmallVector parameters; - def.getParameters(parameters); + SmallVector parameters = def.getParameters(); if (!parameters.empty()) { os << "\n#### Parameters:\n\n"; os << "| Parameter | C++ type | Description |\n" diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -8,11 +8,12 @@ #include "OpFormatGen.h" #include "FormatGen.h" +#include "OpClass.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/Class.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" -#include "mlir/TableGen/Class.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Trait.h" #include "llvm/ADT/MapVector.h"