diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/TableGen/Class.h @@ -0,0 +1,412 @@ +//===- Class.h - Helper classes for C++ code emission -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines several classes for Op C++ code emission. They are only +// expected to be used by MLIR TableGen backends. +// +// We emit the op declaration and definition into separate files: *Ops.h.inc +// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and +// the latter for dialect *Ops.cpp. This way provides a cleaner interface. +// +// In order to do this split, we need to track method signature and +// implementation logic separately. Signature information is used for both +// declaration and definition, while implementation logic is only for +// definition. So we have the following classes for C++ code emission. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TABLEGEN_CLASS_H_ +#define MLIR_TABLEGEN_CLASS_H_ + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/CodeGenHelpers.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +namespace mlir { +namespace tblgen { +class FmtObjectBase; + +/// This class contains a single method parameter for a C++ function. +class MethodParameter { +public: + /// Create a method parameter with a C++ type, parameter name, and an optional + /// default value. Marking a parameter as "optional" is a cosmetic effect on + /// the generated code. + template + MethodParameter(TypeT &&type, NameT &&name, DefaultT &&defaultValue, + bool optional = false) + : type(stringify(std::forward(type))), + name(stringify(std::forward(name))), + defaultValue(stringify(std::forward(defaultValue))), + optional(optional) {} + + /// Create a method parameter with a C++ type, parameter name, and no default + /// value. + template + MethodParameter(TypeT &&type, NameT &&name, bool optional = false) + : MethodParameter(std::forward(type), std::forward(name), + /*defaultValue=*/"", optional) {} + + /// Write the parameter as part of a method declaration. + void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); } + /// Write the parameter as part of a method definition. + void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); } + + /// Get the C++ type. + const std::string &getType() const { return type; } + /// Returns true if the parameter has a default value. + bool hasDefaultValue() const { return !defaultValue.empty(); } + +private: + void writeTo(raw_ostream &os, bool emitDefault) const; + + /// The C++ type. + std::string type; + /// The variable name. + std::string name; + /// An optional default value. The default value exists if the string is not + /// empty. + std::string defaultValue; + /// Whether the parameter should be indicated as "optional". + bool optional; +}; + +/// This class contains a list of method parameters for constructor, class +/// methods, and method signatures. +class MethodParameters { +public: + /// Create a list of method parameters. + MethodParameters(std::initializer_list parameters) + : parameters(parameters) {} + MethodParameters(SmallVector parameters) + : parameters(std::move(parameters)) {} + + /// Write the parameters as part of a method declaration. + void writeDeclTo(raw_ostream &os) const; + /// Write the parameters as part of a method definition. + void writeDefTo(raw_ostream &os) const; + + /// Determine whether this list of parameters "subsumes" another, which occurs + /// when this parameter list is identical to the other and has zero or more + /// additional default-valued parameters. + bool subsumes(const MethodParameters &other) const; + + /// Return the number of parameters. + unsigned getNumParameters() const { return parameters.size(); } + +private: + llvm::SmallVector parameters; +}; + +/// This class contains the signature of a C++ method, including the return +/// type. method name, and method parameters. +class MethodSignature { +public: + MethodSignature(StringRef retType, StringRef name, + SmallVector &¶meters) + : returnType(retType), methodName(name), + parameters(std::move(parameters)) {} + template + MethodSignature(StringRef retType, StringRef name, Parameters &&...parameters) + : returnType(retType), methodName(name), + parameters({std::forward(parameters)...}) {} + + /// Determine whether a method with this signature makes a method with + /// `other` signature redundant. This occurs if the signatures have the same + /// name and this signature's parameteres subsume the other's. + /// + /// A method that makes another method redundant with a different return type + /// can replace the other, the assumption being that the subsuming method + /// provides a more resolved return type, e.g. IntegerAttr vs. Attribute. + bool makesRedundant(const MethodSignature &other) const; + + /// Get the name of the method. + StringRef getName() const { return methodName; } + + /// Get the number of parameters. + unsigned getNumParameters() const { return parameters.getNumParameters(); } + + /// Write the signature as part of a method declaration. + void writeDeclTo(raw_ostream &os) const; + + /// Write the signature as part of a method definition. `namePrefix` is to be + /// prepended to the method name (typically namespaces for qualifying the + /// method definition). + void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + +private: + /// The method's C++ return type. + std::string returnType; + /// The method name. + std::string methodName; + /// The method's parameter list. + MethodParameters parameters; +}; + +/// Class for holding the body of an op's method for C++ code emission +class MethodBody { +public: + explicit MethodBody(bool declOnly); + + MethodBody &operator<<(Twine content); + MethodBody &operator<<(int content); + MethodBody &operator<<(const FmtObjectBase &content); + + void writeTo(raw_ostream &os) const; + +private: + /// Whether this class should record method body. + bool isEffective; + /// The body of the method. + std::string body; +}; + +/// Class for holding an op's method for C++ code emission +class Method { +public: + /// Properties (qualifiers) of class methods. Bitfield is used here to help + /// querying properties. + enum Property { + MP_None = 0x0, + MP_Static = 0x1, + MP_Constructor = 0x2, + MP_Private = 0x4, + MP_Declaration = 0x8, + MP_Inline = 0x10, + MP_Constexpr = 0x20 | MP_Inline, + MP_StaticDeclaration = MP_Static | MP_Declaration, + }; + + template + Method(StringRef retType, StringRef name, Property property, Args &&...args) + : properties(property), + methodSignature(retType, name, std::forward(args)...), + methodBody(properties & MP_Declaration) {} + + Method(Method &&) = default; + Method &operator=(Method &&) = default; + + virtual ~Method() = default; + + MethodBody &body() { return methodBody; } + + /// Returns true if this is a static method. + bool isStatic() const { return properties & MP_Static; } + + /// Returns true if this is a private method. + bool isPrivate() const { return properties & MP_Private; } + + /// Returns true if this is an inline method. + bool isInline() const { return properties & MP_Inline; } + + /// Returns the name of this method. + StringRef getName() const { return methodSignature.getName(); } + + /// Returns if this method makes the `other` method redundant. + bool makesRedundant(const Method &other) const { + return methodSignature.makesRedundant(other.methodSignature); + } + + /// Writes the method as a declaration to the given `os`. + virtual void writeDeclTo(raw_ostream &os) const; + + /// Writes the method as a definition to the given `os`. `namePrefix` is the + /// prefix to be prepended to the method name (typically namespaces for + /// qualifying the method definition). + virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const; + +protected: + /// A collection of method properties. + Property properties; + /// The signature of the method. + MethodSignature methodSignature; + /// The body of the method, if it has one. + MethodBody methodBody; +}; + +} // end namespace tblgen +} // end namespace mlir + +/// The OR of two method properties should return method properties. Ensure that +/// this function is visible to `Class`. +inline constexpr mlir::tblgen::Method::Property +operator|(mlir::tblgen::Method::Property lhs, + mlir::tblgen::Method::Property rhs) { + return mlir::tblgen::Method::Property(static_cast(lhs) | + static_cast(rhs)); +} + +namespace mlir { +namespace tblgen { + +/// Class for holding an op's constructor method for C++ code emission. +class Constructor : public Method { +public: + template + Constructor(StringRef className, Property property, + Parameters &&...parameters) + : Method("", className, property, + std::forward(parameters)...) {} + + /// Add member initializer to constructor initializing `name` with `value`. + void addMemberInitializer(StringRef name, StringRef value); + + /// Writes the method as a definition to the given `os`. `namePrefix` is the + /// prefix to be prepended to the method name (typically namespaces for + /// qualifying the method definition). + void writeDefTo(raw_ostream &os, StringRef namePrefix) const override; + +private: + /// Member initializers. + std::string memberInitializers; +}; + +/// A class used to emit C++ classes from Tablegen. Contains a list of public +/// methods and a list of private fields to be emitted. +class Class { +public: + explicit Class(StringRef name); + + /// Add a new constructor to this class and prune and constructors made + /// redundant by it. Returns null if the constructor was not added. Else, + /// returns a pointer to the new constructor. + template + Constructor *addConstructorAndPrune(Parameters &&...parameters) { + return addConstructorAndPrune( + Constructor(getClassName(), Method::MP_Constructor, + std::forward(parameters)...)); + } + + /// Add a new method to this class and prune any methods made redundant by it. + /// Returns null if the method was not added (because an existing method would + /// make it redundant). Else, returns a pointer to the new method. + template + Method *addMethod(StringRef retType, StringRef name, + Method::Property properties, Parameters &&...parameters) { + return addMethodAndPrune(Method(retType, name, properties, + std::forward(parameters)...)); + } + + /// Add a method with statically-known properties. + template + Method *addMethod(StringRef retType, StringRef name, + Parameters &&...parameters) { + return addMethod(retType, name, Properties, + std::forward(parameters)...); + } + + /// Add a static method. + template + Method *addStaticMethod(StringRef retType, StringRef name, + Parameters &&...parameters) { + return addMethod( + retType, name, std::forward(parameters)...); + } + + /// Add an inline static method. + template + Method *addStaticInlineMethod(StringRef retType, StringRef name, + Parameters &&...parameters) { + return addMethod( + retType, name, std::forward(parameters)...); + } + + /// Add an inline method. + template + Method *addInlineMethod(StringRef retType, StringRef name, + Parameters &&...parameters) { + return addMethod( + retType, name, std::forward(parameters)...); + } + + /// Add a declaration for a method. + template + Method *declareMethod(StringRef retType, StringRef name, + Parameters &&...parameters) { + return addMethod( + retType, name, std::forward(parameters)...); + } + + /// Add a declaration for a static method. + template + Method *declareStaticMethod(StringRef retType, StringRef name, + Parameters &&...parameters) { + return addMethod( + retType, name, std::forward(parameters)...); + } + + /// Creates a new field in this class. + void newField(StringRef type, StringRef name, StringRef defaultValue = ""); + + /// Writes this op's class as a declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const; + /// Writes the method definitions in this op's class to the given `os`. + void writeDefTo(raw_ostream &os) const; + + /// Returns the C++ class name of the op. + StringRef getClassName() const { return className; } + +protected: + /// Get a list of all the methods to emit, filtering out hidden ones. + void forAllMethods(llvm::function_ref func) const { + llvm::for_each(constructors, [&](auto &ctor) { func(ctor); }); + llvm::for_each(methods, [&](auto &method) { func(method); }); + } + + /// Add a new constructor if it is not made redundant by any existing + /// constructors and prune and existing constructors made redundant. + Constructor *addConstructorAndPrune(Constructor &&newCtor); + /// Add a new method if it is not made redundant by any existing methods and + /// prune and existing methods made redundant. + Method *addMethodAndPrune(Method &&newMethod); + + /// The C++ class name. + std::string className; + /// The list of constructors. + std::vector constructors; + /// The list of class methods. + std::vector methods; + /// The list of class members. + SmallVector fields; +}; + +// Class for holding an op for C++ code emission +class OpClass : public Class { +public: + explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); + + /// Adds an op trait. + void addTrait(Twine trait); + + /// Writes this op's class as a declaration to the given `os`. Redefines + /// Class::writeDeclTo to also emit traits and extra class declarations. + void writeDeclTo(raw_ostream &os) const; + +private: + StringRef extraClassDeclaration; + llvm::SetVector, StringSet<>> traits; +}; + +} // namespace tblgen +} // namespace mlir + +#endif // MLIR_TABLEGEN_CLASS_H_ diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -216,9 +216,28 @@ ConstraintMap regionConstraints; }; -// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar. +/// Escape a string using C++ encoding. E.g. foo"bar -> foo\x22bar. std::string escapeString(StringRef value); +namespace detail { +template struct stringifier { + template static std::string apply(T &&t) { + return std::string(std::forward(t)); + } +}; +template <> struct stringifier { + static std::string apply(const Twine &twine) { + return twine.str(); + } +}; +} // end namespace detail + +/// Generically convert a value to a std::string. +template std::string stringify(T &&t) { + return detail::stringifier>>:: + apply(std::forward(t)); +} + } // namespace tblgen } // namespace mlir diff --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h deleted file mode 100644 --- a/mlir/include/mlir/TableGen/OpClass.h +++ /dev/null @@ -1,442 +0,0 @@ -//===- OpClass.h - Helper classes for Op C++ code emission ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines several classes for Op C++ code emission. They are only -// expected to be used by MLIR TableGen backends. -// -// We emit the op declaration and definition into separate files: *Ops.h.inc -// and *Ops.cpp.inc. The former is to be included in the dialect *Ops.h and -// the latter for dialect *Ops.cpp. This way provides a cleaner interface. -// -// In order to do this split, we need to track method signature and -// implementation logic separately. Signature information is used for both -// declaration and definition, while implementation logic is only for -// definition. So we have the following classes for C++ code emission. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_TABLEGEN_OPCLASS_H_ -#define MLIR_TABLEGEN_OPCLASS_H_ - -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSet.h" -#include "llvm/Support/raw_ostream.h" - -#include -#include - -namespace mlir { -namespace tblgen { -class FmtObjectBase; - -// Class for holding a single parameter of an op's method for C++ code emission. -class OpMethodParameter { -public: - // Properties (qualifiers) for the parameter. - enum Property { - PP_None = 0x0, - PP_Optional = 0x1, - }; - - OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "", - Property properties = PP_None) - : type(type), name(name), defaultValue(defaultValue), - properties(properties) {} - - OpMethodParameter(StringRef type, StringRef name, Property property) - : OpMethodParameter(type, name, "", property) {} - - // Writes the parameter as a part of a method declaration to `os`. - void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); } - - // Writes the parameter as a part of a method definition to `os` - void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); } - - const std::string &getType() const { return type; } - bool hasDefaultValue() const { return !defaultValue.empty(); } - -private: - void writeTo(raw_ostream &os, bool emitDefault) const; - - std::string type; - std::string name; - std::string defaultValue; - Property properties; -}; - -// Base class for holding parameters of an op's method for C++ code emission. -class OpMethodParameters { -public: - // Discriminator for LLVM-style RTTI. - enum ParamsKind { - // Separate type and name for each parameter is not known. - PK_Unresolved, - // Each parameter is resolved to a type and name. - PK_Resolved, - }; - - OpMethodParameters(ParamsKind kind) : kind(kind) {} - virtual ~OpMethodParameters() {} - - // LLVM-style RTTI support. - ParamsKind getKind() const { return kind; } - - // Writes the parameters as a part of a method declaration to `os`. - virtual void writeDeclTo(raw_ostream &os) const = 0; - - // Writes the parameters as a part of a method definition to `os` - virtual void writeDefTo(raw_ostream &os) const = 0; - - // Factory methods to create the correct type of `OpMethodParameters` - // object based on the arguments. - static std::unique_ptr create(); - - static std::unique_ptr create(StringRef params); - - static std::unique_ptr - create(llvm::SmallVectorImpl &¶ms); - - static std::unique_ptr - create(StringRef type, StringRef name, StringRef defaultValue = ""); - -private: - const ParamsKind kind; -}; - -// Class for holding unresolved parameters. -class OpMethodUnresolvedParameters : public OpMethodParameters { -public: - OpMethodUnresolvedParameters(StringRef params) - : OpMethodParameters(PK_Unresolved), parameters(params) {} - - // write the parameters as a part of a method declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const override; - - // write the parameters as a part of a method definition to the given `os` - void writeDefTo(raw_ostream &os) const override; - - // LLVM-style RTTI support. - static bool classof(const OpMethodParameters *params) { - return params->getKind() == PK_Unresolved; - } - -private: - std::string parameters; -}; - -// Class for holding resolved parameters. -class OpMethodResolvedParameters : public OpMethodParameters { -public: - OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {} - - OpMethodResolvedParameters(llvm::SmallVectorImpl &¶ms) - : OpMethodParameters(PK_Resolved) { - for (OpMethodParameter ¶m : params) - parameters.emplace_back(std::move(param)); - } - - OpMethodResolvedParameters(StringRef type, StringRef name, - StringRef defaultValue) - : OpMethodParameters(PK_Resolved) { - parameters.emplace_back(type, name, defaultValue); - } - - // Returns the number of parameters. - size_t getNumParameters() const { return parameters.size(); } - - // Returns if this method makes the `other` method redundant. Note that this - // is more than just finding conflicting methods. This method determines if - // the 2 set of parameters are conflicting and if so, returns true if this - // method has a more general set of parameters that can replace all possible - // calls to the `other` method. - bool makesRedundant(const OpMethodResolvedParameters &other) const; - - // write the parameters as a part of a method declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const override; - - // write the parameters as a part of a method definition to the given `os` - void writeDefTo(raw_ostream &os) const override; - - // LLVM-style RTTI support. - static bool classof(const OpMethodParameters *params) { - return params->getKind() == PK_Resolved; - } - -private: - llvm::SmallVector parameters; -}; - -// Class for holding the signature of an op's method for C++ code emission -class OpMethodSignature { -public: - template - OpMethodSignature(StringRef retType, StringRef name, Args &&...args) - : returnType(retType), methodName(name), - parameters(OpMethodParameters::create(std::forward(args)...)) {} - OpMethodSignature(OpMethodSignature &&) = default; - - // Returns if a method with this signature makes a method with `other` - // signature redundant. Only supports resolved parameters. - bool makesRedundant(const OpMethodSignature &other) const; - - // Returns the number of parameters (for resolved parameters). - size_t getNumParameters() const { - return cast(parameters.get()) - ->getNumParameters(); - } - - // Returns the name of the method. - StringRef getName() const { return methodName; } - - // Writes the signature as a method declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const; - - // Writes the signature as the start of a method definition to the given `os`. - // `namePrefix` is the prefix to be prepended to the method name (typically - // namespaces for qualifying the method definition). - void writeDefTo(raw_ostream &os, StringRef namePrefix) const; - -private: - std::string returnType; - std::string methodName; - std::unique_ptr parameters; -}; - -// Class for holding the body of an op's method for C++ code emission -class OpMethodBody { -public: - explicit OpMethodBody(bool declOnly); - - OpMethodBody &operator<<(Twine content); - OpMethodBody &operator<<(int content); - OpMethodBody &operator<<(const FmtObjectBase &content); - - void writeTo(raw_ostream &os) const; - -private: - // Whether this class should record method body. - bool isEffective; - std::string body; -}; - -// Class for holding an op's method for C++ code emission -class OpMethod { -public: - // Properties (qualifiers) of class methods. Bitfield is used here to help - // querying properties. - enum Property { - MP_None = 0x0, - MP_Static = 0x1, - MP_Constructor = 0x2, - MP_Private = 0x4, - MP_Declaration = 0x8, - MP_Inline = 0x10, - MP_Constexpr = 0x20 | MP_Inline, - MP_StaticDeclaration = MP_Static | MP_Declaration, - }; - - template - OpMethod(StringRef retType, StringRef name, Property property, unsigned id, - Args &&...args) - : properties(property), - methodSignature(retType, name, std::forward(args)...), - methodBody(properties & MP_Declaration), id(id) {} - - OpMethod(OpMethod &&) = default; - - virtual ~OpMethod() = default; - - OpMethodBody &body() { return methodBody; } - - // Returns true if this is a static method. - bool isStatic() const { return properties & MP_Static; } - - // Returns true if this is a private method. - bool isPrivate() const { return properties & MP_Private; } - - // Returns true if this is an inline method. - bool isInline() const { return properties & MP_Inline; } - - // Returns the name of this method. - StringRef getName() const { return methodSignature.getName(); } - - // Returns the ID for this method - unsigned getID() const { return id; } - - // Returns if this method makes the `other` method redundant. - bool makesRedundant(const OpMethod &other) const { - return methodSignature.makesRedundant(other.methodSignature); - } - - // Writes the method as a declaration to the given `os`. - virtual void writeDeclTo(raw_ostream &os) const; - - // Writes the method as a definition to the given `os`. `namePrefix` is the - // prefix to be prepended to the method name (typically namespaces for - // qualifying the method definition). - virtual void writeDefTo(raw_ostream &os, StringRef namePrefix) const; - -protected: - Property properties; - OpMethodSignature methodSignature; - OpMethodBody methodBody; - const unsigned id; -}; - -// Class for holding an op's constructor method for C++ code emission. -class OpConstructor : public OpMethod { -public: - template - OpConstructor(StringRef className, Property property, unsigned id, - Args &&...args) - : OpMethod("", className, property, id, std::forward(args)...) {} - - // Add member initializer to constructor initializing `name` with `value`. - void addMemberInitializer(StringRef name, StringRef value); - - // Writes the method as a definition to the given `os`. `namePrefix` is the - // prefix to be prepended to the method name (typically namespaces for - // qualifying the method definition). - void writeDefTo(raw_ostream &os, StringRef namePrefix) const override; - -private: - // Member initializers. - std::string memberInitializers; -}; - -// A class used to emit C++ classes from Tablegen. Contains a list of public -// methods and a list of private fields to be emitted. -class Class { -public: - explicit Class(StringRef name); - - // Adds a new method to this class and prune redundant methods. Returns null - // if the method was not added (because an existing method would make it - // redundant), else returns a pointer to the added method. Note that this call - // may also delete existing methods that are made redundant by a method to the - // class. - template - OpMethod *addMethodAndPrune(StringRef retType, StringRef name, - OpMethod::Property properties, Args &&...args) { - auto newMethod = std::make_unique( - retType, name, properties, nextMethodID++, std::forward(args)...); - return addMethodAndPrune(methods, std::move(newMethod)); - } - - template - OpMethod *addMethodAndPrune(StringRef retType, StringRef name, - Args &&...args) { - return addMethodAndPrune(retType, name, OpMethod::MP_None, - std::forward(args)...); - } - - template - OpConstructor *addConstructorAndPrune(Args &&...args) { - auto newConstructor = std::make_unique( - getClassName(), OpMethod::MP_Constructor, nextMethodID++, - std::forward(args)...); - return addMethodAndPrune(constructors, std::move(newConstructor)); - } - - // Creates a new field in this class. - void newField(StringRef type, StringRef name, StringRef defaultValue = ""); - - // Writes this op's class as a declaration to the given `os`. - void writeDeclTo(raw_ostream &os) const; - // Writes the method definitions in this op's class to the given `os`. - void writeDefTo(raw_ostream &os) const; - - // Returns the C++ class name of the op. - StringRef getClassName() const { return className; } - -protected: - // Get a list of all the methods to emit, filtering out hidden ones. - void forAllMethods(llvm::function_ref func) const { - using ConsRef = const std::unique_ptr &; - using MethodRef = const std::unique_ptr &; - llvm::for_each(constructors, [&](ConsRef ptr) { func(*ptr); }); - llvm::for_each(methods, [&](MethodRef ptr) { func(*ptr); }); - } - - // For deterministic code generation, keep methods sorted in the order in - // which they were generated. - template - struct MethodCompare { - bool operator()(const std::unique_ptr &x, - const std::unique_ptr &y) const { - return x->getID() < y->getID(); - } - }; - - template - using MethodSet = - std::set, MethodCompare>; - - template - MethodTy *addMethodAndPrune(MethodSet &set, - std::unique_ptr &&newMethod) { - // Check if the new method will be made redundant by existing methods. - for (auto &method : set) - if (method->makesRedundant(*newMethod)) - return nullptr; - - // We can add this a method to the set. Prune any existing methods that will - // be made redundant by adding this new method. Note that the redundant - // check between two methods is more than a conflict check. makesRedundant() - // below will check if the new method conflicts with an existing method and - // if so, returns true if the new method makes the existing method redundant - // because all calls to the existing method can be subsumed by the new - // method. So makesRedundant() does a combined job of finding conflicts and - // deciding which of the 2 conflicting methods survive. - // - // Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing - // it manually here. - for (auto it = set.begin(), end = set.end(); it != end;) { - if (newMethod->makesRedundant(*(it->get()))) - it = set.erase(it); - else - ++it; - } - - MethodTy *ret = newMethod.get(); - set.insert(std::move(newMethod)); - return ret; - } - - std::string className; - MethodSet constructors; - MethodSet methods; - unsigned nextMethodID = 0; - SmallVector fields; -}; - -// Class for holding an op for C++ code emission -class OpClass : public Class { -public: - explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); - - // Adds an op trait. - void addTrait(Twine trait); - - // Writes this op's class as a declaration to the given `os`. Redefines - // Class::writeDeclTo to also emit traits and extra class declarations. - void writeDeclTo(raw_ostream &os) const; - -private: - StringRef extraClassDeclaration; - SmallVector traitsVec; - StringSet<> traitsSet; -}; - -} // namespace tblgen -} // namespace mlir - -#endif // MLIR_TABLEGEN_OPCLASS_H_ diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -13,12 +13,12 @@ Attribute.cpp AttrOrTypeDef.cpp Builder.cpp + Class.cpp Constraint.cpp Dialect.cpp Format.cpp Interfaces.cpp Operator.cpp - OpClass.cpp Pass.cpp Pattern.cpp Predicate.cpp diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/Class.cpp rename from mlir/lib/TableGen/OpClass.cpp rename to mlir/lib/TableGen/Class.cpp --- a/mlir/lib/TableGen/OpClass.cpp +++ b/mlir/lib/TableGen/Class.cpp @@ -1,4 +1,4 @@ -//===- OpClass.cpp - Helper classes for Op C++ code emission --------------===// +//===- Class.cpp - Helper classes for Op C++ code emission --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/TableGen/OpClass.h" +#include "mlir/TableGen/Class.h" #include "mlir/TableGen/Format.h" #include "llvm/ADT/Sequence.h" @@ -20,173 +20,102 @@ using namespace mlir; using namespace mlir::tblgen; -namespace { - // Returns space to be emitted after the given C++ `type`. return "" if the // ends with '&' or '*', or is empty, else returns " ". -StringRef getSpaceAfterType(StringRef type) { +static StringRef getSpaceAfterType(StringRef type) { return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " "; } -} // namespace - //===----------------------------------------------------------------------===// -// OpMethodParameter definitions +// MethodParameter definitions //===----------------------------------------------------------------------===// -void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const { - if (properties & PP_Optional) +void MethodParameter::writeTo(raw_ostream &os, bool emitDefault) const { + if (optional) os << "/*optional*/"; os << type << getSpaceAfterType(type) << name; - if (emitDefault && !defaultValue.empty()) + if (emitDefault && hasDefaultValue()) os << " = " << defaultValue; } //===----------------------------------------------------------------------===// -// OpMethodParameters definitions +// MethodParameters definitions //===----------------------------------------------------------------------===// -// Factory methods to construct the correct type of `OpMethodParameters` -// object based on the arguments. -std::unique_ptr OpMethodParameters::create() { - return std::make_unique(); +void MethodParameters::writeDeclTo(raw_ostream &os) const { + llvm::interleaveComma(parameters, os, + [&os](auto ¶m) { param.writeDeclTo(os); }); } - -std::unique_ptr -OpMethodParameters::create(StringRef params) { - return std::make_unique(params); -} - -std::unique_ptr -OpMethodParameters::create(llvm::SmallVectorImpl &¶ms) { - return std::make_unique(std::move(params)); +void MethodParameters::writeDefTo(raw_ostream &os) const { + llvm::interleaveComma(parameters, os, + [&os](auto ¶m) { param.writeDefTo(os); }); } -std::unique_ptr -OpMethodParameters::create(StringRef type, StringRef name, - StringRef defaultValue) { - return std::make_unique(type, name, defaultValue); -} - -//===----------------------------------------------------------------------===// -// OpMethodUnresolvedParameters definitions -//===----------------------------------------------------------------------===// -void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const { - os << parameters; -} - -void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const { - // We need to remove the default values for parameters in method definition. - // TODO: We are using '=' and ',' as delimiters for parameter - // initializers. This is incorrect for initializer list with more than one - // element. Change to a more robust approach. - llvm::SmallVector tokens; - StringRef params = parameters; - while (!params.empty()) { - std::pair parts = params.split("="); - tokens.push_back(parts.first); - params = parts.second.split(',').second; - } - llvm::interleaveComma(tokens, os, [&](StringRef token) { os << token; }); -} - -//===----------------------------------------------------------------------===// -// OpMethodResolvedParameters definitions -//===----------------------------------------------------------------------===// - -// Returns true if a method with these parameters makes a method with parameters -// `other` redundant. This should return true only if all possible calls to the -// other method can be replaced by calls to this method. -bool OpMethodResolvedParameters::makesRedundant( - const OpMethodResolvedParameters &other) const { - const size_t otherNumParams = other.getNumParameters(); - const size_t thisNumParams = getNumParameters(); - - // All calls to the other method can be replaced this method only if this - // method has the same or more arguments number of arguments as the other, and - // the common arguments have the same type. - if (thisNumParams < otherNumParams) +bool MethodParameters::subsumes(const MethodParameters &other) const { + // These parameters do not subsume the others if there are fewer parameters + // or their types do not match. + if (parameters.size() < other.parameters.size()) + return false; + if (!std::equal( + other.parameters.begin(), other.parameters.end(), parameters.begin(), + [](auto &lhs, auto &rhs) { return lhs.getType() == rhs.getType(); })) return false; - for (int idx : llvm::seq(0, otherNumParams)) - if (parameters[idx].getType() != other.parameters[idx].getType()) - return false; - - // If all the common arguments have the same type, we can elide the other - // method if this method has the same number of arguments as other or the - // first argument after the common ones has a default value (and by C++ - // requirement, all the later ones will also have a default value). - return thisNumParams == otherNumParams || - parameters[otherNumParams].hasDefaultValue(); -} -void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const { - llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) { - param.writeDeclTo(os); - }); -} - -void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const { - llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) { - param.writeDefTo(os); - }); + // If all the common parameters have the same type, we can elide the other + // method if this method has the same number of parameters as other or if the + // first paramater after the common parameters has a default value (and, as + // required by C++, subsequent parameters will have default values too). + return parameters.size() == other.parameters.size() || + parameters[other.parameters.size()].hasDefaultValue(); } //===----------------------------------------------------------------------===// -// OpMethodSignature definitions +// MethodSignature definitions //===----------------------------------------------------------------------===// -// Returns if a method with this signature makes a method with `other` signature -// redundant. Only supports resolved parameters. -bool OpMethodSignature::makesRedundant(const OpMethodSignature &other) const { - if (methodName != other.methodName) - return false; - auto *resolvedThis = dyn_cast(parameters.get()); - auto *resolvedOther = - dyn_cast(other.parameters.get()); - if (resolvedThis && resolvedOther) - return resolvedThis->makesRedundant(*resolvedOther); - return false; +bool MethodSignature::makesRedundant(const MethodSignature &other) const { + return methodName == other.methodName && + parameters.subsumes(other.parameters); } -void OpMethodSignature::writeDeclTo(raw_ostream &os) const { +void MethodSignature::writeDeclTo(raw_ostream &os) const { os << returnType << getSpaceAfterType(returnType) << methodName << "("; - parameters->writeDeclTo(os); + parameters.writeDeclTo(os); os << ")"; } -void OpMethodSignature::writeDefTo(raw_ostream &os, - StringRef namePrefix) const { +void MethodSignature::writeDefTo(raw_ostream &os, StringRef namePrefix) const { os << returnType << getSpaceAfterType(returnType) << namePrefix << (namePrefix.empty() ? "" : "::") << methodName << "("; - parameters->writeDefTo(os); + parameters.writeDefTo(os); os << ")"; } //===----------------------------------------------------------------------===// -// OpMethodBody definitions +// MethodBody definitions //===----------------------------------------------------------------------===// -OpMethodBody::OpMethodBody(bool declOnly) : isEffective(!declOnly) {} +MethodBody::MethodBody(bool declOnly) : isEffective(!declOnly) {} -OpMethodBody &OpMethodBody::operator<<(Twine content) { +MethodBody &MethodBody::operator<<(Twine content) { if (isEffective) body.append(content.str()); return *this; } -OpMethodBody &OpMethodBody::operator<<(int content) { +MethodBody &MethodBody::operator<<(int content) { if (isEffective) body.append(std::to_string(content)); return *this; } -OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) { +MethodBody &MethodBody::operator<<(const FmtObjectBase &content) { if (isEffective) body.append(content.str()); return *this; } -void OpMethodBody::writeTo(raw_ostream &os) const { +void MethodBody::writeTo(raw_ostream &os) const { auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; }); os << bodyRef; if (bodyRef.empty() || bodyRef.back() != '\n') @@ -194,10 +123,10 @@ } //===----------------------------------------------------------------------===// -// OpMethod definitions +// Method definitions //===----------------------------------------------------------------------===// -void OpMethod::writeDeclTo(raw_ostream &os) const { +void Method::writeDeclTo(raw_ostream &os) const { os.indent(2); if (isStatic()) os << "static "; @@ -213,7 +142,7 @@ } } -void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { +void Method::writeDefTo(raw_ostream &os, StringRef namePrefix) const { // Do not write definition if the method is decl only. if (properties & MP_Declaration) return; @@ -227,15 +156,15 @@ } //===----------------------------------------------------------------------===// -// OpConstructor definitions +// Constructor definitions //===----------------------------------------------------------------------===// -void OpConstructor::addMemberInitializer(StringRef name, StringRef value) { +void Constructor::addMemberInitializer(StringRef name, StringRef value) { memberInitializers.append(std::string(llvm::formatv( "{0}{1}({2})", memberInitializers.empty() ? " : " : ", ", name, value))); } -void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const { +void Constructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const { // Do not write definition if the method is decl only. if (properties & MP_Declaration) return; @@ -243,7 +172,7 @@ methodSignature.writeDefTo(os, namePrefix); os << " " << memberInitializers << " {\n"; methodBody.writeTo(os); - os << "}"; + os << "}\n"; } //===----------------------------------------------------------------------===// @@ -259,12 +188,13 @@ : formatv("{0} = {1}", varName, defaultValue).str(); fields.push_back(std::move(field)); } + void Class::writeDeclTo(raw_ostream &os) const { bool hasPrivateMethod = false; os << "class " << className << " {\n"; os << "public:\n"; - forAllMethods([&](const OpMethod &method) { + forAllMethods([&](const Method &method) { if (!method.isPrivate()) { method.writeDeclTo(os); os << '\n'; @@ -276,7 +206,7 @@ os << '\n'; os << "private:\n"; if (hasPrivateMethod) { - forAllMethods([&](const OpMethod &method) { + forAllMethods([&](const Method &method) { if (method.isPrivate()) { method.writeDeclTo(os); os << '\n'; @@ -291,12 +221,35 @@ } void Class::writeDefTo(raw_ostream &os) const { - forAllMethods([&](const OpMethod &method) { + forAllMethods([&](const Method &method) { method.writeDefTo(os, className); - os << "\n\n"; + os << "\n"; }); } +// Insert a new method into a list of methods, if it would not be pruned, and +// prune and existing methods. +template +MethodT *insertAndPrune(ContainerT &methods, MethodT newMethod) { + if (llvm::any_of(methods, [&](auto &method) { + return method.makesRedundant(newMethod); + })) + return nullptr; + + llvm::erase_if( + methods, [&](auto &method) { return newMethod.makesRedundant(method); }); + methods.push_back(std::move(newMethod)); + return &methods.back(); +} + +Method *Class::addMethodAndPrune(Method &&newMethod) { + return insertAndPrune(methods, std::move(newMethod)); +} + +Constructor *Class::addConstructorAndPrune(Constructor &&newCtor) { + return insertAndPrune(constructors, std::move(newCtor)); +} + //===----------------------------------------------------------------------===// // OpClass definitions //===----------------------------------------------------------------------===// @@ -304,15 +257,11 @@ OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) : Class(name), extraClassDeclaration(extraClassDeclaration) {} -void OpClass::addTrait(Twine trait) { - auto traitStr = trait.str(); - if (traitsSet.insert(traitStr).second) - traitsVec.push_back(std::move(traitStr)); -} +void OpClass::addTrait(Twine trait) { traits.insert(trait.str()); } void OpClass::writeDeclTo(raw_ostream &os) const { os << "class " << className << " : public ::mlir::Op<" << className; - for (const auto &trait : traitsVec) + for (const auto &trait : traits) os << ", " << trait; os << "> {\npublic:\n" << " using Op::Op;\n" @@ -320,7 +269,7 @@ << " using Adaptor = " << className << "Adaptor;\n"; bool hasPrivateMethod = false; - forAllMethods([&](const OpMethod &method) { + forAllMethods([&](const Method &method) { if (!method.isPrivate()) { method.writeDeclTo(os); os << "\n"; @@ -335,7 +284,7 @@ if (hasPrivateMethod) { os << "\nprivate:\n"; - forAllMethods([&](const OpMethod &method) { + forAllMethods([&](const Method &method) { if (method.isPrivate()) { method.writeDeclTo(os); os << "\n"; diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -10,11 +10,11 @@ // //===----------------------------------------------------------------------===// +#include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" -#include "mlir/TableGen/OpClass.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Trait.h" #include "llvm/ADT/Optional.h" diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -13,11 +13,11 @@ #include "OpFormatGen.h" #include "OpGenHelpers.h" +#include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" -#include "mlir/TableGen/OpClass.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/SideEffects.h" #include "mlir/TableGen/Trait.h" @@ -361,7 +361,7 @@ // types. `inferredAttributes` is populated with any attributes that are // elided from the build list. The given `typeParamKind` and `attrParamKind` // controls how result types and attributes are placed in the parameter list. - void buildParamList(llvm::SmallVectorImpl ¶mList, + void buildParamList(SmallVectorImpl ¶mList, llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, @@ -369,7 +369,7 @@ // Adds op arguments and regions into operation state for build() methods. void - genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, + genCodeForAddingArgAndRegionForBuilder(MethodBody &body, llvm::StringSet<> &inferredAttributes, bool isRawValueAttr = false); @@ -390,17 +390,16 @@ // Generates verify statements for operands and results in the operation. // The generated code will be attached to `body`. - void genOperandResultVerifier(OpMethodBody &body, - Operator::value_range values, + void genOperandResultVerifier(MethodBody &body, Operator::value_range values, StringRef valueKind); // Generates verify statements for regions in the operation. // The generated code will be attached to `body`. - void genRegionVerifier(OpMethodBody &body); + void genRegionVerifier(MethodBody &body); // Generates verify statements for successors in the operation. // The generated code will be attached to `body`. - void genSuccessorVerifier(OpMethodBody &body); + void genSuccessorVerifier(MethodBody &body); // Generates the traits used by the object. void genTraits(); @@ -413,8 +412,8 @@ // Generate op interface method for the given interface method. If // 'declaration' is true, generates a declaration, else a definition. - OpMethod *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, - bool declaration = true); + Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, + bool declaration = true); // Generate the side effect interface methods. void genSideEffectInterfaceMethods(); @@ -470,7 +469,7 @@ // Generate attribute verification. If an op instance is not available, then // attribute checks that require one will not be emitted. static void genAttributeVerifier( - const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, OpMethodBody &body, + const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body, const StaticVerifierFunctionEmitter &staticVerifierEmitter) { // Check that a required attribute exists. // @@ -602,7 +601,7 @@ void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } -static void errorIfPruned(size_t line, OpMethod *m, const Twine &methodName, +static void errorIfPruned(size_t line, Method *m, const Twine &methodName, const Operator &op) { if (m) return; @@ -627,18 +626,15 @@ for (const NamedAttribute &namedAttr : op.getAttributes()) addAttrName(namedAttr.name); // Include key attributes from several traits as implicitly registered. - std::string operandSizes = "operand_segment_sizes"; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) - addAttrName(operandSizes); - std::string attrSizes = "result_segment_sizes"; + addAttrName("operand_segment_sizes"); if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) - addAttrName(attrSizes); + addAttrName("result_segment_sizes"); // Emit the getAttributeNames method. { - auto *method = opClass.addMethodAndPrune( - "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames", - OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline)); + auto *method = opClass.addStaticInlineMethod( + "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); ERROR_IF_PRUNED(method, "getAttributeNames", op); auto &body = method->body(); if (attributeNames.empty()) { @@ -658,20 +654,18 @@ // Emit the getAttributeNameForIndex methods. { - auto *method = opClass.addMethodAndPrune( + auto *method = opClass.addInlineMethod( "::mlir::Identifier", "getAttributeNameForIndex", - OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline), - "unsigned", "index"); + MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); method->body() << " return getAttributeNameForIndex((*this)->getName(), index);"; } { - auto *method = opClass.addMethodAndPrune( + auto *method = opClass.addStaticInlineMethod( "::mlir::Identifier", "getAttributeNameForIndex", - OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline | - OpMethod::MP_Static), - "::mlir::OperationName name, unsigned index"); + MethodParameter("::mlir::OperationName", "name"), + MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); method->body() << "assert(index < " << attributeNames.size() << " && \"invalid attribute index\");\n" @@ -689,8 +683,7 @@ // Generate the non-static variant. { auto *method = - opClass.addMethodAndPrune("::mlir::Identifier", methodName, - OpMethod::Property(OpMethod::MP_Inline)); + opClass.addInlineMethod("::mlir::Identifier", methodName); ERROR_IF_PRUNED(method, methodName, op); method->body() << llvm::formatv(attrNameMethodBody, attrIt.second).str(); @@ -698,10 +691,9 @@ // Generate the static variant. { - auto *method = opClass.addMethodAndPrune( + auto *method = opClass.addStaticInlineMethod( "::mlir::Identifier", methodName, - OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static), - "::mlir::OperationName", "name"); + MethodParameter("::mlir::OperationName", "name")); ERROR_IF_PRUNED(method, methodName, op); method->body() << llvm::formatv(attrNameMethodBody, "name, " + Twine(attrIt.second)) @@ -717,13 +709,13 @@ // Emit the derived attribute body. auto emitDerivedAttr = [&](StringRef name, Attribute attr) { - if (auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name)) + if (auto *method = opClass.addMethod(attr.getReturnType(), name)) method->body() << " " << attr.getDerivedCodeBody() << "\n"; }; // Emit with return type specified. auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) { - auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); + auto *method = opClass.addMethod(attr.getReturnType(), name); ERROR_IF_PRUNED(method, name, op); auto &body = method->body(); body << " auto attr = " << name << "Attr();\n"; @@ -748,7 +740,7 @@ // use the string interface for better compile time verification. auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { auto *method = - opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str()); + opClass.addMethod(attr.getStorageType(), (name + "Attr").str()); if (!method) return; method->body() << formatv( @@ -773,68 +765,69 @@ [](const NamedAttribute &namedAttr) { return namedAttr.attr.isDerivedAttr(); }); - if (!derivedAttrs.empty()) { - opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait"); - // Generate helper method to query whether a named attribute is a derived - // attribute. This enables, for example, avoiding adding an attribute that - // overlaps with a derived attribute. - { - auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute", - OpMethod::MP_Static, - "::llvm::StringRef", "name"); - ERROR_IF_PRUNED(method, "isDerivedAttribute", op); - auto &body = method->body(); - for (auto namedAttr : derivedAttrs) - body << " if (name == \"" << namedAttr.name << "\") return true;\n"; - body << " return false;"; - } - // Generate method to materialize derived attributes as a DictionaryAttr. - { - auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr", - "materializeDerivedAttributes"); - ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op); - auto &body = method->body(); - - auto nonMaterializable = - make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) { - return namedAttr.attr.getConvertFromStorageCall().empty(); - }); - if (!nonMaterializable.empty()) { - std::string attrs; - llvm::raw_string_ostream os(attrs); - interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) { - os << op.getGetterName(attr.name); - }); - PrintWarning( - op.getLoc(), - formatv( - "op has non-materializable derived attributes '{0}', skipping", - os.str())); - body << formatv(" emitOpError(\"op has non-materializable derived " - "attributes '{0}'\");\n", - attrs); - body << " return nullptr;"; - return; - } + if (derivedAttrs.empty()) + return; - body << " ::mlir::MLIRContext* ctx = getContext();\n"; - body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; - body << " return ::mlir::DictionaryAttr::get("; - body << " ctx, {\n"; - interleave( - derivedAttrs, body, - [&](const NamedAttribute &namedAttr) { - auto tmpl = namedAttr.attr.getConvertFromStorageCall(); - std::string name = op.getGetterName(namedAttr.name); - body << " {" << name << "AttrName(),\n" - << tgfmt(tmpl, &fctx.withSelf(name + "()") - .withBuilder("odsBuilder") - .addSubst("_ctx", "ctx")) - << "}"; - }, - ",\n"); - body << "});"; + opClass.addTrait("::mlir::DerivedAttributeOpInterface::Trait"); + // Generate helper method to query whether a named attribute is a derived + // attribute. This enables, for example, avoiding adding an attribute that + // overlaps with a derived attribute. + { + auto *method = + opClass.addStaticMethod("bool", "isDerivedAttribute", + MethodParameter("::llvm::StringRef", "name")); + ERROR_IF_PRUNED(method, "isDerivedAttribute", op); + auto &body = method->body(); + for (auto namedAttr : derivedAttrs) + body << " if (name == \"" << namedAttr.name << "\") return true;\n"; + body << " return false;"; + } + // Generate method to materialize derived attributes as a DictionaryAttr. + { + auto *method = opClass.addMethod("::mlir::DictionaryAttr", + "materializeDerivedAttributes"); + ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op); + auto &body = method->body(); + + auto nonMaterializable = + make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) { + return namedAttr.attr.getConvertFromStorageCall().empty(); + }); + if (!nonMaterializable.empty()) { + std::string attrs; + llvm::raw_string_ostream os(attrs); + interleaveComma(nonMaterializable, os, [&](const NamedAttribute &attr) { + os << op.getGetterName(attr.name); + }); + PrintWarning( + op.getLoc(), + formatv( + "op has non-materializable derived attributes '{0}', skipping", + os.str())); + body << formatv(" emitOpError(\"op has non-materializable derived " + "attributes '{0}'\");\n", + attrs); + body << " return nullptr;"; + return; } + + body << " ::mlir::MLIRContext* ctx = getContext();\n"; + body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; + body << " return ::mlir::DictionaryAttr::get("; + body << " ctx, {\n"; + interleave( + derivedAttrs, body, + [&](const NamedAttribute &namedAttr) { + auto tmpl = namedAttr.attr.getConvertFromStorageCall(); + std::string name = op.getGetterName(namedAttr.name); + body << " {" << name << "AttrName(),\n" + << tgfmt(tmpl, &fctx.withSelf(name + "()") + .withBuilder("odsBuilder") + .addSubst("_ctx", "ctx")) + << "}"; + }, + ",\n"); + body << "});"; } } @@ -844,19 +837,21 @@ // for better compile time verification. auto emitAttrWithStorageType = [&](StringRef setterName, StringRef getterName, Attribute attr) { - auto *method = opClass.addMethodAndPrune( - "void", (setterName + "Attr").str(), attr.getStorageType(), "attr"); + auto *method = + opClass.addMethod("void", (setterName + "Attr").str(), + MethodParameter(attr.getStorageType(), "attr")); if (method) method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);", getterName); }; for (const NamedAttribute &namedAttr : op.getAttributes()) { - if (!namedAttr.attr.isDerivedAttr()) - for (auto names : llvm::zip(op.getSetterNames(namedAttr.name), - op.getGetterNames(namedAttr.name))) - emitAttrWithStorageType(std::get<0>(names), std::get<1>(names), - namedAttr.attr); + if (namedAttr.attr.isDerivedAttr()) + continue; + for (auto names : llvm::zip(op.getSetterNames(namedAttr.name), + op.getGetterNames(namedAttr.name))) + emitAttrWithStorageType(std::get<0>(names), std::get<1>(names), + namedAttr.attr); } } @@ -866,7 +861,7 @@ auto emitRemoveAttr = [&](StringRef name) { auto upperInitial = name.take_front().upper(); auto suffix = name.drop_front(); - auto *method = opClass.addMethodAndPrune( + auto *method = opClass.addMethod( "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str()); if (!method) return; @@ -887,8 +882,8 @@ int numVariadic, int numNonVariadic, StringRef rangeSizeCall, bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) { - auto *method = opClass.addMethodAndPrune("std::pair", - methodName, "unsigned", "index"); + auto *method = opClass.addMethod("std::pair", methodName, + MethodParameter("unsigned", "index")); if (!method) return; auto &body = method->body(); @@ -900,7 +895,7 @@ // Because the op can have arbitrarily interleaved variadic and non-variadic // operands, we need to embed a list in the "sink" getter method for // calculation at run-time. - llvm::SmallVector isVariadic; + SmallVector isVariadic; isVariadic.reserve(llvm::size(odsValues)); for (auto &it : odsValues) isVariadic.push_back(it.isVariableLength() ? "true" : "false"); @@ -959,8 +954,8 @@ rangeSizeCall, attrSizedOperands, sizeAttrInit, const_cast(op).getOperands()); - auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned", - "index"); + auto *m = opClass.addMethod(rangeType, "getODSOperands", + MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(m, "getODSOperands", op); auto &body = m->body(); body << formatv(valueRangeReturnCode, rangeBeginCall, @@ -974,7 +969,7 @@ continue; for (StringRef name : op.getGetterNames(operand.name)) { if (operand.isOptional()) { - m = opClass.addMethodAndPrune("::mlir::Value", name); + m = opClass.addMethod("::mlir::Value", name); ERROR_IF_PRUNED(m, name, op); m->body() << " auto operands = getODSOperands(" << i << ");\n" << " return operands.empty() ? ::mlir::Value() : " @@ -983,24 +978,24 @@ std::string segmentAttr = op.getGetterName( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); if (isAdaptor) { - m = opClass.addMethodAndPrune( - "::llvm::SmallVector<::mlir::ValueRange>", name); + m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>", + name); ERROR_IF_PRUNED(m, name, op); m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, segmentAttr, i); continue; } - m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", name); + m = opClass.addMethod("::mlir::OperandRangeRange", name); ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr << "Attr());"; } else if (operand.isVariadic()) { - m = opClass.addMethodAndPrune(rangeType, name); + m = opClass.addMethod(rangeType, name); ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSOperands(" << i << ");"; } else { - m = opClass.addMethodAndPrune("::mlir::Value", name); + m = opClass.addMethod("::mlir::Value", name); ERROR_IF_PRUNED(m, name, op); m->body() << " return *getODSOperands(" << i << ").begin();"; } @@ -1035,10 +1030,10 @@ if (operand.name.empty()) continue; for (StringRef name : op.getGetterNames(operand.name)) { - auto *m = opClass.addMethodAndPrune( - operand.isVariadicOfVariadic() ? "::mlir::MutableOperandRangeRange" - : "::mlir::MutableOperandRange", - (name + "Mutable").str()); + auto *m = opClass.addMethod(operand.isVariadicOfVariadic() + ? "::mlir::MutableOperandRangeRange" + : "::mlir::MutableOperandRange", + (name + "Mutable").str()); ERROR_IF_PRUNED(m, name, op); auto &body = m->body(); body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" @@ -1110,8 +1105,9 @@ numNormalResults, "getOperation()->getNumResults()", attrSizedResults, attrSizeInitCode, op.getResults()); - auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range", - "getODSResults", "unsigned", "index"); + auto *m = + opClass.addMethod("::mlir::Operation::result_range", "getODSResults", + MethodParameter("unsigned", "index")); ERROR_IF_PRUNED(m, "getODSResults", op); m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", "getODSResultIndexAndLength(index)"); @@ -1122,17 +1118,17 @@ continue; for (StringRef name : op.getGetterNames(result.name)) { if (result.isOptional()) { - m = opClass.addMethodAndPrune("::mlir::Value", name); + m = opClass.addMethod("::mlir::Value", name); ERROR_IF_PRUNED(m, name, op); m->body() << " auto results = getODSResults(" << i << ");\n" << " return results.empty() ? ::mlir::Value() : *results.begin();"; } else if (result.isVariadic()) { - m = opClass.addMethodAndPrune("::mlir::Operation::result_range", name); + m = opClass.addMethod("::mlir::Operation::result_range", name); ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSResults(" << i << ");"; } else { - m = opClass.addMethodAndPrune("::mlir::Value", name); + m = opClass.addMethod("::mlir::Value", name); ERROR_IF_PRUNED(m, name, op); m->body() << " return *getODSResults(" << i << ").begin();"; } @@ -1150,15 +1146,15 @@ for (StringRef name : op.getGetterNames(region.name)) { // Generate the accessors for a variadic region. if (region.isVariadic()) { - auto *m = opClass.addMethodAndPrune( - "::mlir::MutableArrayRef<::mlir::Region>", name); + auto *m = + opClass.addMethod("::mlir::MutableArrayRef<::mlir::Region>", name); ERROR_IF_PRUNED(m, name, op); m->body() << formatv(" return (*this)->getRegions().drop_front({0});", i); continue; } - auto *m = opClass.addMethodAndPrune("::mlir::Region &", name); + auto *m = opClass.addMethod("::mlir::Region &", name); ERROR_IF_PRUNED(m, name, op); m->body() << formatv(" return (*this)->getRegion({0});", i); } @@ -1175,7 +1171,7 @@ for (StringRef name : op.getGetterNames(successor.name)) { // Generate the accessors for a variadic successor list. if (successor.isVariadic()) { - auto *m = opClass.addMethodAndPrune("::mlir::SuccessorRange", name); + auto *m = opClass.addMethod("::mlir::SuccessorRange", name); ERROR_IF_PRUNED(m, name, op); m->body() << formatv( " return {std::next((*this)->successor_begin(), {0}), " @@ -1184,7 +1180,7 @@ continue; } - auto *m = opClass.addMethodAndPrune("::mlir::Block *", name); + auto *m = opClass.addMethod("::mlir::Block *", name); ERROR_IF_PRUNED(m, name, op); m->body() << formatv(" return (*this)->getSuccessor({0});", i); } @@ -1227,14 +1223,13 @@ // inferring result type. auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, bool inferType) { - llvm::SmallVector paramList; - llvm::SmallVector resultNames; + SmallVector paramList; + SmallVector resultNames; llvm::StringSet<> inferredAttributes; buildParamList(paramList, inferredAttributes, resultNames, paramKind, attrType); - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method. if (!m) return; @@ -1308,7 +1303,7 @@ int numResults = op.getNumResults(); // Signature - llvm::SmallVector paramList; + SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); @@ -1319,8 +1314,7 @@ if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; @@ -1348,14 +1342,13 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() { // TODO: Expand to support regions. - SmallVector paramList; + SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", "attributes", "{}"); - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; @@ -1407,14 +1400,13 @@ } void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { - llvm::SmallVector paramList; - llvm::SmallVector resultNames; + SmallVector paramList; + SmallVector resultNames; llvm::StringSet<> inferredAttributes; buildParamList(paramList, inferredAttributes, resultNames, TypeParamKind::None); - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; @@ -1436,14 +1428,13 @@ } void OpEmitter::genUseAttrAsResultTypeBuilder() { - SmallVector paramList; + SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", "attributes", "{}"); - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; @@ -1480,16 +1471,15 @@ /// Returns a signature of the builder. Updates the context `fctx` to enable /// replacement of $_builder and $_state in the body. -static std::string getBuilderSignature(const Builder &builder) { +static SmallVector +getBuilderSignature(const Builder &builder) { ArrayRef params(builder.getParameters()); // Inject builder and state arguments. - llvm::SmallVector arguments; + SmallVector arguments; arguments.reserve(params.size() + 2); - arguments.push_back( - llvm::formatv("::mlir::OpBuilder &{0}", odsBuilder).str()); - arguments.push_back( - llvm::formatv("::mlir::OperationState &{0}", builderOpState).str()); + arguments.emplace_back("::mlir::OpBuilder &", odsBuilder); + arguments.emplace_back("::mlir::OperationState &", builderOpState); for (unsigned i = 0, e = params.size(); i < e; ++i) { // If no name is provided, generate one. @@ -1497,27 +1487,27 @@ std::string name = paramName ? paramName->str() : "odsArg" + std::to_string(i); - std::string defaultValue; + StringRef defaultValue; if (Optional defaultParamValue = params[i].getDefaultValue()) - defaultValue = llvm::formatv(" = {0}", *defaultParamValue).str(); - arguments.push_back( - llvm::formatv("{0} {1}{2}", params[i].getCppType(), name, defaultValue) - .str()); + defaultValue = *defaultParamValue; + + arguments.emplace_back(params[i].getCppType(), std::move(name), + defaultValue); } - return llvm::join(arguments, ", "); + return arguments; } void OpEmitter::genBuilder() { // Handle custom builders if provided. for (const Builder &builder : op.getBuilders()) { - std::string paramStr = getBuilderSignature(builder); + SmallVector arguments = getBuilderSignature(builder); Optional body = builder.getBody(); - OpMethod::Property properties = - body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; + Method::Property properties = + body ? Method::MP_Static : Method::MP_StaticDeclaration; auto *method = - opClass.addMethodAndPrune("void", "build", properties, paramStr); + opClass.addMethod("void", "build", properties, std::move(arguments)); if (body) ERROR_IF_PRUNED(method, "build", op); @@ -1561,7 +1551,7 @@ int numVariadicOperands = op.getNumVariableLengthOperands(); int numNonVariadicOperands = numOperands - numVariadicOperands; - SmallVector paramList; + SmallVector paramList; paramList.emplace_back("::mlir::OpBuilder &", ""); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::TypeRange", "resultTypes"); @@ -1573,8 +1563,7 @@ if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); - auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, - std::move(paramList)); + auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) return; @@ -1612,7 +1601,7 @@ genInferredTypeCollectiveParamBuilder(); } -void OpEmitter::buildParamList(SmallVectorImpl ¶mList, +void OpEmitter::buildParamList(SmallVectorImpl ¶mList, llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, @@ -1637,11 +1626,8 @@ StringRef type = result.isVariadic() ? "::mlir::TypeRange" : "::mlir::Type"; - OpMethodParameter::Property properties = OpMethodParameter::PP_None; - if (result.isOptional()) - properties = OpMethodParameter::PP_Optional; - paramList.emplace_back(type, resultName, properties); + paramList.emplace_back(type, resultName, result.isOptional()); resultTypeNames.emplace_back(std::move(resultName)); } } break; @@ -1699,11 +1685,8 @@ else type = "::mlir::Value"; - OpMethodParameter::Property properties = OpMethodParameter::PP_None; - if (operand->isOptional()) - properties = OpMethodParameter::PP_Optional; paramList.emplace_back(type, getArgumentName(op, numOperands++), - properties); + operand->isOptional()); continue; } const NamedAttribute &namedAttr = *arg.get(); @@ -1713,10 +1696,6 @@ if (inferredAttributes.contains(namedAttr.name)) continue; - OpMethodParameter::Property properties = OpMethodParameter::PP_None; - if (attr.isOptional()) - properties = OpMethodParameter::PP_Optional; - StringRef type; switch (attrParamKind) { case AttrParamKind::WrappedAttr: @@ -1736,7 +1715,8 @@ i >= defaultValuedAttrStartIndex) { defaultValue += attr.getDefaultValue(); } - paramList.emplace_back(type, namedAttr.name, defaultValue, properties); + paramList.emplace_back(type, namedAttr.name, defaultValue, + attr.isOptional()); } /// Insert parameters for each successor. @@ -1754,7 +1734,7 @@ } void OpEmitter::genCodeForAddingArgAndRegionForBuilder( - OpMethodBody &body, llvm::StringSet<> &inferredAttributes, + MethodBody &body, llvm::StringSet<> &inferredAttributes, bool isRawValueAttr) { // Push all operands to the result. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { @@ -1871,12 +1851,11 @@ if (hasCanonicalizeMethod) { // static LogicResult FooOp:: // canonicalize(FooOp op, PatternRewriter &rewriter); - SmallVector paramList; + SmallVector paramList; paramList.emplace_back(op.getCppClassName(), "op"); paramList.emplace_back("::mlir::PatternRewriter &", "rewriter"); - auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize", - OpMethod::MP_StaticDeclaration, - std::move(paramList)); + auto *m = opClass.declareStaticMethod("::mlir::LogicalResult", + "canonicalize", std::move(paramList)); ERROR_IF_PRUNED(m, "canonicalize", op); } @@ -1892,12 +1871,12 @@ // Add a signature for getCanonicalizationPatterns if implemented by the // dialect or if synthesized to call 'canonicalize'. - SmallVector paramList; + SmallVector paramList; paramList.emplace_back("::mlir::RewritePatternSet &", "results"); paramList.emplace_back("::mlir::MLIRContext *", "context"); - auto kind = hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; - auto *method = opClass.addMethodAndPrune( - "void", "getCanonicalizationPatterns", kind, std::move(paramList)); + auto kind = hasBody ? Method::MP_Static : Method::MP_StaticDeclaration; + auto *method = opClass.addMethod("void", "getCanonicalizationPatterns", kind, + std::move(paramList)); // If synthesizing the method, fill it it. if (hasBody) { @@ -1912,18 +1891,17 @@ if (def.getValueAsBit("hasFolder")) { if (hasSingleResult) { - auto *m = opClass.addMethodAndPrune( - "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration, - "::llvm::ArrayRef<::mlir::Attribute>", "operands"); + auto *m = opClass.declareMethod( + "::mlir::OpFoldResult", "fold", + MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands")); ERROR_IF_PRUNED(m, "operands", op); } else { - SmallVector paramList; + SmallVector paramList; paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", "results"); - auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "fold", - OpMethod::MP_Declaration, - std::move(paramList)); + auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold", + std::move(paramList)); ERROR_IF_PRUNED(m, "fold", op); } } @@ -1953,18 +1931,18 @@ } } -OpMethod *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, - bool declaration) { - SmallVector paramList; +Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, + bool declaration) { + SmallVector paramList; for (const InterfaceMethod::Argument &arg : method.getArguments()) paramList.emplace_back(arg.type, arg.name); - auto properties = method.isStatic() ? OpMethod::MP_Static : OpMethod::MP_None; + auto properties = method.isStatic() ? Method::MP_Static : Method::MP_None; if (declaration) properties = - static_cast(properties | OpMethod::MP_Declaration); - return opClass.addMethodAndPrune(method.getReturnType(), method.getName(), - properties, std::move(paramList)); + static_cast(properties | Method::MP_Declaration); + return opClass.addMethod(method.getReturnType(), method.getName(), properties, + std::move(paramList)); } void OpEmitter::genOpInterfaceMethods() { @@ -2039,8 +2017,8 @@ "SideEffects::EffectInstance<{0}>> &", it.first()) .str(); - auto *getEffects = - opClass.addMethodAndPrune("void", "getEffects", type, "effects"); + auto *getEffects = opClass.addMethod("void", "getEffects", + MethodParameter(type, "effects")); ERROR_IF_PRUNED(getEffects, "getEffects", op); auto &body = getEffects->body(); @@ -2082,7 +2060,7 @@ const auto *trait = dyn_cast( op.getTrait("::mlir::InferTypeOpInterface::Trait")); Interface interface = trait->getInterface(); - OpMethod *method = [&]() -> OpMethod * { + Method *method = [&]() -> Method * { for (const InterfaceMethod &interfaceMethod : interface.getMethods()) { if (interfaceMethod.getName() == "inferReturnTypes") { return genOpInterfaceMethod(interfaceMethod, /*declaration=*/false); @@ -2099,8 +2077,7 @@ fctx.withBuilder("odsBuilder"); body << " ::mlir::Builder odsBuilder(context);\n"; - auto emitType = - [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & { + auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> MethodBody & { if (!type.isArg()) return body << tgfmt(*type.getType().getBuilderCall(), &fctx); auto argIndex = type.getArg(); @@ -2129,12 +2106,11 @@ hasStringAttribute(def, "assemblyFormat")) return; - SmallVector paramList; + SmallVector paramList; paramList.emplace_back("::mlir::OpAsmParser &", "parser"); paramList.emplace_back("::mlir::OperationState &", "result"); - auto *method = - opClass.addMethodAndPrune("::mlir::ParseResult", "parse", - OpMethod::MP_Static, std::move(paramList)); + auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse", + std::move(paramList)); ERROR_IF_PRUNED(method, "parse", op); FmtContext fctx; @@ -2152,8 +2128,8 @@ if (!stringInit) return; - auto *method = - opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p"); + auto *method = opClass.addMethod( + "void", "print", MethodParameter("::mlir::OpAsmPrinter &", "p")); ERROR_IF_PRUNED(method, "print", op); FmtContext fctx; fctx.addSubst("cppClass", opClass.getClassName()); @@ -2162,7 +2138,7 @@ } /// Generate verification on native traits requiring attributes. -static void genNativeTraitAttrVerifier(OpMethodBody &body, +static void genNativeTraitAttrVerifier(MethodBody &body, const OpOrAdaptorHelper &emitHelper) { // Check that the variadic segment sizes attribute exists and contains the // expected number of elements. @@ -2209,7 +2185,7 @@ } void OpEmitter::genVerifier() { - auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify"); + auto *method = opClass.addMethod("::mlir::LogicalResult", "verify"); ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); @@ -2247,7 +2223,7 @@ } } -void OpEmitter::genOperandResultVerifier(OpMethodBody &body, +void OpEmitter::genOperandResultVerifier(MethodBody &body, Operator::value_range values, StringRef valueKind) { // Check that an optional value is at most 1 element. @@ -2321,7 +2297,7 @@ body << " }\n"; } -void OpEmitter::genRegionVerifier(OpMethodBody &body) { +void OpEmitter::genRegionVerifier(MethodBody &body) { /// Code to verify a region. /// /// {0}: Getter for the regions. @@ -2363,7 +2339,7 @@ body << " }\n"; } -void OpEmitter::genSuccessorVerifier(OpMethodBody &body) { +void OpEmitter::genSuccessorVerifier(MethodBody &body) { const char *const verifySuccessor = R"( for (auto *successor : {0}) if (::mlir::failed({1}(*this, successor, "{2}", index++))) @@ -2485,9 +2461,8 @@ } void OpEmitter::genOpNameGetter() { - auto *method = opClass.addMethodAndPrune( - "::llvm::StringLiteral", "getOperationName", - OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr)); + auto *method = opClass.addStaticMethod( + "::llvm::StringLiteral", "getOperationName"); ERROR_IF_PRUNED(method, "getOperationName", op); method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName() << "\");"; @@ -2514,8 +2489,9 @@ opClass.addTrait("::mlir::OpAsmOpInterface::Trait"); // Generate the right accessor for the number of results. - auto *method = opClass.addMethodAndPrune( - "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn"); + auto *method = opClass.addMethod( + "void", "getAsmResultNames", + MethodParameter("::mlir::OpAsmSetValueNameFn", "setNameFn")); ERROR_IF_PRUNED(method, "getAsmResultNames", op); auto &body = method->body(); for (int i = 0; i != numResults; ++i) { @@ -2567,7 +2543,7 @@ const auto *attrSizedOperands = op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); { - SmallVector paramList; + SmallVector paramList; paramList.emplace_back("::mlir::ValueRange", "values"); paramList.emplace_back("::mlir::DictionaryAttr", "attrs", attrSizedOperands ? "" : "nullptr"); @@ -2581,14 +2557,14 @@ { auto *constructor = adaptor.addConstructorAndPrune( - llvm::formatv("{0}&", op.getCppClassName()).str(), "op"); + MethodParameter(op.getCppClassName() + " &", "op")); constructor->addMemberInitializer("odsOperands", "op->getOperands()"); constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); constructor->addMemberInitializer("odsRegions", "op->getRegions()"); } { - auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands"); + auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands"); ERROR_IF_PRUNED(m, "getOperands", op); m->body() << " return odsOperands;"; } @@ -2605,7 +2581,7 @@ fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); auto emitAttr = [&](StringRef name, StringRef emitName, Attribute attr) { - auto *method = adaptor.addMethodAndPrune(attr.getStorageType(), emitName); + auto *method = adaptor.addMethod(attr.getStorageType(), emitName); ERROR_IF_PRUNED(method, "Adaptor::" + emitName, op); auto &body = method->body(); body << " assert(odsAttrs && \"no attributes when constructing adapter\");" @@ -2629,8 +2605,7 @@ }; { - auto *m = - adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes"); + auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes"); ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); m->body() << " return odsAttrs;"; } @@ -2645,7 +2620,7 @@ unsigned numRegions = op.getNumRegions(); if (numRegions > 0) { - auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions"); + auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions"); ERROR_IF_PRUNED(m, "Adaptor::getRegions", op); m->body() << " return odsRegions;"; } @@ -2657,13 +2632,13 @@ // Generate the accessors for a variadic region. for (StringRef name : op.getGetterNames(region.name)) { if (region.isVariadic()) { - auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", name); + auto *m = adaptor.addMethod("::mlir::RegionRange", name); ERROR_IF_PRUNED(m, "Adaptor::" + name, op); m->body() << formatv(" return odsRegions.drop_front({0});", i); continue; } - auto *m = adaptor.addMethodAndPrune("::mlir::Region &", name); + auto *m = adaptor.addMethod("::mlir::Region &", name); ERROR_IF_PRUNED(m, "Adaptor::" + name, op); m->body() << formatv(" return *odsRegions[{0}];", i); } @@ -2674,8 +2649,8 @@ } void OpOperandAdaptorEmitter::addVerification() { - auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify", - "::mlir::Location", "loc"); + auto *method = adaptor.addMethod("::mlir::LogicalResult", "verify", + MethodParameter("::mlir::Location", "loc")); ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -9,10 +9,10 @@ #include "OpFormatGen.h" #include "FormatGen.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/TableGen/Class.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" -#include "mlir/TableGen/OpClass.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Trait.h" #include "llvm/ADT/MapVector.h" @@ -140,8 +140,7 @@ namespace { /// This class implements single kind directives. -template -class DirectiveElement : public Element { +template class DirectiveElement : public Element { public: DirectiveElement() : Element(type){}; static bool classof(const Element *ele) { return ele->getKind() == type; } @@ -422,23 +421,23 @@ /// Generate the operation parser from this format. void genParser(Operator &op, OpClass &opClass); /// Generate the parser code for a specific format element. - void genElementParser(Element *element, OpMethodBody &body, + void genElementParser(Element *element, MethodBody &body, FmtContext &attrTypeCtx); /// Generate the c++ to resolve the types of operands and results during /// parsing. - void genParserTypeResolution(Operator &op, OpMethodBody &body); + void genParserTypeResolution(Operator &op, MethodBody &body); /// Generate the c++ to resolve regions during parsing. - void genParserRegionResolution(Operator &op, OpMethodBody &body); + void genParserRegionResolution(Operator &op, MethodBody &body); /// Generate the c++ to resolve successors during parsing. - void genParserSuccessorResolution(Operator &op, OpMethodBody &body); + void genParserSuccessorResolution(Operator &op, MethodBody &body); /// Generate the c++ to handling variadic segment size traits. - void genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body); + void genParserVariadicSegmentResolution(Operator &op, MethodBody &body); /// Generate the operation printer from this format. void genPrinter(Operator &op, OpClass &opClass); /// Generate the printer code for a specific format element. - void genElementPrinter(Element *element, OpMethodBody &body, Operator &op, + void genElementPrinter(Element *element, MethodBody &body, Operator &op, bool &shouldEmitSpace, bool &lastWasPunctuation); /// The various elements in this format. @@ -813,7 +812,7 @@ } /// Generate the parser for a literal value. -static void genLiteralParser(StringRef value, OpMethodBody &body) { +static void genLiteralParser(StringRef value, MethodBody &body) { // Handle the case of a keyword/identifier. if (value.front() == '_' || isalpha(value.front())) { body << "Keyword(\"" << value << "\")"; @@ -839,7 +838,7 @@ /// Generate the storage code required for parsing the given element. static void genElementParserStorage(Element *element, const Operator &op, - OpMethodBody &body) { + MethodBody &body) { if (auto *optional = dyn_cast(element)) { auto elements = optional->getThenElements(); @@ -937,7 +936,7 @@ } /// Generate the parser for a parameter to a custom directive. -static void genCustomParameterParser(Element ¶m, OpMethodBody &body) { +static void genCustomParameterParser(Element ¶m, MethodBody &body) { if (auto *attr = dyn_cast(¶m)) { body << attr->getVar()->name << "Attr"; } else if (isa(¶m)) { @@ -988,7 +987,7 @@ } /// Generate the parser for a custom directive. -static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) { +static void genCustomDirectiveParser(CustomDirective *dir, MethodBody &body) { body << " {\n"; // Preprocess the directive variables. @@ -1098,7 +1097,7 @@ } /// Generate the parser for a enum attribute. -static void genEnumAttrParser(const NamedAttribute *var, OpMethodBody &body, +static void genEnumAttrParser(const NamedAttribute *var, MethodBody &body, FmtContext &attrTypeCtx) { Attribute baseAttr = var->attr.getBaseAttr(); const EnumAttr &enumAttr = cast(baseAttr); @@ -1141,13 +1140,12 @@ } void OperationFormat::genParser(Operator &op, OpClass &opClass) { - llvm::SmallVector paramList; + SmallVector paramList; paramList.emplace_back("::mlir::OpAsmParser &", "parser"); paramList.emplace_back("::mlir::OperationState &", "result"); - auto *method = - opClass.addMethodAndPrune("::mlir::ParseResult", "parse", - OpMethod::MP_Static, std::move(paramList)); + auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse", + std::move(paramList)); auto &body = method->body(); // Generate variables to store the operands and type within the format. This @@ -1174,7 +1172,7 @@ body << " return ::mlir::success();\n"; } -void OperationFormat::genElementParser(Element *element, OpMethodBody &body, +void OperationFormat::genElementParser(Element *element, MethodBody &body, FmtContext &attrTypeCtx) { /// Optional Group. if (auto *optional = dyn_cast(element)) { @@ -1353,8 +1351,7 @@ } } -void OperationFormat::genParserTypeResolution(Operator &op, - OpMethodBody &body) { +void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { // If any of type resolutions use transformed variables, make sure that the // types of those variables are resolved. SmallPtrSet verifiedVariables; @@ -1528,7 +1525,7 @@ } void OperationFormat::genParserRegionResolution(Operator &op, - OpMethodBody &body) { + MethodBody &body) { // Check for the case where all regions were parsed. bool hasAllRegions = llvm::any_of( elements, [](auto &elt) { return isa(elt.get()); }); @@ -1547,7 +1544,7 @@ } void OperationFormat::genParserSuccessorResolution(Operator &op, - OpMethodBody &body) { + MethodBody &body) { // Check for the case where all successors were parsed. bool hasAllSuccessors = llvm::any_of( elements, [](auto &elt) { return isa(elt.get()); }); @@ -1566,7 +1563,7 @@ } void OperationFormat::genParserVariadicSegmentResolution(Operator &op, - OpMethodBody &body) { + MethodBody &body) { if (!allOperands) { if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { body << " result.addAttribute(\"operand_segment_sizes\", " @@ -1641,7 +1638,7 @@ /// Generate the printer for the 'attr-dict' directive. static void genAttrDictPrinter(OperationFormat &fmt, Operator &op, - OpMethodBody &body, bool withKeyword) { + MethodBody &body, bool withKeyword) { body << " _odsPrinter.printOptionalAttrDict" << (withKeyword ? "WithKeyword" : "") << "((*this)->getAttrs(), /*elidedAttrs=*/{"; @@ -1665,7 +1662,7 @@ /// Generate the printer for a literal value. `shouldEmitSpace` is true if a /// space should be emitted before this element. `lastWasPunctuation` is true if /// the previous element was a punctuation literal. -static void genLiteralPrinter(StringRef value, OpMethodBody &body, +static void genLiteralPrinter(StringRef value, MethodBody &body, bool &shouldEmitSpace, bool &lastWasPunctuation) { body << " _odsPrinter"; @@ -1682,8 +1679,8 @@ /// Generate the printer for a space. `shouldEmitSpace` and `lastWasPunctuation` /// are set to false. -static void genSpacePrinter(bool value, OpMethodBody &body, - bool &shouldEmitSpace, bool &lastWasPunctuation) { +static void genSpacePrinter(bool value, MethodBody &body, bool &shouldEmitSpace, + bool &lastWasPunctuation) { if (value) { body << " _odsPrinter << ' ';\n"; lastWasPunctuation = false; @@ -1696,7 +1693,7 @@ /// Generate the printer for a custom directive parameter. static void genCustomDirectiveParameterPrinter(Element *element, const Operator &op, - OpMethodBody &body) { + MethodBody &body) { if (auto *attr = dyn_cast(element)) { body << op.getGetterName(attr->getVar()->name) << "Attr()"; @@ -1734,7 +1731,7 @@ /// Generate the printer for a custom directive. static void genCustomDirectivePrinter(CustomDirective *customDir, - const Operator &op, OpMethodBody &body) { + const Operator &op, MethodBody &body) { body << " print" << customDir->getName() << "(_odsPrinter, *this"; for (Element ¶m : customDir->getArguments()) { body << ", "; @@ -1744,7 +1741,7 @@ } /// Generate the printer for a region with the given variable name. -static void genRegionPrinter(const Twine ®ionName, OpMethodBody &body, +static void genRegionPrinter(const Twine ®ionName, MethodBody &body, bool hasImplicitTermTrait) { if (hasImplicitTermTrait) body << llvm::formatv(regionSingleBlockImplicitTerminatorPrinterCode, @@ -1753,7 +1750,7 @@ body << " _odsPrinter.printRegion(" << regionName << ");\n"; } static void genVariadicRegionPrinter(const Twine ®ionListName, - OpMethodBody &body, + MethodBody &body, bool hasImplicitTermTrait) { body << " llvm::interleaveComma(" << regionListName << ", _odsPrinter, [&](::mlir::Region ®ion) {\n "; @@ -1762,8 +1759,8 @@ } /// Generate the C++ for an operand to a (*-)type directive. -static OpMethodBody &genTypeOperandPrinter(Element *arg, const Operator &op, - OpMethodBody &body) { +static MethodBody &genTypeOperandPrinter(Element *arg, const Operator &op, + MethodBody &body) { if (isa(arg)) return body << "getOperation()->getOperandTypes()"; if (isa(arg)) @@ -1786,7 +1783,7 @@ /// Generate the printer for an enum attribute. static void genEnumAttrPrinter(const NamedAttribute *var, const Operator &op, - OpMethodBody &body) { + MethodBody &body) { Attribute baseAttr = var->attr.getBaseAttr(); const EnumAttr &enumAttr = cast(baseAttr); std::vector cases = enumAttr.getAllCases(); @@ -1864,7 +1861,7 @@ /// Generate the check for the anchor of an optional group. static void genOptionalGroupPrinterAnchor(Element *anchor, const Operator &op, - OpMethodBody &body) { + MethodBody &body) { TypeSwitch(anchor) .Case([&](auto *element) { const NamedTypeConstraint *var = element->getVar(); @@ -1892,7 +1889,7 @@ }); } -void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body, +void OperationFormat::genElementPrinter(Element *element, MethodBody &body, Operator &op, bool &shouldEmitSpace, bool &lastWasPunctuation) { if (LiteralElement *literal = dyn_cast(element)) @@ -2047,8 +2044,9 @@ } void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { - auto *method = opClass.addMethodAndPrune("void", "print", - "::mlir::OpAsmPrinter &_odsPrinter"); + auto *method = opClass.addMethod( + "void", "print", + MethodParameter("::mlir::OpAsmPrinter &", "_odsPrinter")); auto &body = method->body(); // Flags for if we should emit a space, and if the last element was @@ -2065,8 +2063,7 @@ /// Function to find an element within the given range that has the same name as /// 'name'. -template -static auto findArg(RangeT &&range, StringRef name) { +template static auto findArg(RangeT &&range, StringRef name) { auto it = llvm::find_if(range, [=](auto &arg) { return arg.name == name; }); return it != range.end() ? &*it : nullptr; }