Index: mlir/include/mlir/Conversion/LLVMCommon/Pattern.h =================================================================== --- mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -149,13 +149,8 @@ /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - if constexpr (SourceOp::hasProperties()) - return rewrite(cast(op), - OpAdaptor(operands, op->getDiscardableAttrDictionary(), - cast(op).getProperties()), - rewriter); - rewrite(cast(op), - OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter); + rewrite(cast(op), OpAdaptor(operands, cast(op)), + rewriter); } LogicalResult match(Operation *op) const final { return match(cast(op)); @@ -163,15 +158,8 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - if constexpr (SourceOp::hasProperties()) - return matchAndRewrite(cast(op), - OpAdaptor(operands, - op->getDiscardableAttrDictionary(), - cast(op).getProperties()), - rewriter); - return matchAndRewrite( - cast(op), - OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter); + return matchAndRewrite(cast(op), + OpAdaptor(operands, cast(op)), rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be Index: mlir/include/mlir/IR/OpDefinition.h =================================================================== --- mlir/include/mlir/IR/OpDefinition.h +++ mlir/include/mlir/IR/OpDefinition.h @@ -1917,15 +1917,8 @@ SmallVectorImpl &results) { OpFoldResult result; if constexpr (has_fold_adaptor_single_result_v) { - if constexpr (hasProperties()) { - result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getDiscardableAttrDictionary(), - cast(op).getProperties(), op->getRegions())); - } else { - result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getDiscardableAttrDictionary(), {}, - op->getRegions())); - } + result = cast(op).fold( + typename ConcreteOpT::FoldAdaptor(operands, cast(op))); } else { result = cast(op).fold(operands); } @@ -1948,19 +1941,9 @@ SmallVectorImpl &results) { auto result = LogicalResult::failure(); if constexpr (has_fold_adaptor_v) { - if constexpr (hasProperties()) { - result = cast(op).fold( - typename ConcreteOpT::FoldAdaptor( - operands, op->getDiscardableAttrDictionary(), - cast(op).getProperties(), op->getRegions()), - results); - } else { - result = cast(op).fold( - typename ConcreteOpT::FoldAdaptor( - operands, op->getDiscardableAttrDictionary(), {}, - op->getRegions()), - results); - } + result = cast(op).fold( + typename ConcreteOpT::FoldAdaptor(operands, cast(op)), + results); } else { result = cast(op).fold(operands, results); } Index: mlir/include/mlir/TableGen/Class.h =================================================================== --- mlir/include/mlir/TableGen/Class.h +++ mlir/include/mlir/TableGen/Class.h @@ -166,6 +166,21 @@ /// method definition). void writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const; + /// Write the template parameters of the signature. + void writeTemplateParamsTo(raw_indented_ostream &os) const; + + /// Add a template parameter. + template + void addTemplateParam(ParamT param) { + templateParams.insert(stringify(param)); + } + + /// Add a list of template parameters. + template + void addTemplateParams(ContainerT &&container) { + templateParams.insert(std::begin(container), std::end(container)); + } + private: /// The method's C++ return type. std::string returnType; @@ -173,6 +188,8 @@ std::string methodName; /// The method's parameter list. MethodParameters parameters; + /// An optional list of template parameters. + SetVector, StringSet<>> templateParams; }; /// This class contains the body of a C++ method. @@ -367,6 +384,14 @@ void writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const override; + /// Add a template parameter. + template + void addTemplateParam(ParamT param); + + /// Add a list of template parameters. + template + void addTemplateParams(ContainerT &&container); + protected: /// A collection of method properties. Properties properties; @@ -459,6 +484,20 @@ namespace mlir { namespace tblgen { +template +void Method::addTemplateParam(ParamT param) { + // Templates imply inline. + properties |= Method::Inline; + methodSignature.addTemplateParam(param); +} + +template +void Method::addTemplateParams(ContainerT &&container) { + // Templates imply inline. + properties |= Method::Inline; + methodSignature.addTemplateParam(std::forward(container)); +} + /// This class describes a C++ parent class declaration. class ParentClass { public: Index: mlir/include/mlir/Transforms/DialectConversion.h =================================================================== --- mlir/include/mlir/Transforms/DialectConversion.h +++ mlir/include/mlir/Transforms/DialectConversion.h @@ -523,24 +523,13 @@ void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto sourceOp = cast(op); - rewrite(sourceOp, - OpAdaptor(operands, op->getDiscardableAttrDictionary(), - sourceOp.getProperties()), - rewriter); + rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto sourceOp = cast(op); - if constexpr (SourceOp::hasProperties()) - return matchAndRewrite(sourceOp, - OpAdaptor(operands, - op->getDiscardableAttrDictionary(), - sourceOp.getProperties()), - rewriter); - return matchAndRewrite( - sourceOp, OpAdaptor(operands, op->getDiscardableAttrDictionary()), - rewriter); + return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be Index: mlir/include/mlir/Transforms/OneToNTypeConversion.h =================================================================== --- mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ mlir/include/mlir/Transforms/OneToNTypeConversion.h @@ -231,12 +231,9 @@ OpAdaptor(const OneToNTypeMapping *operandMapping, const OneToNTypeMapping *resultMapping, - const ValueRange *convertedOperands, RangeT values, - DictionaryAttr attrs = nullptr, Properties &properties = {}, - RegionRange regions = {}) - : BaseT(values, attrs, properties, regions), - operandMapping(operandMapping), resultMapping(resultMapping), - convertedOperands(convertedOperands) {} + const ValueRange *convertedOperands, RangeT values, SourceOp op) + : BaseT(values, op), operandMapping(operandMapping), + resultMapping(resultMapping), convertedOperands(convertedOperands) {} /// Get the type mapping of the original operands to the converted operands. const OneToNTypeMapping &getOperandMapping() const { @@ -276,8 +273,7 @@ valueRanges.push_back(values); } OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands, - valueRanges, op->getAttrDictionary(), - cast(op).getProperties(), op->getRegions()); + valueRanges, cast(op)); // Call overload implemented by the derived class. return matchAndRewrite(cast(op), adaptor, rewriter); Index: mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp =================================================================== --- mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -136,11 +136,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto reallocOp = cast(op); - return matchAndRewrite(reallocOp, - OpAdaptor(operands, - op->getDiscardableAttrDictionary(), - reallocOp.getProperties()), - rewriter); + return matchAndRewrite(reallocOp, OpAdaptor(operands, reallocOp), rewriter); } // A `realloc` is converted as follows: Index: mlir/lib/TableGen/Class.cpp =================================================================== --- mlir/lib/TableGen/Class.cpp +++ mlir/lib/TableGen/Class.cpp @@ -93,6 +93,17 @@ os << ")"; } +void MethodSignature::writeTemplateParamsTo( + mlir::raw_indented_ostream &os) const { + if (templateParams.empty()) + return; + + os << "template <"; + llvm::interleaveComma(templateParams, os, + [&](StringRef param) { os << "typename " << param; }); + os << ">\n"; +} + //===----------------------------------------------------------------------===// // MethodBody definitions //===----------------------------------------------------------------------===// @@ -114,6 +125,7 @@ //===----------------------------------------------------------------------===// void Method::writeDeclTo(raw_indented_ostream &os) const { + methodSignature.writeTemplateParamsTo(os); if (deprecationMessage) { os << "[[deprecated(\""; os.write_escaped(*deprecationMessage); @@ -153,6 +165,7 @@ //===----------------------------------------------------------------------===// void Constructor::writeDeclTo(raw_indented_ostream &os) const { + methodSignature.writeTemplateParamsTo(os); if (properties & ConstexprValue) os << "constexpr "; methodSignature.writeDeclTo(os); Index: mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3958,6 +3958,35 @@ } } + // Create constructors constructing the adaptor from an instance of the op. + // This takes the attributes, properties and regions from the op instance + // and the value range from the parameter. + { + // Base class is in the cpp file and can simply access the members of the op + // class to initialize the template independent fields. + auto *constructor = genericAdaptorBase.addConstructor( + MethodParameter(op.getCppClassName(), "op")); + constructor->addMemberInitializer( + genericAdaptorBase.getClassName(), + llvm::Twine(useProperties ? "op->getAttrDictionary()" + : "op->getDiscardableAttrDictionary()") + + ", op.getProperties(), op->getRegions()"); + + // Generic adaptor is templated and therefore defined inline in the header. + // We cannot use the Op class here as it is an incomplete type (we have a + // circular reference between the two). + // Use a template trick to make the constructor be instantiated at call site + // when the op class is complete. + constructor = genericAdaptor.addConstructor( + MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op")); + constructor->addTemplateParam("LateInst = " + op.getCppClassName()); + constructor->addTemplateParam( + "= std::enable_if_t>"); + constructor->addMemberInitializer("Base", "op"); + constructor->addMemberInitializer("odsOperands", "values"); + } + std::string sizeAttrInit; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { if (op.getDialect().usePropertiesForAttributes()) @@ -4071,9 +4100,8 @@ // Constructor taking the Op as single parameter. auto *constructor = adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op")); - constructor->addMemberInitializer( - adaptor.getClassName(), "op->getOperands(), op->getAttrDictionary(), " - "op.getProperties(), op->getRegions()"); + constructor->addMemberInitializer(genericAdaptorClassName, + "op->getOperands(), op"); } // Add verification function.