Index: mlir/include/mlir/TableGen/Class.h =================================================================== --- mlir/include/mlir/TableGen/Class.h +++ mlir/include/mlir/TableGen/Class.h @@ -436,6 +436,13 @@ static_cast(rhs)); } +inline constexpr mlir::tblgen::Method::Properties & +operator|=(mlir::tblgen::Method::Properties &lhs, + mlir::tblgen::Method::Properties rhs) { + return lhs = mlir::tblgen::Method::Properties(static_cast(lhs) | + static_cast(rhs)); +} + namespace mlir { namespace tblgen { @@ -488,11 +495,25 @@ /// Write the using declaration. void writeDeclTo(raw_indented_ostream &os) const override; + /// Add a template parameter. + template + void addTemplateParam(ParamT param) { + templateParams.insert(stringify(param)); + } + + /// Add a list of template parameters. + template + void addTemplateParams(ContainerT &&container) { + templateParams.insert(std::begin(container), std::end(container)); + } + private: /// The name of the declaration, or a resolved name to an inherited function. std::string name; /// The type that is being aliased. Leave empty for inheriting functions. std::string value; + /// An optional list of class template parameters. + SetVector, StringSet<>> templateParams; }; /// This class describes a class field. @@ -581,23 +602,50 @@ /// Add a new constructor to this class and prune and constructors made /// redundant by it. Returns null if the constructor was not added. Else, /// returns a pointer to the new constructor. + /// If the class has template parameters, the constructor is automatically + /// defined inline. template Constructor *addConstructor(Args &&...args) { + Method::Properties defaultProperties = Method::Constructor; + if (!templateParams.empty()) + defaultProperties |= Method::Inline; return addConstructorAndPrune(Constructor(getClassName(), - Properties | Method::Constructor, + Properties | defaultProperties, std::forward(args)...)); } /// Add a new method to this class and prune any methods made redundant by it. /// Returns null if the method was not added (because an existing method would /// make it redundant). Else, returns a pointer to the new method. + /// If the class has template parameters, the method is automatically defined + /// inline. + template + Method *addMethod(RetTypeT &&retType, NameT &&name, + Method::Properties properties, + ArrayRef parameters) { + if (!templateParams.empty()) + properties |= Method::Inline; + return addMethodAndPrune(Method(std::forward(retType), + std::forward(name), + Properties | properties, parameters)); + } + + /// Add a method with statically-known properties. + template + Method *addMethod(RetTypeT &&retType, NameT &&name, + ArrayRef parameters) { + return addMethod(std::forward(retType), std::forward(name), + Properties, parameters); + } + template Method *addMethod(RetTypeT &&retType, NameT &&name, Method::Properties properties, Args &&...args) { - return addMethodAndPrune( - Method(std::forward(retType), std::forward(name), - Properties | properties, std::forward(args)...)); + return addMethod(std::forward(retType), std::forward(name), + properties | Properties, {std::forward(args)...}); } /// Add a method with statically-known properties. @@ -674,6 +722,18 @@ /// Add a parent class. ParentClass &addParent(ParentClass parent); + /// Add a template parameter. + template + void addTemplateParam(ParamT param) { + templateParams.insert(stringify(param)); + } + + /// Add a list of template parameters. + template + void addTemplateParams(ContainerT &&container) { + templateParams.insert(std::begin(container), std::end(container)); + } + /// Return the C++ name of the class. StringRef getClassName() const { return className; } @@ -751,6 +811,9 @@ /// A list of declarations in the class, emitted in order. std::vector> declarations; + + /// An optional list of class template parameters. + SetVector, StringSet<>> templateParams; }; } // namespace tblgen Index: mlir/include/mlir/TableGen/Operator.h =================================================================== --- mlir/include/mlir/TableGen/Operator.h +++ mlir/include/mlir/TableGen/Operator.h @@ -64,6 +64,9 @@ /// Returns the name of op's adaptor C++ class. std::string getAdaptorName() const; + /// Returns the name of op's generic adaptor C++ class. + std::string getGenericAdaptorName() const; + /// Check invariants (like no duplicated or conflicted names) and abort the /// process if any invariant is broken. void assertInvariants() const; Index: mlir/lib/TableGen/Class.cpp =================================================================== --- mlir/lib/TableGen/Class.cpp +++ mlir/lib/TableGen/Class.cpp @@ -228,6 +228,13 @@ //===----------------------------------------------------------------------===// void UsingDeclaration::writeDeclTo(raw_indented_ostream &os) const { + if (!templateParams.empty()) { + os << "template <"; + llvm::interleaveComma(templateParams, os, [&](StringRef paramName) { + os << "typename " << paramName; + }); + os << ">\n"; + } os << "using " << name; if (!value.empty()) os << " = " << value; @@ -275,6 +282,13 @@ } void Class::writeDeclTo(raw_indented_ostream &os) const { + if (!templateParams.empty()) { + os << "template <"; + llvm::interleaveComma(templateParams, os, + [&](StringRef param) { os << "typename " << param; }); + os << ">\n"; + } + // Declare the class. os << (isStruct ? "struct" : "class") << ' ' << className << ' '; @@ -341,7 +355,7 @@ }); return it == reverseDecls.end() ? (isStruct ? Visibility::Public : Visibility::Private) - : cast(*it).getVisibility(); + : cast(**it).getVisibility(); } Method *insertAndPruneMethods(std::vector> &methods, Index: mlir/lib/TableGen/Operator.cpp =================================================================== --- mlir/lib/TableGen/Operator.cpp +++ mlir/lib/TableGen/Operator.cpp @@ -69,6 +69,10 @@ return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); } +std::string Operator::getGenericAdaptorName() const { + return std::string(llvm::formatv("{0}GenericAdaptor", getCppClassName())); +} + /// Assert the invariants of accessors generated for the given name. static void assertAccessorInvariants(const Operator &op, StringRef name) { std::string accessorName = Index: mlir/test/mlir-tblgen/op-decl-and-defs.td =================================================================== --- mlir/test/mlir-tblgen/op-decl-and-defs.td +++ mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -52,20 +52,34 @@ // CHECK-LABEL: NS::AOp declarations -// CHECK: class AOpAdaptor { +// CHECK: namespace detail { +// CHECK: class AOpGenericAdaptorBase { // CHECK: public: -// CHECK: AOpAdaptor(::mlir::ValueRange values -// CHECK: ::mlir::ValueRange getODSOperands(unsigned index); -// CHECK: ::mlir::Value getA(); -// CHECK: ::mlir::ValueRange getB(); // CHECK: ::mlir::IntegerAttr getAttr1Attr(); // CHECK: uint32_t getAttr1(); // CHECK: ::mlir::FloatAttr getSomeAttr2Attr(); // CHECK: ::std::optional< ::llvm::APFloat > getSomeAttr2(); // CHECK: ::mlir::Region &getSomeRegion(); // CHECK: ::mlir::RegionRange getSomeRegions(); +// CHECK: }; +// CHECK: } + +// CHECK: template +// CHECK: class AOpGenericAdaptor : public detail::AOpGenericAdaptorBase { +// CHECK: public: +// CHECK: AOpGenericAdaptor(RangeT values, +// CHECK-SAME: odsOperands(values) +// CHECK: RangeT getODSOperands(unsigned index) { +// CHECK: ValueT getA() { +// CHECK: RangeT getB() { // CHECK: private: -// CHECK: ::mlir::ValueRange odsOperands; +// CHECK: RangeT odsOperands; +// CHECK: }; + +// CHECK: class AOpAdaptor : public AOpGenericAdaptor<::mlir::ValueRange> { +// CHECK: public: +// CHECK: AOpAdaptor(AOp +// CHECK: ::mlir::LogicalResult verify( // CHECK: }; // CHECK: class AOp : public ::mlir::Op::Impl, ::mlir::OpTrait::AtLeastNResults<1>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::AtLeastNOperands<1>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsIsolatedFromAbove @@ -108,10 +122,10 @@ // DEFS-LABEL: NS::AOp definitions -// DEFS: AOpAdaptor::AOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsOperands(values), odsAttrs(attrs), odsRegions(regions) -// DEFS: ::mlir::RegionRange AOpAdaptor::getSomeRegions() +// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) +// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions() // DEFS-NEXT: return odsRegions.drop_front(1); -// DEFS: ::mlir::RegionRange AOpAdaptor::getRegions() +// DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions() // Check AttrSizedOperandSegments // --- @@ -127,15 +141,17 @@ ); } -// CHECK-LABEL: AttrSizedOperandOpAdaptor( -// CHECK-SAME: ::mlir::ValueRange values -// CHECK-SAME: ::mlir::DictionaryAttr attrs -// CHECK: ::mlir::ValueRange getA(); -// CHECK: ::mlir::ValueRange getB(); -// CHECK: ::mlir::Value getC(); -// CHECK: ::mlir::ValueRange getD(); +// CHECK-LABEL: class AttrSizedOperandOpGenericAdaptorBase { // CHECK: ::mlir::DenseIntElementsAttr getOperandSegmentSizes(); +// CHECK-LABEL: AttrSizedOperandOpGenericAdaptor( +// CHECK-SAME: RangeT values +// CHECK-SAME: ::mlir::DictionaryAttr attrs +// CHECK: RangeT getA() { +// CHECK: RangeT getB() { +// CHECK: ValueT getC() { +// CHECK: RangeT getD() { + // Check op trait for different number of operands // --- @@ -166,7 +182,7 @@ } // CHECK-LABEL: NS::EOp declarations -// CHECK: ::mlir::Value getA(); +// CHECK: ::mlir::TypedValue<::mlir::IntegerType> getA(); // CHECK: ::mlir::MutableOperandRange getAMutable(); // CHECK: ::mlir::TypedValue<::mlir::FloatType> getB(); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::Type b, /*optional*/::mlir::Value a) @@ -335,6 +351,18 @@ // Check leading underscore in op name // --- +def NS_VarOfVarOperandOp : NS_Op<"var_of_var_operand", []> { + let arguments = (ins + VariadicOfVariadic:$var_of_var_attr, + DenseI32ArrayAttr:$var_size + ); +} + +// CHECK-LABEL: class VarOfVarOperandOpGenericAdaptor +// CHECK: public: +// CHECK: ::llvm::SmallVector getVarOfVarAttr() { + + def NS__AOp : NS_Op<"_op_with_leading_underscore", []>; // CHECK-LABEL: NS::_AOp declarations Index: mlir/test/mlir-tblgen/op-operand.td =================================================================== --- mlir/test/mlir-tblgen/op-operand.td +++ mlir/test/mlir-tblgen/op-operand.td @@ -14,8 +14,8 @@ // CHECK-LABEL: OpA definitions -// CHECK: OpAAdaptor::OpAAdaptor -// CHECK-SAME: odsOperands(values), odsAttrs(attrs) +// CHECK: OpAGenericAdaptorBase::OpAGenericAdaptorBase +// CHECK-SAME: odsAttrs(attrs) // CHECK: void OpA::build // CHECK: ::mlir::Value input @@ -39,15 +39,6 @@ let arguments = (ins Variadic:$input1, AnyTensor:$input2, Variadic:$input3); } -// CHECK-LABEL: ::mlir::ValueRange OpDAdaptor::getInput1 -// CHECK-NEXT: return getODSOperands(0); - -// CHECK-LABEL: ::mlir::Value OpDAdaptor::getInput2 -// CHECK-NEXT: return *getODSOperands(1).begin(); - -// CHECK-LABEL: ::mlir::ValueRange OpDAdaptor::getInput3 -// CHECK-NEXT: return getODSOperands(2); - // CHECK-LABEL: ::mlir::Operation::operand_range OpD::getInput1 // CHECK-NEXT: return getODSOperands(0); Index: mlir/tools/mlir-tblgen/OpClass.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpClass.cpp +++ mlir/tools/mlir-tblgen/OpClass.cpp @@ -27,6 +27,11 @@ declare("Op::print"); /// Type alias for the adaptor class. declare("Adaptor", className + "Adaptor"); + declare("GenericAdaptor", + className + "GenericAdaptor") + ->addTemplateParam("RangeT"); + declare( + "FoldAdaptor", "GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>"); } void OpClass::finalize() { Index: mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -117,11 +117,12 @@ /// /// {0}: The name of the segment attribute. /// {1}: The index of the main operand. +/// {2}: The range type of adaptor. static const char *const variadicOfVariadicAdaptorCalcCode = R"( auto tblgenTmpOperands = getODSOperands({1}); auto sizes = {0}(); - ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups; + ::llvm::SmallVector<{2}> tblgenTmpOperandGroups; for (int i = 0, e = sizes.size(); i < e; ++i) {{ tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(sizes[i])); tblgenTmpOperands = tblgenTmpOperands.drop_front(sizes[i]); @@ -1190,13 +1191,17 @@ // Generates the code to compute the start and end index of an operand or result // range. template -static void -generateValueRangeStartAndEnd(Class &opClass, StringRef methodName, - int numVariadic, int numNonVariadic, - StringRef rangeSizeCall, bool hasAttrSegmentSize, - StringRef sizeAttrInit, RangeT &&odsValues) { +static void generateValueRangeStartAndEnd( + Class &opClass, bool isGenericAdaptorBase, StringRef methodName, + int numVariadic, int numNonVariadic, StringRef rangeSizeCall, + bool hasAttrSegmentSize, StringRef sizeAttrInit, RangeT &&odsValues) { + + SmallVector parameters{MethodParameter("unsigned", "index")}; + if (isGenericAdaptorBase) + parameters.emplace_back("unsigned", "odsOperandsSize"); + auto *method = opClass.addMethod("std::pair", methodName, - MethodParameter("unsigned", "index")); + parameters); if (!method) return; auto &body = method->body(); @@ -1218,8 +1223,7 @@ } } -static std::string generateTypeForGetter(bool isAdaptor, - const NamedTypeConstraint &value) { +static std::string generateTypeForGetter(const NamedTypeConstraint &value) { std::string str = "::mlir::Value"; /// If the CPPClassName is not a fully qualified type. Uses of types /// across Dialect fail because they are not in the correct namespace. So we @@ -1229,7 +1233,7 @@ /// https://github.com/llvm/llvm-project/issues/57279. /// Adaptor will have values that are not from the type of their operation and /// this is expected, so we dont generate TypedValue for Adaptor - if (!isAdaptor && value.constraint.getCPPClassName() != "::mlir::Type" && + if (value.constraint.getCPPClassName() != "::mlir::Type" && StringRef(value.constraint.getCPPClassName()).startswith("::")) str = llvm::formatv("::mlir::TypedValue<{0}>", value.constraint.getCPPClassName()) @@ -1248,12 +1252,12 @@ // "{0}" marker in the pattern. Note that the pattern should work for any kind // of ops, in particular for one-operand ops that may not have the // `getOperand(unsigned)` method. -static void generateNamedOperandGetters(const Operator &op, Class &opClass, - bool isAdaptor, StringRef sizeAttrInit, - StringRef rangeType, - StringRef rangeBeginCall, - StringRef rangeSizeCall, - StringRef getOperandCallPattern) { +static void +generateNamedOperandGetters(const Operator &op, Class &opClass, + Class *genericAdaptorBase, StringRef sizeAttrInit, + StringRef rangeType, StringRef rangeElementType, + StringRef rangeBeginCall, StringRef rangeSizeCall, + StringRef getOperandCallPattern) { const int numOperands = op.getNumOperands(); const int numVariadicOperands = op.getNumVariableLengthOperands(); const int numNormalOperands = numOperands - numVariadicOperands; @@ -1281,10 +1285,33 @@ // First emit a few "sink" getter methods upon which we layer all nicer named // getter methods. - generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength", - numVariadicOperands, numNormalOperands, - rangeSizeCall, attrSizedOperands, sizeAttrInit, - const_cast(op).getOperands()); + // If generating for an adaptor, the method is put into the non-templated + // generic base class, to not require being defined in the header. + // Since the operand size can't be determined from the base class however, + // it has to be passed as an additional argument. The trampoline below + // generates the function with the same signature as the Op in the generic + // adaptor. + generateValueRangeStartAndEnd( + /*opClass=*/genericAdaptorBase != nullptr ? *genericAdaptorBase : opClass, + /*isGenericAdaptorBase=*/genericAdaptorBase != nullptr, + /*methodName=*/"getODSOperandIndexAndLength", numVariadicOperands, + numNormalOperands, + /*rangeSizeCall=*/genericAdaptorBase != nullptr ? "odsOperandsSize" + : rangeSizeCall, + attrSizedOperands, sizeAttrInit, + const_cast(op).getOperands()); + if (genericAdaptorBase) { + // Generate trampoline for calling 'getODSOperandIndexAndLength' with just + // the index. This just calls the implementation in the base class but + // passes the operand size as parameter. + Method *method = opClass.addMethod("std::pair", + "getODSOperandIndexAndLength", + MethodParameter("unsigned", "index")); + ERROR_IF_PRUNED(method, "getODSOperandIndexAndLength", op); + MethodBody &body = method->body(); + body.indent() << formatv( + "return Base::getODSOperandIndexAndLength(index, {0});", rangeSizeCall); + } auto *m = opClass.addMethod(rangeType, "getODSOperands", MethodParameter("unsigned", "index")); @@ -1301,19 +1328,23 @@ continue; std::string name = op.getGetterName(operand.name); if (operand.isOptional()) { - m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); + m = opClass.addMethod(genericAdaptorBase != nullptr + ? rangeElementType + : generateTypeForGetter(operand), + name); ERROR_IF_PRUNED(m, name, op); - m->body() << " auto operands = getODSOperands(" << i << ");\n" - << " return operands.empty() ? ::mlir::Value() : " - "*operands.begin();"; + m->body().indent() << formatv( + "auto operands = getODSOperands({0});\n" + "return operands.empty() ? {1}{{} : *operands.begin();", + i, rangeElementType); } else if (operand.isVariadicOfVariadic()) { std::string segmentAttr = op.getGetterName( operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); - if (isAdaptor) { - m = opClass.addMethod("::llvm::SmallVector<::mlir::ValueRange>", name); + if (genericAdaptorBase) { + m = opClass.addMethod("::llvm::SmallVector<" + rangeType + ">", name); ERROR_IF_PRUNED(m, name, op); m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, - segmentAttr, i); + segmentAttr, i, rangeType); continue; } @@ -1326,7 +1357,10 @@ ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSOperands(" << i << ");"; } else { - m = opClass.addMethod(generateTypeForGetter(isAdaptor, operand), name); + m = opClass.addMethod(genericAdaptorBase != nullptr + ? rangeElementType + : generateTypeForGetter(operand), + name); ERROR_IF_PRUNED(m, name, op); m->body() << " return *getODSOperands(" << i << ").begin();"; } @@ -1344,9 +1378,10 @@ generateNamedOperandGetters( op, opClass, - /*isAdaptor=*/false, + /*genericAdaptorBase=*/nullptr, /*sizeAttrInit=*/attrSizeInitCode, /*rangeType=*/"::mlir::Operation::operand_range", + /*rangeElementType=*/"::mlir::Value", /*rangeBeginCall=*/"getOperation()->operand_begin()", /*rangeSizeCall=*/"getOperation()->getNumOperands()", /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); @@ -1431,9 +1466,9 @@ } generateValueRangeStartAndEnd( - opClass, "getODSResultIndexAndLength", numVariadicResults, - numNormalResults, "getOperation()->getNumResults()", attrSizedResults, - attrSizeInitCode, op.getResults()); + opClass, /*isGenericAdaptorBase=*/false, "getODSResultIndexAndLength", + numVariadicResults, numNormalResults, "getOperation()->getNumResults()", + attrSizedResults, attrSizeInitCode, op.getResults()); auto *m = opClass.addMethod("::mlir::Operation::result_range", "getODSResults", @@ -1448,8 +1483,7 @@ continue; std::string name = op.getGetterName(result.name); if (result.isOptional()) { - m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result), - name); + m = opClass.addMethod(generateTypeForGetter(result), name); ERROR_IF_PRUNED(m, name, op); m->body() << " auto results = getODSResults(" << i << ");\n" @@ -1459,8 +1493,7 @@ ERROR_IF_PRUNED(m, name, op); m->body() << " return getODSResults(" << i << ");"; } else { - m = opClass.addMethod(generateTypeForGetter(/*isAdaptor=*/false, result), - name); + m = opClass.addMethod(generateTypeForGetter(result), name); ERROR_IF_PRUNED(m, name, op); m->body() << " return *getODSResults(" << i << ").begin();"; } @@ -2906,8 +2939,19 @@ namespace { // Helper class to emit Op operand adaptors to an output stream. Operand -// adaptors are wrappers around ArrayRef that provide named operand +// adaptors are wrappers around random access ranges that provide named operand // getters identical to those defined in the Op. +// This currently generates 3 classes per Op: +// * A Base class within the 'detail' namespace, which contains all logic and +// members independent of the random access range that is indexed into. +// In other words, it contains all the attribute and region getters. +// * A templated class named '{OpName}GenericAdaptor' with a template parameter +// 'RangeT' that is indexed into by the getters to access the operands. +// It contains all getters to access operands and inherits from the previous +// class. +// * A class named '{OpName}Adaptor', which inherits from the 'GenericAdaptor' +// with 'mlir::ValueRange' as template parameter. It adds a constructor from +// an instance of the op type and a verify function. class OpOperandAdaptorEmitter { public: static void @@ -2931,7 +2975,9 @@ // The operation for which to emit an adaptor. const Operator &op; - // The generated adaptor class. + // The generated adaptor classes. + Class genericAdaptorBase; + Class genericAdaptor; Class adaptor; // The emitter containing all of the locally emitted verification functions. @@ -2945,42 +2991,47 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter) - : op(op), adaptor(op.getAdaptorName()), + : op(op), genericAdaptorBase(op.getGenericAdaptorName() + "Base"), + genericAdaptor(op.getGenericAdaptorName()), adaptor(op.getAdaptorName()), staticVerifierEmitter(staticVerifierEmitter), emitHelper(op, /*emitForOp=*/false) { - adaptor.addField("::mlir::ValueRange", "odsOperands"); - adaptor.addField("::mlir::DictionaryAttr", "odsAttrs"); - adaptor.addField("::mlir::RegionRange", "odsRegions"); - adaptor.addField("::std::optional<::mlir::OperationName>", "odsOpName"); + + genericAdaptorBase.declare(Visibility::Protected); + genericAdaptorBase.declare("::mlir::DictionaryAttr", "odsAttrs"); + genericAdaptorBase.declare("::mlir::RegionRange", "odsRegions"); + genericAdaptorBase.declare("::std::optional<::mlir::OperationName>", + "odsOpName"); + + genericAdaptor.addTemplateParam("RangeT"); + genericAdaptor.addField("RangeT", "odsOperands"); + genericAdaptor.addParent( + ParentClass("detail::" + genericAdaptorBase.getClassName())); + genericAdaptor.declare( + "ValueT", "::llvm::detail::ValueOfRange"); + genericAdaptor.declare( + "Base", "detail::" + genericAdaptorBase.getClassName()); const auto *attrSizedOperands = - op.getTrait("::m::OpTrait::AttrSizedOperandSegments"); + op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); { SmallVector paramList; - paramList.emplace_back("::mlir::ValueRange", "values"); paramList.emplace_back("::mlir::DictionaryAttr", "attrs", attrSizedOperands ? "" : "nullptr"); paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); - auto *constructor = adaptor.addConstructor(std::move(paramList)); - - constructor->addMemberInitializer("odsOperands", "values"); - constructor->addMemberInitializer("odsAttrs", "attrs"); - constructor->addMemberInitializer("odsRegions", "regions"); + auto *baseConstructor = genericAdaptorBase.addConstructor(paramList); + baseConstructor->addMemberInitializer("odsAttrs", "attrs"); + baseConstructor->addMemberInitializer("odsRegions", "regions"); - MethodBody &body = constructor->body(); + MethodBody &body = baseConstructor->body(); body.indent() << "if (odsAttrs)\n"; body.indent() << formatv( "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n", op.getOperationName()); - } - { - auto *constructor = - adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op")); - constructor->addMemberInitializer("odsOperands", "op->getOperands()"); - constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); - constructor->addMemberInitializer("odsRegions", "op->getRegions()"); - constructor->addMemberInitializer("odsOpName", "op->getName()"); + paramList.insert(paramList.begin(), MethodParameter("RangeT", "values")); + auto *constructor = genericAdaptor.addConstructor(std::move(paramList)); + constructor->addMemberInitializer("Base", "attrs, regions"); + constructor->addMemberInitializer("odsOperands", "values"); } std::string sizeAttrInit; @@ -2988,16 +3039,18 @@ sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, emitHelper.getAttr(operandSegmentAttrName)); } - generateNamedOperandGetters(op, adaptor, - /*isAdaptor=*/true, sizeAttrInit, - /*rangeType=*/"::mlir::ValueRange", + generateNamedOperandGetters(op, genericAdaptor, + /*genericAdaptorBase=*/&genericAdaptorBase, + /*sizeAttrInit=*/sizeAttrInit, + /*rangeType=*/"RangeT", + /*rangeElementType=*/"ValueT", /*rangeBeginCall=*/"odsOperands.begin()", /*rangeSizeCall=*/"odsOperands.size()", /*getOperandCallPattern=*/"odsOperands[{0}]"); // Any invalid overlap for `getOperands` will have been diagnosed before here // already. - if (auto *m = adaptor.addMethod("::mlir::ValueRange", "getOperands")) + if (auto *m = genericAdaptor.addMethod("RangeT", "getOperands")) m->body() << " return odsOperands;"; FmtContext fctx; @@ -3006,7 +3059,8 @@ // Generate named accessor with Attribute return type. auto emitAttrWithStorageType = [&](StringRef name, StringRef emitName, Attribute attr) { - auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr"); + auto *method = + genericAdaptorBase.addMethod(attr.getStorageType(), emitName + "Attr"); ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op); auto &body = method->body().indent(); body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n" @@ -3028,7 +3082,8 @@ }; { - auto *m = adaptor.addMethod("::mlir::DictionaryAttr", "getAttributes"); + auto *m = + genericAdaptorBase.addMethod("::mlir::DictionaryAttr", "getAttributes"); ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); m->body() << " return odsAttrs;"; } @@ -3039,7 +3094,7 @@ continue; std::string emitName = op.getGetterName(name); emitAttrWithStorageType(name, emitName, attr); - emitAttrGetterWithReturnType(fctx, adaptor, op, emitName, attr); + emitAttrGetterWithReturnType(fctx, genericAdaptorBase, op, emitName, attr); } unsigned numRegions = op.getNumRegions(); @@ -3051,25 +3106,44 @@ // Generate the accessors for a variadic region. std::string name = op.getGetterName(region.name); if (region.isVariadic()) { - auto *m = adaptor.addMethod("::mlir::RegionRange", name); + auto *m = genericAdaptorBase.addMethod("::mlir::RegionRange", name); ERROR_IF_PRUNED(m, "Adaptor::" + name, op); m->body() << formatv(" return odsRegions.drop_front({0});", i); continue; } - auto *m = adaptor.addMethod("::mlir::Region &", name); + auto *m = genericAdaptorBase.addMethod("::mlir::Region &", name); ERROR_IF_PRUNED(m, "Adaptor::" + name, op); m->body() << formatv(" return *odsRegions[{0}];", i); } if (numRegions > 0) { // Any invalid overlap for `getRegions` will have been diagnosed before here // already. - if (auto *m = adaptor.addMethod("::mlir::RegionRange", "getRegions")) + if (auto *m = + genericAdaptorBase.addMethod("::mlir::RegionRange", "getRegions")) m->body() << " return odsRegions;"; } + StringRef genericAdaptorClassName = genericAdaptor.getClassName(); + adaptor.addParent(ParentClass(genericAdaptorClassName)) + .addTemplateParam("::mlir::ValueRange"); + adaptor.declare(Visibility::Public); + adaptor.declare(genericAdaptorClassName + + "::" + genericAdaptorClassName); + { + // Constructor taking the Op as single parameter. + auto *constructor = + adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op")); + constructor->addMemberInitializer( + adaptor.getClassName(), + "op->getOperands(), op->getAttrDictionary(), op->getRegions()"); + } + // Add verification function. addVerification(); + + genericAdaptorBase.finalize(); + genericAdaptor.finalize(); adaptor.finalize(); } @@ -3090,14 +3164,26 @@ const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os) { - OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os); + OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter); + { + NamespaceEmitter ns(os, "detail"); + emitter.genericAdaptorBase.writeDeclTo(os); + } + emitter.genericAdaptor.writeDeclTo(os); + emitter.adaptor.writeDeclTo(os); } void OpOperandAdaptorEmitter::emitDef( const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os) { - OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os); + OpOperandAdaptorEmitter emitter(op, staticVerifierEmitter); + { + NamespaceEmitter ns(os, "detail"); + emitter.genericAdaptorBase.writeDefTo(os); + } + emitter.genericAdaptor.writeDefTo(os); + emitter.adaptor.writeDefTo(os); } // Emits the opcode enum and op classes. Index: mlir/unittests/IR/AdaptorTest.cpp =================================================================== --- /dev/null +++ mlir/unittests/IR/AdaptorTest.cpp @@ -0,0 +1,63 @@ +//===- AdaptorTest.cpp - Adaptor unit tests -------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../../test/lib/Dialect/Test/TestDialect.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace mlir; +using namespace test; + +using testing::ElementsAre; + +TEST(Adaptor, GenericAdaptorsOperandAccess) { + MLIRContext context; + context.loadDialect(); + Builder builder(&context); + + // Has normal and Variadic arguments. + MixedNormalVariadicOperandOp::FoldAdaptor a({}); + { + SmallVector v = {0, 1, 2, 3, 4}; + MixedNormalVariadicOperandOp::GenericAdaptor> b(v); + EXPECT_THAT(b.getInput1(), ElementsAre(0, 1)); + EXPECT_EQ(b.getInput2(), 2); + EXPECT_THAT(b.getInput3(), ElementsAre(3, 4)); + } + + // Has optional arguments. + OIListSimple::FoldAdaptor c({}, nullptr); + { + // Optional arguments return the default constructed value if not present. + // Using optional instead of plain int here to differentiate absence of + // value from the value 0. + SmallVector> v = {0, 4}; + OIListSimple::GenericAdaptor>> d( + v, builder.getDictionaryAttr({builder.getNamedAttr( + "operand_segment_sizes", + builder.getDenseI32ArrayAttr({1, 0, 1}))})); + EXPECT_EQ(d.getArg0(), 0); + EXPECT_EQ(d.getArg1(), std::nullopt); + EXPECT_EQ(d.getArg2(), 4); + } + + // Has VariadicOfVariadic arguments. + FormatVariadicOfVariadicOperand::FoldAdaptor e({}); + { + SmallVector v = {0, 1, 2, 3, 4}; + FormatVariadicOfVariadicOperand::GenericAdaptor> f( + v, builder.getDictionaryAttr({builder.getNamedAttr( + "operand_segments", builder.getDenseI32ArrayAttr({3, 2, 0}))})); + SmallVector> operand = f.getOperand(); + ASSERT_EQ(operand.size(), (std::size_t)3); + EXPECT_THAT(operand[0], ElementsAre(0, 1, 2)); + EXPECT_THAT(operand[1], ElementsAre(3, 4)); + EXPECT_THAT(operand[2], ElementsAre()); + } +} \ No newline at end of file Index: mlir/unittests/IR/CMakeLists.txt =================================================================== --- mlir/unittests/IR/CMakeLists.txt +++ mlir/unittests/IR/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_unittest(MLIRIRTests + AdaptorTest.cpp AttributeTest.cpp BlockAndValueMapping.cpp DialectTest.cpp