diff --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h --- a/mlir/include/mlir/TableGen/OpClass.h +++ b/mlir/include/mlir/TableGen/OpClass.h @@ -27,32 +27,196 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/raw_ostream.h" +#include #include namespace mlir { namespace tblgen { class FmtObjectBase; +// Class for holding a single parameter of an op's method for C++ code emission. +class OpMethodParameter { +public: + // Properties (qualifiers) for the parameter. + enum Property { + PP_None = 0x0, + PP_Optional = 0x1, + }; + + OpMethodParameter(StringRef type, StringRef name, StringRef defaultValue = "", + Property properties = PP_None) + : type(type), name(name), defaultValue(defaultValue), + properties(properties) {} + + OpMethodParameter(StringRef type, StringRef name, Property property) + : OpMethodParameter(type, name, "", property) {} + + // write the parameter as a part of a method declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/true); } + + // write the parameter as a part of a method definition to the given `os` + void writeDefTo(raw_ostream &os) const { writeTo(os, /*emitDefault=*/false); } + + const std::string &getType() const { return type; } + bool hasDefaultValue() const { return !defaultValue.empty(); } + +private: + void writeTo(raw_ostream &os, bool emitDefault) const; + + std::string type; + std::string name; + std::string defaultValue; + Property properties; +}; + +// Base class for holding parameters of an op's method for C++ code emission. +class OpMethodParameters { +public: + // Discriminator for LLVM-style RTTI. + enum ParamsKind { + PK_Unresolved, // Separate type and name for each parameter is not known + PK_Resolved, // Each parameter is resolved to a type and name + }; + + OpMethodParameters(ParamsKind K) : Kind(K) {} + virtual ~OpMethodParameters() {} + + // LLVM-style RTTI support. + ParamsKind getKind() const { return Kind; } + + // write the parameters as a part of a method declaration to the given `os`. + virtual void writeDeclTo(raw_ostream &os) const = 0; + + // write the parameters as a part of a method definition to the given `os` + virtual void writeDefTo(raw_ostream &os) const = 0; + + // Factory methods to create the correct type of `OpMethodParameters` + // object based on the arguments. + static std::unique_ptr create(); + + static std::unique_ptr create(StringRef params); + + static std::unique_ptr + create(llvm::SmallVectorImpl &¶ms); + + static std::unique_ptr + create(StringRef type, StringRef name, StringRef defaultValue = ""); + +private: + const ParamsKind Kind; +}; + +// Class for holding unresolved parameters. +class OpMethodUnresolvedParameters : public OpMethodParameters { +public: + OpMethodUnresolvedParameters(StringRef params) + : OpMethodParameters(PK_Unresolved), parameters(params) {} + + // write the parameters as a part of a method declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const override; + + // write the parameters as a part of a method definition to the given `os` + void writeDefTo(raw_ostream &os) const override; + + // LLVM-style RTTI support. + static bool classof(const OpMethodParameters *P) { + return P->getKind() == PK_Unresolved; + } + +private: + std::string parameters; +}; + +// Class for holding resolved parameters. +class OpMethodResolvedParameters : public OpMethodParameters { +public: + OpMethodResolvedParameters() : OpMethodParameters(PK_Resolved) {} + + OpMethodResolvedParameters(llvm::SmallVectorImpl &¶ms) + : OpMethodParameters(PK_Resolved) { + for (OpMethodParameter ¶m : params) + parameters.emplace_back(std::move(param)); + } + + OpMethodResolvedParameters(StringRef type, StringRef name, + StringRef defaultValue) + : OpMethodParameters(PK_Resolved) { + parameters.emplace_back(type, name, defaultValue); + } + + // Returns the number of parameters. + size_t getNumParameters() const { return parameters.size(); } + + // Returns if this method makes the `other` method redundant. Note that this + // is more than just finding conflicting methods. This method determines if + // the 2 set of parameters are conflicting and if so, returns true if this + // method has a more general set of parameters that can replace all possible + // calls to the `other` method. + bool makesRedundant(const OpMethodResolvedParameters &other) const; + + // write the parameters as a part of a method declaration to the given `os`. + void writeDeclTo(raw_ostream &os) const override; + + // write the parameters as a part of a method definition to the given `os` + void writeDefTo(raw_ostream &os) const override; + + // LLVM-style RTTI support. + static bool classof(const OpMethodParameters *P) { + return P->getKind() == PK_Resolved; + } + +private: + llvm::SmallVector parameters; +}; + // Class for holding the signature of an op's method for C++ code emission class OpMethodSignature { public: - OpMethodSignature(StringRef retType, StringRef name, StringRef params); + template + OpMethodSignature(StringRef retType, StringRef name, Args &&...args) + : returnType(retType), methodName(name), + parameters(OpMethodParameters::create(std::forward(args)...)) {} + OpMethodSignature(OpMethodSignature &&) = default; + + // Returns if a method with this signature makes a method with `other` + // signature redundant. Only supports resolved parameters. If either + // signature is using unresolved parameters, returns false. + bool makesRedundant(const OpMethodSignature &other) const { + if (methodName != other.methodName) + return false; + auto *resolved_this = + dyn_cast(parameters.get()); + auto *resolved_other = + dyn_cast(other.parameters.get()); + if (resolved_this && resolved_other) + return resolved_this->makesRedundant(*resolved_other); + return false; + } + + // Returns the number of parameters (for resolved parameters). + size_t getNumParameters() const { + assert(isa(parameters.get())); + return cast(parameters.get()) + ->getNumParameters(); + } + + // Returns the name of the method. + StringRef getName() const { return methodName; } // Writes the signature as a method declaration to the given `os`. void writeDeclTo(raw_ostream &os) const; + // Writes the signature as the start of a method definition to the given `os`. // `namePrefix` is the prefix to be prepended to the method name (typically // namespaces for qualifying the method definition). void writeDefTo(raw_ostream &os, StringRef namePrefix) const; private: - // Returns true if the given C++ `type` ends with '&' or '*', or is empty. - static bool elideSpaceAfterType(StringRef type); - std::string returnType; std::string methodName; - std::string parameters; + std::unique_ptr parameters; }; // Class for holding the body of an op's method for C++ code emission @@ -79,13 +243,22 @@ // querying properties. enum Property { MP_None = 0x0, - MP_Static = 0x1, // Static method - MP_Constructor = 0x2, // Constructor - MP_Private = 0x4, // Private method + MP_Static = 0x1, // Static method + MP_Constructor = 0x2, // Constructor + MP_Private = 0x4, // Private method + MP_DeclOnly = 0x8, // Declaration only + MP_StaticDeclOnly = MP_Static | MP_DeclOnly, // Static Declaration only }; - OpMethod(StringRef retType, StringRef name, StringRef params, - Property property, bool declOnly); + template + OpMethod(StringRef retType, StringRef name, Property property, unsigned id, + Args &&...args) + : properties(property), + methodSignature(retType, name, std::forward(args)...), + methodBody(properties & MP_DeclOnly), id(id) {} + + OpMethod(OpMethod &&) = default; + virtual ~OpMethod() = default; OpMethodBody &body() { return methodBody; } @@ -96,8 +269,20 @@ // Returns true if this is a private method. bool isPrivate() const { return properties & MP_Private; } + // Returns the name of this method. + StringRef getName() const { return methodSignature.getName(); } + + // Returns the ID for this method + unsigned getID() const { return id; } + + // Returns if this method makes the `other` method redundant. + bool makesRedundant(const OpMethod &other) const { + return methodSignature.makesRedundant(other.methodSignature); + } + // Writes the method as a declaration to the given `os`. virtual void writeDeclTo(raw_ostream &os) const; + // Writes the method as a definition to the given `os`. `namePrefix` is the // prefix to be prepended to the method name (typically namespaces for // qualifying the method definition). @@ -105,18 +290,18 @@ protected: Property properties; - // Whether this method only contains a declaration. - bool isDeclOnly; OpMethodSignature methodSignature; OpMethodBody methodBody; + const unsigned id; }; // Class for holding an op's constructor method for C++ code emission. class OpConstructor : public OpMethod { public: - OpConstructor(StringRef retType, StringRef name, StringRef params, - Property property, bool declOnly) - : OpMethod(retType, name, params, property, declOnly){}; + template + OpConstructor(StringRef className, Property property, unsigned id, + Args &&...args) + : OpMethod("", className, property, id, std::forward(args)...){}; // Add member initializer to constructor initializing `name` with `value`. void addMemberInitializer(StringRef name, StringRef value); @@ -137,12 +322,32 @@ public: explicit Class(StringRef name); - // Creates a new method in this class. - OpMethod &newMethod(StringRef retType, StringRef name, StringRef params = "", - OpMethod::Property = OpMethod::MP_None, - bool declOnly = false); - - OpConstructor &newConstructor(StringRef params = "", bool declOnly = false); + // Adds a new method to this class and prune redundant methods. Returns null + // if the method being added gets pruned, else returns a pointer to the added + // method. Note that this call may also delete existing methods that are made + // redundant by this method to the class. + template + OpMethod *addMethodAndPrune(StringRef retType, StringRef name, + OpMethod::Property properties, Args &&...args) { + auto newMethod = std::make_unique( + retType, name, properties, nextMethodID++, std::forward(args)...); + return addMethodAndPrune(methods, std::move(newMethod)); + } + + template + OpMethod *addMethodAndPrune(StringRef retType, StringRef name, + Args &&...args) { + return addMethodAndPrune(retType, name, OpMethod::MP_None, + std::forward(args)...); + } + + template + OpConstructor *addConstructorAndPrune(Args &&...args) { + auto newConstructor = std::make_unique( + getClassName(), OpMethod::MP_Constructor, nextMethodID++, + std::forward(args)...); + return addMethodAndPrune(constructors, std::move(newConstructor)); + } // Creates a new field in this class. void newField(StringRef type, StringRef name, StringRef defaultValue = ""); @@ -156,9 +361,63 @@ StringRef getClassName() const { return className; } protected: + // Get a list of all the methods to emit, filtering out hidden ones. + void forAllMethods(llvm::function_ref Func) const { + using ConsRef = const std::unique_ptr &; + using MethodRef = const std::unique_ptr &; + llvm::for_each(constructors, [&](ConsRef ptr) { Func(*ptr); }); + llvm::for_each(methods, [&](MethodRef ptr) { Func(*ptr); }); + } + + // For deterministic code generation, keep methods sorted in the order in + // which they were generated. + template + struct MethodCompare { + bool operator()(const std::unique_ptr &x, + const std::unique_ptr &y) { + return x->getID() < y->getID(); + } + }; + + template + using MethodSet = + std::set, MethodCompare>; + + template + MethodTy *addMethodAndPrune(MethodSet &set, + std::unique_ptr &&newMethod) { + // Check if the new method will be made redundant by existing methods. + for (auto &method : set) + if (method->makesRedundant(*newMethod)) + return nullptr; + + // We can add this a method to the set. Prune any existing methods that will + // be made redundant by adding this new method. Note that the redundant + // check between two methods is more than a conflict check. makesRedundant() + // below will check if the new method conflicts with an existing method and + // if so, returns true if the new method makes the existing method redundant + // because all calls to the existing method can be subsumed by the new + // method. So makesRedundant() does a combined job of finding conflicts and + // deciding which of the 2 conflicting methods survive. + // + // Note: llvm::erase_if does not work with sets of std::unique_ptr, so doing + // it manually here. + for (auto it = set.begin(), end = set.end(); it != end;) { + if (newMethod->makesRedundant(*(it->get()))) + it = set.erase(it); + else + ++it; + } + + MethodTy *ret = newMethod.get(); + set.insert(std::move(newMethod)); + return ret; + } + std::string className; - SmallVector constructors; - SmallVector methods; + MethodSet constructors; + MethodSet methods; + unsigned nextMethodID = 0; SmallVector fields; }; diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp --- a/mlir/lib/TableGen/OpClass.cpp +++ b/mlir/lib/TableGen/OpClass.cpp @@ -9,50 +9,148 @@ #include "mlir/TableGen/OpClass.h" #include "mlir/TableGen/Format.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include + +#define DEBUG_TYPE "mlir-tblgen-opclass" using namespace mlir; using namespace mlir::tblgen; +namespace { + +// Returns space to be emitted after the given C++ `type`. return "" if the +// ends with '&' or '*', or is empty, else returns " ". +StringRef getSpaceAfterType(StringRef type) { + return (type.empty() || type.endswith("&") || type.endswith("*")) ? "" : " "; +} + +} // namespace + //===----------------------------------------------------------------------===// -// OpMethodSignature definitions +// OpMethodParameter definitions //===----------------------------------------------------------------------===// -OpMethodSignature::OpMethodSignature(StringRef retType, StringRef name, - StringRef params) - : returnType(retType), methodName(name), parameters(params) {} +void OpMethodParameter::writeTo(raw_ostream &os, bool emitDefault) const { + if (properties & PP_Optional) + os << "/*optional*/"; + os << type << getSpaceAfterType(type) << name; + if (emitDefault && !defaultValue.empty()) + os << " = " << defaultValue; +} + +//===----------------------------------------------------------------------===// +// OpMethodParameters definitions +//===----------------------------------------------------------------------===// -void OpMethodSignature::writeDeclTo(raw_ostream &os) const { - os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << methodName - << "(" << parameters << ")"; +// Factory methods to construct the correct type of `OpMethodParameters` +// object based on the arguments. +std::unique_ptr OpMethodParameters::create() { + return std::make_unique(); } -void OpMethodSignature::writeDefTo(raw_ostream &os, - StringRef namePrefix) const { +std::unique_ptr +OpMethodParameters::create(StringRef params) { + return std::make_unique(params); +} + +std::unique_ptr +OpMethodParameters::create(llvm::SmallVectorImpl &¶ms) { + return std::make_unique(std::move(params)); +} + +std::unique_ptr +OpMethodParameters::create(StringRef type, StringRef name, + StringRef defaultValue) { + return std::make_unique(type, name, defaultValue); +} + +//===----------------------------------------------------------------------===// +// OpMethodUnresolvedParameters definitions +//===----------------------------------------------------------------------===// +void OpMethodUnresolvedParameters::writeDeclTo(raw_ostream &os) const { + os << parameters; +} + +void OpMethodUnresolvedParameters::writeDefTo(raw_ostream &os) const { // We need to remove the default values for parameters in method definition. // TODO: We are using '=' and ',' as delimiters for parameter // initializers. This is incorrect for initializer list with more than one // element. Change to a more robust approach. - auto removeParamDefaultValue = [](StringRef params) { - std::string result; - std::pair parts; - while (!params.empty()) { - parts = params.split("="); - result.append(result.empty() ? "" : ", "); - result += parts.first; - params = parts.second.split(",").second; - } - return result; - }; + StringRef params = parameters; + bool First = true; + while (!params.empty()) { + std::pair parts = params.split("="); + if (!First) + os << ", "; + os << parts.first; + params = parts.second.split(",").second; + First = false; + } +} + +//===----------------------------------------------------------------------===// +// OpMethodResolvedParameters definitions +//===----------------------------------------------------------------------===// + +// Returns true if a method with these parameters makes a method with parameters +// `other` redundant. This should return true only if all possible calls to the +// other method can be replaced by calls to this method. +bool OpMethodResolvedParameters::makesRedundant( + const OpMethodResolvedParameters &other) const { + const size_t otherNumParams = other.getNumParameters(); + const size_t thisNumParams = getNumParameters(); + + // All calls to the other method can be replaced this method only if this + // method has the same or more arguments number of arguments as the other, and + // the common arguments have the same type. + if (thisNumParams < otherNumParams) + return false; + for (int idx : llvm::seq(0, otherNumParams)) + if (parameters[idx].getType() != other.parameters[idx].getType()) + return false; + + // If all the common arguments have the same type, we can elide the other + // method if this method has the same number of arguments as other or the + // first argument after the common ones has a default value (and by C++ + // requirement, all the later ones will also have a default value). + return thisNumParams == otherNumParams || + parameters[otherNumParams].hasDefaultValue(); +} + +// write the parameters as a part of a method declaration to the given `os`. +void OpMethodResolvedParameters::writeDeclTo(raw_ostream &os) const { + llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) { + param.writeDeclTo(os); + }); +} + +// write the parameters as a part of a method definition to the given `os` +void OpMethodResolvedParameters::writeDefTo(raw_ostream &os) const { + llvm::interleaveComma(parameters, os, [&](const OpMethodParameter ¶m) { + param.writeDefTo(os); + }); +} - os << returnType << (elideSpaceAfterType(returnType) ? "" : " ") << namePrefix - << (namePrefix.empty() ? "" : "::") << methodName << "(" - << removeParamDefaultValue(parameters) << ")"; +//===----------------------------------------------------------------------===// +// OpMethodSignature definitions +//===----------------------------------------------------------------------===// + +void OpMethodSignature::writeDeclTo(raw_ostream &os) const { + os << returnType << getSpaceAfterType(returnType) << methodName << "("; + parameters->writeDeclTo(os); + os << ")"; } -bool OpMethodSignature::elideSpaceAfterType(StringRef type) { - return type.empty() || type.endswith("&") || type.endswith("*"); +void OpMethodSignature::writeDefTo(raw_ostream &os, + StringRef namePrefix) const { + os << returnType << getSpaceAfterType(returnType) << namePrefix + << (namePrefix.empty() ? "" : "::") << methodName << "("; + parameters->writeDefTo(os); + os << ")"; } //===----------------------------------------------------------------------===// @@ -90,10 +188,6 @@ // OpMethod definitions //===----------------------------------------------------------------------===// -OpMethod::OpMethod(StringRef retType, StringRef name, StringRef params, - OpMethod::Property property, bool declOnly) - : properties(property), isDeclOnly(declOnly), - methodSignature(retType, name, params), methodBody(declOnly) {} void OpMethod::writeDeclTo(raw_ostream &os) const { os.indent(2); if (isStatic()) @@ -103,9 +197,9 @@ } void OpMethod::writeDefTo(raw_ostream &os, StringRef namePrefix) const { - if (isDeclOnly) + // Do not write definition if the method is decl only. + if (properties & MP_DeclOnly) return; - methodSignature.writeDefTo(os, namePrefix); os << " {\n"; methodBody.writeTo(os); @@ -122,7 +216,8 @@ } void OpConstructor::writeDefTo(raw_ostream &os, StringRef namePrefix) const { - if (isDeclOnly) + // Do not write definition if the method is decl only. + if (properties & MP_DeclOnly) return; methodSignature.writeDefTo(os, namePrefix); @@ -137,18 +232,6 @@ Class::Class(StringRef name) : className(name) {} -OpMethod &Class::newMethod(StringRef retType, StringRef name, StringRef params, - OpMethod::Property property, bool declOnly) { - methods.emplace_back(retType, name, params, property, declOnly); - return methods.back(); -} - -OpConstructor &Class::newConstructor(StringRef params, bool declOnly) { - constructors.emplace_back("", getClassName(), params, - OpMethod::MP_Constructor, declOnly); - return constructors.back(); -} - void Class::newField(StringRef type, StringRef name, StringRef defaultValue) { std::string varName = formatv("{0} {1}", type, name).str(); std::string field = defaultValue.empty() @@ -156,43 +239,42 @@ : formatv("{0} = {1}", varName, defaultValue).str(); fields.push_back(std::move(field)); } - void Class::writeDeclTo(raw_ostream &os) const { bool hasPrivateMethod = false; os << "class " << className << " {\n"; os << "public:\n"; - for (const auto &method : - llvm::concat(constructors, methods)) { + + forAllMethods([&](const OpMethod &method) { if (!method.isPrivate()) { method.writeDeclTo(os); os << '\n'; } else { hasPrivateMethod = true; } - } + }); + os << '\n'; os << "private:\n"; if (hasPrivateMethod) { - for (const auto &method : - llvm::concat(constructors, methods)) { + forAllMethods([&](const OpMethod &method) { if (method.isPrivate()) { method.writeDeclTo(os); os << '\n'; } - } + }); os << '\n'; } + for (const auto &field : fields) os.indent(2) << field << ";\n"; os << "};\n"; } void Class::writeDefTo(raw_ostream &os) const { - for (const auto &method : - llvm::concat(constructors, methods)) { + forAllMethods([&](const OpMethod &method) { method.writeDefTo(os, className); os << "\n\n"; - } + }); } //===----------------------------------------------------------------------===// @@ -217,14 +299,14 @@ os << " using Adaptor = " << className << "Adaptor;\n"; bool hasPrivateMethod = false; - for (const auto &method : methods) { + forAllMethods([&](const OpMethod &method) { if (!method.isPrivate()) { method.writeDeclTo(os); os << "\n"; } else { hasPrivateMethod = true; } - } + }); // TODO: Add line control markers to make errors easier to debug. if (!extraClassDeclaration.empty()) @@ -232,12 +314,12 @@ if (hasPrivateMethod) { os << "\nprivate:\n"; - for (const auto &method : methods) { + forAllMethods([&](const OpMethod &method) { if (method.isPrivate()) { method.writeDeclTo(os); os << "\n"; } - } + }); } os << "};\n"; diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -107,7 +107,7 @@ StrAttr:$str_attr, ElementsAttr:$elements_attr, FlatSymbolRefAttr:$function_attr, - SomeTypeAttr:$type_attr, + SomeTypeAttr:$some_type_attr, ArrayAttr:$array_attr, TypedArrayAttrBase:$some_attr_array, TypeAttr:$type_attr @@ -128,7 +128,7 @@ // DEF: if (!((tblgen_str_attr.isa<::mlir::StringAttr>()))) // DEF: if (!((tblgen_elements_attr.isa<::mlir::ElementsAttr>()))) // DEF: if (!((tblgen_function_attr.isa<::mlir::FlatSymbolRefAttr>()))) -// DEF: if (!(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa())))) +// DEF: if (!(((tblgen_some_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_some_type_attr.cast<::mlir::TypeAttr>().getValue().isa())))) // DEF: if (!((tblgen_array_attr.isa<::mlir::ArrayAttr>()))) // DEF: if (!(((tblgen_some_attr_array.isa<::mlir::ArrayAttr>())) && (::llvm::all_of(tblgen_some_attr_array.cast<::mlir::ArrayAttr>(), [](::mlir::Attribute attr) { return (some-condition); })))) // DEF: if (!(((tblgen_type_attr.isa<::mlir::TypeAttr>())) && ((tblgen_type_attr.cast<::mlir::TypeAttr>().getValue().isa<::mlir::Type>())))) @@ -145,7 +145,7 @@ // DEF: ::llvm::StringRef BOp::str_attr() // DEF: ::mlir::ElementsAttr BOp::elements_attr() // DEF: ::llvm::StringRef BOp::function_attr() -// DEF: SomeType BOp::type_attr() +// DEF: SomeType BOp::some_type_attr() // DEF: ::mlir::ArrayAttr BOp::array_attr() // DEF: ::mlir::ArrayAttr BOp::some_attr_array() // DEF: ::mlir::Type BOp::type_attr() diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -110,7 +110,7 @@ let results = (outs AnyTensor:$result); } -// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes ) +// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) // CHECK: odsState.addTypes({operands[0].getType()}); // Test with inferred shapes and interleaved with operands/attributes. diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -232,10 +232,6 @@ // operand's type as all results' types. void genUseOperandAsResultTypeCollectiveParamBuilder(); - // Returns true if the inferred collective param build method should be - // generated. - bool shouldGenerateInferredTypeCollectiveParamBuilder(); - // Generates the build() method that takes aggregate operands/attributes // parameters. This build() method uses inferred types as result types. // Requires: The type needs to be inferable via InferTypeOpInterface. @@ -268,7 +264,7 @@ // `resultTypeNames` with the names for parameters for specifying result // types. The given `typeParamKind` and `attrParamKind` controls how result // types and attributes are placed in the parameter list. - void buildParamList(std::string ¶mList, + void buildParamList(llvm::SmallVectorImpl ¶mList, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); @@ -496,15 +492,19 @@ // Emit the derived attribute body. auto emitDerivedAttr = [&](StringRef name, Attribute attr) { - auto &method = opClass.newMethod(attr.getReturnType(), name); - auto &body = method.body(); + auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); + if (!method) + return; + auto &body = method->body(); body << " " << attr.getDerivedCodeBody() << "\n"; }; // Emit with return type specified. auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) { - auto &method = opClass.newMethod(attr.getReturnType(), name); - auto &body = method.body(); + auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); + if (!method) + return; + auto &body = method->body(); body << " auto attr = " << name << "Attr();\n"; if (attr.hasDefaultValue()) { // Returns the default value if not set. @@ -526,9 +526,11 @@ // referring to the attributes via accessors instead of having to use // the string interface for better compile time verification. auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { - auto &method = - opClass.newMethod(attr.getStorageType(), (name + "Attr").str()); - auto &body = method.body(); + auto *method = + opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str()); + if (!method) + return; + auto &body = method->body(); body << " return this->getAttr(\"" << name << "\")."; if (attr.isOptional() || attr.hasDefaultValue()) body << "dyn_cast_or_null<"; @@ -558,19 +560,19 @@ // attribute. This enables, for example, avoiding adding an attribute that // overlaps with a derived attribute. { - auto &method = - opClass.newMethod("bool", "isDerivedAttribute", - "::llvm::StringRef name", OpMethod::MP_Static); - auto &body = method.body(); + auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute", + OpMethod::MP_Static, + "::llvm::StringRef", "name"); + auto &body = method->body(); for (auto namedAttr : derivedAttrs) body << " if (name == \"" << namedAttr.name << "\") return true;\n"; body << " return false;"; } // Generate method to materialize derived attributes as a DictionaryAttr. { - OpMethod &method = opClass.newMethod("::mlir::DictionaryAttr", - "materializeDerivedAttributes"); - auto &body = method.body(); + auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr", + "materializeDerivedAttributes"); + auto &body = method->body(); auto nonMaterializable = make_filter_range(derivedAttrs, [](const NamedAttribute &namedAttr) { @@ -618,9 +620,11 @@ // to the attributes via setters instead of having to use the string interface // for better compile time verification. auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { - auto &method = opClass.newMethod("void", (name + "Attr").str(), - (attr.getStorageType() + " attr").str()); - auto &body = method.body(); + auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(), + attr.getStorageType(), "attr"); + if (!method) + return; + auto &body = method->body(); body << " this->getOperation()->setAttr(\"" << name << "\", attr);"; }; @@ -640,13 +644,15 @@ int numVariadic, int numNonVariadic, StringRef rangeSizeCall, bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) { - auto &method = opClass.newMethod("std::pair", methodName, - "unsigned index"); - + auto *method = opClass.addMethodAndPrune("std::pair", + methodName, "unsigned", "index"); + if (!method) + return; + auto &body = method->body(); if (numVariadic == 0) { - method.body() << " return {index, 1};\n"; + body << " return {index, 1};\n"; } else if (hasAttrSegmentSize) { - method.body() << sizeAttrInit << attrSizedSegmentValueRangeCalcCode; + body << sizeAttrInit << attrSizedSegmentValueRangeCalcCode; } else { // Because the op can have arbitrarily interleaved variadic and non-variadic // operands, we need to embed a list in the "sink" getter method for @@ -656,9 +662,8 @@ for (auto &it : odsValues) isVariadic.push_back(it.isVariableLength() ? "true" : "false"); std::string isVariadicList = llvm::join(isVariadic, ", "); - method.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, - numNonVariadic, numVariadic, rangeSizeCall, - "operand"); + body << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, + numNonVariadic, numVariadic, rangeSizeCall, "operand"); } } @@ -711,9 +716,11 @@ rangeSizeCall, attrSizedOperands, sizeAttrInit, const_cast(op).getOperands()); - auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); - m.body() << formatv(valueRangeReturnCode, rangeBeginCall, - "getODSOperandIndexAndLength(index)"); + auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned", + "index"); + auto &body = m->body(); + body << formatv(valueRangeReturnCode, rangeBeginCall, + "getODSOperandIndexAndLength(index)"); // Then we emit nicer named getter methods by redirecting to the "sink" getter // method. @@ -723,15 +730,15 @@ continue; if (operand.isOptional()) { - auto &m = opClass.newMethod("::mlir::Value", operand.name); - m.body() << " auto operands = getODSOperands(" << i << ");\n" - << " return operands.empty() ? Value() : *operands.begin();"; + m = opClass.addMethodAndPrune("::mlir::Value", operand.name); + m->body() << " auto operands = getODSOperands(" << i << ");\n" + << " return operands.empty() ? Value() : *operands.begin();"; } else if (operand.isVariadic()) { - auto &m = opClass.newMethod(rangeType, operand.name); - m.body() << " return getODSOperands(" << i << ");"; + m = opClass.addMethodAndPrune(rangeType, operand.name); + m->body() << " return getODSOperands(" << i << ");"; } else { - auto &m = opClass.newMethod("::mlir::Value", operand.name); - m.body() << " return *getODSOperands(" << i << ").begin();"; + m = opClass.addMethodAndPrune("::mlir::Value", operand.name); + m->body() << " return *getODSOperands(" << i << ").begin();"; } } } @@ -753,9 +760,9 @@ const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - auto &m = opClass.newMethod("::mlir::MutableOperandRange", - (operand.name + "Mutable").str()); - auto &body = m.body(); + auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange", + (operand.name + "Mutable").str()); + auto &body = m->body(); body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" << " return ::mlir::MutableOperandRange(getOperation(), " "range.first, range.second"; @@ -800,10 +807,11 @@ numNormalResults, "getOperation()->getNumResults()", attrSizedResults, formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(), op.getResults()); - auto &m = opClass.newMethod("::mlir::Operation::result_range", - "getODSResults", "unsigned index"); - m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", - "getODSResultIndexAndLength(index)"); + + auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range", + "getODSResults", "unsigned", "index"); + m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", + "getODSResultIndexAndLength(index)"); for (int i = 0; i != numResults; ++i) { const auto &result = op.getResult(i); @@ -811,17 +819,17 @@ continue; if (result.isOptional()) { - auto &m = opClass.newMethod("::mlir::Value", result.name); - m.body() + m = opClass.addMethodAndPrune("::mlir::Value", result.name); + m->body() << " auto results = getODSResults(" << i << ");\n" << " return results.empty() ? ::mlir::Value() : *results.begin();"; } else if (result.isVariadic()) { - auto &m = - opClass.newMethod("::mlir::Operation::result_range", result.name); - m.body() << " return getODSResults(" << i << ");"; + m = opClass.addMethodAndPrune("::mlir::Operation::result_range", + result.name); + m->body() << " return getODSResults(" << i << ");"; } else { - auto &m = opClass.newMethod("::mlir::Value", result.name); - m.body() << " return *getODSResults(" << i << ").begin();"; + m = opClass.addMethodAndPrune("::mlir::Value", result.name); + m->body() << " return *getODSResults(" << i << ").begin();"; } } } @@ -835,15 +843,15 @@ // Generate the accessors for a varidiadic region. if (region.isVariadic()) { - auto &m = - opClass.newMethod("::mlir::MutableArrayRef", region.name); - m.body() << formatv( + auto *m = opClass.addMethodAndPrune("::mlir::MutableArrayRef", + region.name); + m->body() << formatv( " return this->getOperation()->getRegions().drop_front({0});", i); continue; } - auto &m = opClass.newMethod("::mlir::Region &", region.name); - m.body() << formatv(" return this->getOperation()->getRegion({0});", i); + auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name); + m->body() << formatv(" return this->getOperation()->getRegion({0});", i); } } @@ -856,16 +864,18 @@ // Generate the accessors for a variadic successor list. if (successor.isVariadic()) { - auto &m = opClass.newMethod("::mlir::SuccessorRange", successor.name); - m.body() << formatv( + auto *m = + opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name); + m->body() << formatv( " return {std::next(this->getOperation()->successor_begin(), {0}), " "this->getOperation()->successor_end()};", i); continue; } - auto &m = opClass.newMethod("::mlir::Block *", successor.name); - m.body() << formatv(" return this->getOperation()->getSuccessor({0});", i); + auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name); + m->body() << formatv(" return this->getOperation()->getSuccessor({0});", + i); } } @@ -905,14 +915,16 @@ // inferring result type. auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, bool inferType) { - std::string paramList; + llvm::SmallVector paramList; llvm::SmallVector resultNames; buildParamList(paramList, resultNames, paramKind, attrType); - auto &m = - opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); - auto &body = m.body(); - + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method. + if (!m) + return; + auto &body = m->body(); genCodeForAddingArgAndRegionForBuilder( body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); @@ -967,54 +979,13 @@ llvm_unreachable("unhandled TypeParamKind"); }; - // A separate arg param builder method will have a signature which is - // ambiguous with the collective params build method (generated in - // `genCollectiveParamBuilder` function below) if it has a single - // `ArrayReg` parameter for result types and a single `ArrayRef` - // parameter for the operands, no parameters after that, and the collective - // params build method has `attributes` as its last parameter (with - // a default value). This will happen when all of the following are true: - // 1. [`attributes` as last parameter in collective params build method]: - // getNumVariadicRegions must be 0 (otherwise the collective params build - // method ends with a `numRegions` param, and we don't specify default - // value for attributes). - // 2. [single `ArrayRef` parameter for operands, and no parameters - // after that]: numArgs() must be 1 (if not, each arg gets a separate param - // in the build methods generated here) and the single arg must be a - // non-attribute variadic argument. - // 3. [single `ArrayReg` parameter for result types]: - // 3a. paramKind should be Collective, or - // 3b. paramKind should be Separate and there should be a single variadic - // result - // - // In that case, skip generating such ambiguous build methods here. + // Some of the build methods generated here may be amiguous, but TableGen's + // ambiguous function detection will elide those ones. for (auto attrType : attrBuilderType) { - // Case 3b above. - if (!(op.hasNoVariadicRegions() && op.hasSingleVariadicArg() && - op.hasSingleVariadicResult())) - emit(attrType, TypeParamKind::Separate, /*inferType=*/false); - if (canInferType(op)) { - // When inferType = true, the generated build method does not have - // result types. If the op has a single variadic arg, then this build - // method will be ambiguous with the collective inferred build method - // generated in `genInferredTypeCollectiveParamBuilder`. If we are going - // to generate that collective inferred method, suppress generating the - // ambiguous build method here. - bool buildMethodAmbiguous = - op.hasSingleVariadicArg() && - shouldGenerateInferredTypeCollectiveParamBuilder(); - if (!buildMethodAmbiguous) - emit(attrType, TypeParamKind::None, /*inferType=*/true); - } - // The separate arg + collective param kind method will be: - // (a) Same as the separate arg + separate param kind method if there is - // only one variadic result. - // (b) Ambiguous with the collective params method under conditions in (3a) - // above. - // In either case, skip generating such build method. - if (!op.hasSingleVariadicResult() && - !(op.hasNoVariadicRegions() && op.hasSingleVariadicArg())) - emit(attrType, TypeParamKind::Collective, /*inferType=*/false); + emit(attrType, TypeParamKind::Separate, /*inferType=*/false); + if (canInferType(op)) + emit(attrType, TypeParamKind::None, /*inferType=*/true); + emit(attrType, TypeParamKind::Collective, /*inferType=*/false); } } @@ -1022,19 +993,23 @@ int numResults = op.getNumResults(); // Signature - std::string params = - std::string("::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &") + - builderOpState + - ", ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> " - "attributes"; - if (op.getNumVariadicRegions()) { - params += ", unsigned numRegions"; - } else { - // Provide default value for `attributes` since its the last parameter - params += " = {}"; - } - auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); - auto &body = m.body(); + llvm::SmallVector paramList; + paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OperationState &", builderOpState); + paramList.emplace_back("::mlir::ValueRange", "operands"); + // Provide default value for `attributes` when its the last parameter + StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; + paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", + "attributes", attributesDefaultValue); + if (op.getNumVariadicRegions()) + paramList.emplace_back("unsigned", "numRegions"); + + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + auto &body = m->body(); // Operands body << " " << builderOpState << ".addOperands(operands);\n"; @@ -1056,19 +1031,20 @@ << llvm::join(resultTypes, ", ") << "});\n\n"; } -bool OpEmitter::shouldGenerateInferredTypeCollectiveParamBuilder() { - return canInferType(op) && op.getNumSuccessors() == 0; -} - void OpEmitter::genInferredTypeCollectiveParamBuilder() { // TODO: Expand to support regions. - std::string params = - std::string("::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &") + - builderOpState + - ", ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> " - "attributes = {}"; - auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); - auto &body = m.body(); + SmallVector paramList; + paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OperationState &", builderOpState); + paramList.emplace_back("::mlir::ValueRange", "operands"); + paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", + "attributes", "{}"); + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + auto &body = m->body(); int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariableLengthResults(); @@ -1116,12 +1092,17 @@ } void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { - std::string paramList; + llvm::SmallVector paramList; llvm::SmallVector resultNames; buildParamList(paramList, resultNames, TypeParamKind::None); - auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static); - genCodeForAddingArgAndRegionForBuilder(m.body()); + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + auto &body = m->body(); + genCodeForAddingArgAndRegionForBuilder(body); auto numResults = op.getNumResults(); if (numResults == 0) @@ -1131,20 +1112,26 @@ const char *index = op.getOperand(0).isVariadic() ? ".front()" : ""; std::string resultType = formatv("{0}{1}.getType()", getArgumentName(op, 0), index).str(); - m.body() << " " << builderOpState << ".addTypes({" << resultType; + body << " " << builderOpState << ".addTypes({" << resultType; for (int i = 1; i != numResults; ++i) - m.body() << ", " << resultType; - m.body() << "});\n\n"; + body << ", " << resultType; + body << "});\n\n"; } void OpEmitter::genUseAttrAsResultTypeBuilder() { - std::string params = - std::string("::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &") + - builderOpState + - ", ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> " - "attributes"; - auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); - auto &body = m.body(); + SmallVector paramList; + paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OperationState &", builderOpState); + paramList.emplace_back("::mlir::ValueRange", "operands"); + paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", + "attributes", "{}"); + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + + auto &body = m->body(); // Push all result types to the operation state std::string resultType; @@ -1184,11 +1171,12 @@ StringRef body = builderDef->getValueAsString("body"); bool hasBody = !body.empty(); - auto &method = - opClass.newMethod("void", "build", params, OpMethod::MP_Static, - /*declOnly=*/!hasBody); + OpMethod::Property properties = + hasBody ? OpMethod::MP_Static : OpMethod::MP_StaticDeclOnly; + auto *method = + opClass.addMethodAndPrune("void", "build", properties, params); if (hasBody) - method.body() << body; + method->body() << body; } } if (op.skipDefaultBuilders()) { @@ -1214,21 +1202,8 @@ // to facilitate different call patterns. if (op.getNumVariableLengthResults() == 0) { if (op.getTrait("OpTrait::SameOperandsAndResultType")) { - // If the operation has a single variadic input, then the build method - // generated by `genUseOperandAsResultTypeSeparateParamBuilder` will be - // ambiguous with the one generated by - // `genUseOperandAsResultTypeCollectiveParamBuilder` (they both will have - // a single `ValueRange` argument for operands, and the collective one - // will have a `ArrayRef` argument initialized to empty). - // Suppress such ambiguous build method. - if (!op.hasSingleVariadicArg()) - genUseOperandAsResultTypeSeparateParamBuilder(); - - // The build method generated by the inferred type collective param - // builder and one generated here have the same arguments and hence - // generating both will be ambiguous. Enable just one of them. - if (!shouldGenerateInferredTypeCollectiveParamBuilder()) - genUseOperandAsResultTypeCollectiveParamBuilder(); + genUseOperandAsResultTypeSeparateParamBuilder(); + genUseOperandAsResultTypeCollectiveParamBuilder(); } if (op.getTrait("OpTrait::FirstAttrDerivedResultType")) genUseAttrAsResultTypeBuilder(); @@ -1243,21 +1218,25 @@ int numOperands = op.getNumOperands(); int numVariadicOperands = op.getNumVariableLengthOperands(); int numNonVariadicOperands = numOperands - numVariadicOperands; - // Signature - std::string params = - std::string("::mlir::OpBuilder &, ::mlir::OperationState &") + - builderOpState + - ", ::llvm::ArrayRef<::mlir::Type> resultTypes, ::mlir::ValueRange " - "operands, " - "::llvm::ArrayRef<::mlir::NamedAttribute> attributes"; - if (op.getNumVariadicRegions()) { - params += ", unsigned numRegions"; - } else { - // Provide default value for `attributes` since its the last parameter - params += " = {}"; - } - auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); - auto &body = m.body(); + + SmallVector paramList; + paramList.emplace_back("::mlir::OpBuilder &", ""); + paramList.emplace_back("::mlir::OperationState &", builderOpState); + paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes"); + paramList.emplace_back("::mlir::ValueRange", "operands"); + // Provide default value for `attributes` when its the last parameter + StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; + paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", + "attributes", attributesDefaultValue); + if (op.getNumVariadicRegions()) + paramList.emplace_back("unsigned", "numRegions"); + + auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, + std::move(paramList)); + // If the builder is redundant, skip generating the method + if (!m) + return; + auto &body = m->body(); // Operands if (numVariadicOperands == 0 || numNonVariadicOperands != 0) @@ -1287,11 +1266,11 @@ // Generate builder that infers type too. // TODO: Expand to handle regions and successors. - if (shouldGenerateInferredTypeCollectiveParamBuilder()) + if (canInferType(op) && op.getNumSuccessors() == 0) genInferredTypeCollectiveParamBuilder(); } -void OpEmitter::buildParamList(std::string ¶mList, +void OpEmitter::buildParamList(SmallVectorImpl ¶mList, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind) { @@ -1299,8 +1278,8 @@ auto numResults = op.getNumResults(); resultTypeNames.reserve(numResults); - paramList = "::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &"; - paramList.append(builderOpState); + paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); + paramList.emplace_back("::mlir::OperationState &", builderOpState); switch (typeParamKind) { case TypeParamKind::None: @@ -1313,19 +1292,18 @@ if (resultName.empty()) resultName = std::string(formatv("resultType{0}", i)); + StringRef type = result.isVariadic() ? "::llvm::ArrayRef<::mlir::Type>" + : "::mlir::Type"; + OpMethodParameter::Property properties = OpMethodParameter::PP_None; if (result.isOptional()) - paramList.append(", /*optional*/::mlir::Type "); - else if (result.isVariadic()) - paramList.append(", ::llvm::ArrayRef<::mlir::Type> "); - else - paramList.append(", ::mlir::Type "); - paramList.append(resultName); + properties = OpMethodParameter::PP_Optional; + paramList.emplace_back(type, resultName, properties); resultTypeNames.emplace_back(std::move(resultName)); } } break; case TypeParamKind::Collective: { - paramList.append(", ::llvm::ArrayRef<::mlir::Type> resultTypes"); + paramList.emplace_back("::llvm::ArrayRef<::mlir::Type>", "resultTypes"); resultTypeNames.push_back("resultTypes"); } break; } @@ -1364,64 +1342,64 @@ auto argument = op.getArg(i); if (argument.is()) { const auto &operand = op.getOperand(numOperands); + StringRef type = + operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value"; + OpMethodParameter::Property properties = OpMethodParameter::PP_None; if (operand.isOptional()) - paramList.append(", /*optional*/::mlir::Value "); - else if (operand.isVariadic()) - paramList.append(", ::mlir::ValueRange "); - else - paramList.append(", ::mlir::Value "); - paramList.append(getArgumentName(op, numOperands)); + properties = OpMethodParameter::PP_Optional; + + paramList.emplace_back(type, getArgumentName(op, numOperands), + properties); ++numOperands; } else { const auto &namedAttr = op.getAttribute(numAttrs); const auto &attr = namedAttr.attr; - paramList.append(", "); + OpMethodParameter::Property properties = OpMethodParameter::PP_None; if (attr.isOptional()) - paramList.append("/*optional*/"); + properties = OpMethodParameter::PP_Optional; + StringRef type; switch (attrParamKind) { case AttrParamKind::WrappedAttr: - paramList.append(std::string(attr.getStorageType())); + type = attr.getStorageType(); break; case AttrParamKind::UnwrappedValue: - if (canUseUnwrappedRawValue(attr)) { - paramList.append(std::string(attr.getReturnType())); - } else { - paramList.append(std::string(attr.getStorageType())); - } + if (canUseUnwrappedRawValue(attr)) + type = attr.getReturnType(); + else + type = attr.getStorageType(); break; } - paramList.append(" "); - paramList.append(std::string(namedAttr.name)); + std::string defaultValue; // Attach default value if requested and possible. if (attrParamKind == AttrParamKind::UnwrappedValue && i >= defaultValuedAttrStartIndex) { bool isString = attr.getReturnType() == "::llvm::StringRef"; - paramList.append(" = "); if (isString) - paramList.append("\""); - paramList.append(std::string(attr.getDefaultValue())); + defaultValue.append("\""); + defaultValue += attr.getDefaultValue(); if (isString) - paramList.append("\""); + defaultValue.append("\""); } + paramList.emplace_back(type, namedAttr.name, defaultValue, properties); ++numAttrs; } } /// Insert parameters for each successor. for (const NamedSuccessor &succ : op.getSuccessors()) { - paramList += (succ.isVariadic() ? ", ::llvm::ArrayRef<::mlir::Block *> " - : ", ::mlir::Block *"); - paramList += succ.name; + StringRef type = succ.isVariadic() ? "::llvm::ArrayRef<::mlir::Block *>" + : "::mlir::Block *"; + paramList.emplace_back(type, succ.name); } /// Insert parameters for variadic regions. - for (const NamedRegion ®ion : op.getRegions()) { + for (const NamedRegion ®ion : op.getRegions()) if (region.isVariadic()) - paramList += llvm::formatv(", unsigned {0}Count", region.name).str(); - } + paramList.emplace_back("unsigned", + llvm::formatv("{0}Count", region.name).str()); } void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, @@ -1508,10 +1486,11 @@ if (!def.getValueAsBit("hasCanonicalizer")) return; - const char *const params = - "::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context"; - opClass.newMethod("void", "getCanonicalizationPatterns", params, - OpMethod::MP_Static, /*declOnly=*/true); + SmallVector paramList; + paramList.emplace_back("::mlir::OwningRewritePatternList &", "results"); + paramList.emplace_back("::mlir::MLIRContext *", "context"); + opClass.addMethodAndPrune("void", "getCanonicalizationPatterns", + OpMethod::MP_StaticDeclOnly, std::move(paramList)); } void OpEmitter::genFolderDecls() { @@ -1520,17 +1499,16 @@ if (def.getValueAsBit("hasFolder")) { if (hasSingleResult) { - const char *const params = "::llvm::ArrayRef<::mlir::Attribute> operands"; - opClass.newMethod("::mlir::OpFoldResult", "fold", params, - OpMethod::MP_None, - /*declOnly=*/true); + opClass.addMethodAndPrune( + "::mlir::OpFoldResult", "fold", OpMethod::MP_DeclOnly, + "::llvm::ArrayRef<::mlir::Attribute>", "operands"); } else { - const char *const params = - "::llvm::ArrayRef<::mlir::Attribute> operands, " - "::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results"; - opClass.newMethod("::mlir::LogicalResult", "fold", params, - OpMethod::MP_None, - /*declOnly=*/true); + SmallVector paramList; + paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); + paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", + "results"); + opClass.addMethodAndPrune("::mlir::LogicalResult", "fold", + OpMethod::MP_DeclOnly, std::move(paramList)); } } } @@ -1554,16 +1532,14 @@ !alwaysDeclaredMethods.count(method.getName())) continue; - std::string args; - llvm::raw_string_ostream os(args); - interleaveComma(method.getArguments(), os, - [&](const InterfaceMethod::Argument &arg) { - os << arg.type << " " << arg.name; - }); - opClass.newMethod(method.getReturnType(), method.getName(), os.str(), - method.isStatic() ? OpMethod::MP_Static - : OpMethod::MP_None, - /*declOnly=*/true); + SmallVector paramList; + for (const InterfaceMethod::Argument &arg : method.getArguments()) + paramList.emplace_back(arg.type, arg.name); + + auto properties = + method.isStatic() ? OpMethod::MP_StaticDeclOnly : OpMethod::MP_DeclOnly; + opClass.addMethodAndPrune(method.getReturnType(), method.getName(), + properties, std::move(paramList)); } } @@ -1622,15 +1598,14 @@ resolveDecorators(op.getResultDecorators(i), i, EffectKind::Result); for (auto &it : interfaceEffects) { - auto effectsParam = - llvm::formatv("::mlir::SmallVectorImpl<::mlir::SideEffects::" - "EffectInstance<{0}>> &effects", - it.first()) - .str(); - // Generate the 'getEffects' method. - auto &getEffects = opClass.newMethod("void", "getEffects", effectsParam); - auto &body = getEffects.body(); + std::string type = llvm::formatv("::mlir::SmallVectorImpl<::mlir::" + "SideEffects::EffectInstance<{0}>> &", + it.first()) + .str(); + auto *getEffects = + opClass.addMethodAndPrune("void", "getEffects", type, "effects"); + auto &body = getEffects->body(); // Add effect instances for each of the locations marked on the operation. for (auto &location : it.second) { @@ -1655,21 +1630,24 @@ if (!op.allResultTypesKnown()) return; - auto &method = opClass.newMethod( - "::mlir::LogicalResult", "inferReturnTypes", - "::mlir::MLIRContext* context, " - "::llvm::Optional<::mlir::Location> location, " - "::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, " - "::mlir::RegionRange regions, " - "::llvm::SmallVectorImpl<::mlir::Type>& inferredReturnTypes", - OpMethod::MP_Static, - /*declOnly=*/false); - auto &os = method.body(); - os << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; + SmallVector paramList; + paramList.emplace_back("::mlir::MLIRContext *", "context"); + paramList.emplace_back("::llvm::Optional<::mlir::Location>", "location"); + paramList.emplace_back("::mlir::ValueRange", "operands"); + paramList.emplace_back("::mlir::DictionaryAttr", "attributes"); + paramList.emplace_back("::mlir::RegionRange", "regions"); + paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::Type>&", + "inferredReturnTypes"); + auto *method = + opClass.addMethodAndPrune("::mlir::LogicalResult", "inferReturnTypes", + OpMethod::MP_Static, std::move(paramList)); + + auto &body = method->body(); + body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; FmtContext fctx; fctx.withBuilder("odsBuilder"); - os << " ::mlir::Builder odsBuilder(context);\n"; + body << " ::mlir::Builder odsBuilder(context);\n"; auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & { @@ -1678,24 +1656,24 @@ assert(!op.getArg(argIndex).is()); auto arg = op.getArgToOperandOrAttribute(argIndex); if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) - return os << "operands[" << arg.operandOrAttributeIndex() + return body << "operands[" << arg.operandOrAttributeIndex() + << "].getType()"; + return body << "attributes[" << arg.operandOrAttributeIndex() << "].getType()"; - return os << "attributes[" << arg.operandOrAttributeIndex() - << "].getType()"; } else { - return os << tgfmt(*type.getType().getBuilderCall(), &fctx); + return body << tgfmt(*type.getType().getBuilderCall(), &fctx); } }; for (int i = 0, e = op.getNumResults(); i != e; ++i) { - os << " inferredReturnTypes[" << i << "] = "; + body << " inferredReturnTypes[" << i << "] = "; auto types = op.getSameTypeAsResult(i); emitType(types[0]) << ";\n"; if (types.size() == 1) continue; // TODO: We could verify equality here, but skipping that for verification. } - os << " return success();"; + body << " return success();"; } void OpEmitter::genParser() { @@ -1703,14 +1681,17 @@ hasStringAttribute(def, "assemblyFormat")) return; - auto &method = opClass.newMethod( - "::mlir::ParseResult", "parse", - "::mlir::OpAsmParser &parser, ::mlir::OperationState &result", - OpMethod::MP_Static); + SmallVector paramList; + paramList.emplace_back("::mlir::OpAsmParser &", "parser"); + paramList.emplace_back("::mlir::OperationState &", "result"); + auto *method = + opClass.addMethodAndPrune("::mlir::ParseResult", "parse", + OpMethod::MP_Static, std::move(paramList)); + FmtContext fctx; fctx.addSubst("cppClass", opClass.getClassName()); auto parser = def.getValueAsString("parser").ltrim().rtrim(" \t\v\f\r"); - method.body() << " " << tgfmt(parser, &fctx); + method->body() << " " << tgfmt(parser, &fctx); } void OpEmitter::genPrinter() { @@ -1722,17 +1703,17 @@ if (!codeInit) return; - auto &method = opClass.newMethod("void", "print", "::mlir::OpAsmPrinter &p"); + auto *method = + opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p"); FmtContext fctx; fctx.addSubst("cppClass", opClass.getClassName()); auto printer = codeInit->getValue().ltrim().rtrim(" \t\v\f\r"); - method.body() << " " << tgfmt(printer, &fctx); + method->body() << " " << tgfmt(printer, &fctx); } void OpEmitter::genVerifier() { - auto &method = - opClass.newMethod("::mlir::LogicalResult", "verify", /*params=*/""); - auto &body = method.body(); + auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify"); + auto &body = method->body(); body << " if (failed(" << op.getAdaptorName() << "(*this).verify(this->getLoc()))) " << "return failure();\n"; @@ -1975,9 +1956,9 @@ } void OpEmitter::genOpNameGetter() { - auto &method = opClass.newMethod("::llvm::StringRef", "getOperationName", - /*params=*/"", OpMethod::MP_Static); - method.body() << " return \"" << op.getOperationName() << "\";\n"; + auto *method = opClass.addMethodAndPrune( + "::llvm::StringRef", "getOperationName", OpMethod::MP_Static); + method->body() << " return \"" << op.getOperationName() << "\";\n"; } void OpEmitter::genOpAsmInterface() { @@ -2001,9 +1982,9 @@ opClass.addTrait("::mlir::OpAsmOpInterface::Trait"); // Generate the right accessor for the number of results. - auto &method = opClass.newMethod("void", "getAsmResultNames", - "OpAsmSetValueNameFn setNameFn"); - auto &body = method.body(); + auto *method = opClass.addMethodAndPrune("void", "getAsmResultNames", + "OpAsmSetValueNameFn", "setNameFn"); + auto &body = method->body(); for (int i = 0; i != numResults; ++i) { body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" << " if (!llvm::empty(resultGroup" << i << "))\n" @@ -2044,22 +2025,23 @@ const auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments"); { - auto &constructor = adaptor.newConstructor( - attrSizedOperands - ? "::mlir::ValueRange values, ::mlir::DictionaryAttr attrs" - : "::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = " - "nullptr"); - constructor.addMemberInitializer("odsOperands", "values"); - constructor.addMemberInitializer("odsAttrs", "attrs"); + SmallVector paramList; + paramList.emplace_back("::mlir::ValueRange", "values"); + paramList.emplace_back("::mlir::DictionaryAttr", "attrs", + attrSizedOperands ? "" : "nullptr"); + auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList)); + + constructor->addMemberInitializer("odsOperands", "values"); + constructor->addMemberInitializer("odsAttrs", "attrs"); } { - auto &constructor = adaptor.newConstructor( - llvm::formatv("{0}& op", op.getCppClassName()).str()); - constructor.addMemberInitializer("odsOperands", - "op.getOperation()->getOperands()"); - constructor.addMemberInitializer("odsAttrs", - "op.getOperation()->getAttrDictionary()"); + auto *constructor = adaptor.addConstructorAndPrune( + llvm::formatv("{0}&", op.getCppClassName()).str(), "op"); + constructor->addMemberInitializer("odsOperands", + "op.getOperation()->getOperands()"); + constructor->addMemberInitializer("odsAttrs", + "op.getOperation()->getAttrDictionary()"); } std::string sizeAttrInit = @@ -2074,7 +2056,7 @@ fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); auto emitAttr = [&](StringRef name, Attribute attr) { - auto &body = adaptor.newMethod(attr.getStorageType(), name).body(); + auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body(); body << " assert(odsAttrs && \"no attributes when constructing adapter\");" << "\n " << attr.getStorageType() << " attr = " << "odsAttrs.get(\"" << name << "\")."; @@ -2107,9 +2089,9 @@ } void OpOperandAdaptorEmitter::addVerification() { - auto &method = adaptor.newMethod("::mlir::LogicalResult", "verify", - /*params=*/"::mlir::Location loc"); - auto &body = method.body(); + auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify", + "::mlir::Location", "loc"); + auto &body = method->body(); const char *checkAttrSizedValueSegmentsCode = R"( { diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -922,11 +922,14 @@ } void OperationFormat::genParser(Operator &op, OpClass &opClass) { - auto &method = opClass.newMethod( - "::mlir::ParseResult", "parse", - "::mlir::OpAsmParser &parser, ::mlir::OperationState &result", - OpMethod::MP_Static); - auto &body = method.body(); + llvm::SmallVector paramList; + paramList.emplace_back("::mlir::OpAsmParser &", "parser"); + paramList.emplace_back("::mlir::OperationState &", "result"); + + auto *method = + opClass.addMethodAndPrune("::mlir::ParseResult", "parse", + OpMethod::MP_Static, std::move(paramList)); + auto &body = method->body(); // Generate variables to store the operands and type within the format. This // allows for referencing these variables in the presence of optional @@ -1607,8 +1610,9 @@ } void OperationFormat::genPrinter(Operator &op, OpClass &opClass) { - auto &method = opClass.newMethod("void", "print", "OpAsmPrinter &p"); - auto &body = method.body(); + auto *method = + opClass.addMethodAndPrune("void", "print", "OpAsmPrinter &", "p"); + auto &body = method->body(); // Emit the operation name, trimming the prefix if this is the standard // dialect.