diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -149,13 +149,8 @@ /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - if constexpr (SourceOp::hasProperties()) - return rewrite(cast(op), - OpAdaptor(operands, op->getDiscardableAttrDictionary(), - cast(op).getProperties()), - rewriter); - rewrite(cast(op), - OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter); + rewrite(cast(op), OpAdaptor(operands, cast(op)), + rewriter); } LogicalResult match(Operation *op) const final { return match(cast(op)); @@ -163,15 +158,8 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - if constexpr (SourceOp::hasProperties()) - return matchAndRewrite(cast(op), - OpAdaptor(operands, - op->getDiscardableAttrDictionary(), - cast(op).getProperties()), - rewriter); - return matchAndRewrite( - cast(op), - OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter); + return matchAndRewrite(cast(op), + OpAdaptor(operands, cast(op)), rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1915,15 +1915,8 @@ SmallVectorImpl &results) { OpFoldResult result; if constexpr (has_fold_adaptor_single_result_v) { - if constexpr (hasProperties()) { - result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getDiscardableAttrDictionary(), - cast(op).getProperties(), op->getRegions())); - } else { - result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( - operands, op->getDiscardableAttrDictionary(), {}, - op->getRegions())); - } + result = cast(op).fold( + typename ConcreteOpT::FoldAdaptor(operands, cast(op))); } else { result = cast(op).fold(operands); } @@ -1946,19 +1939,9 @@ SmallVectorImpl &results) { auto result = LogicalResult::failure(); if constexpr (has_fold_adaptor_v) { - if constexpr (hasProperties()) { - result = cast(op).fold( - typename ConcreteOpT::FoldAdaptor( - operands, op->getDiscardableAttrDictionary(), - cast(op).getProperties(), op->getRegions()), - results); - } else { - result = cast(op).fold( - typename ConcreteOpT::FoldAdaptor( - operands, op->getDiscardableAttrDictionary(), {}, - op->getRegions()), - results); - } + result = cast(op).fold( + typename ConcreteOpT::FoldAdaptor(operands, cast(op)), + results); } else { result = cast(op).fold(operands, results); } diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h --- a/mlir/include/mlir/TableGen/Class.h +++ b/mlir/include/mlir/TableGen/Class.h @@ -166,6 +166,21 @@ /// method definition). void writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const; + /// Write the template parameters of the signature. + void writeTemplateParamsTo(raw_indented_ostream &os) const; + + /// Add a template parameter. + template + void addTemplateParam(ParamT param) { + templateParams.push_back(stringify(param)); + } + + /// Add a list of template parameters. + template + void addTemplateParams(ContainerT &&container) { + templateParams.insert(std::begin(container), std::end(container)); + } + private: /// The method's C++ return type. std::string returnType; @@ -173,6 +188,8 @@ std::string methodName; /// The method's parameter list. MethodParameters parameters; + /// An optional list of template parameters. + SmallVector templateParams; }; /// This class contains the body of a C++ method. @@ -367,6 +384,14 @@ void writeDefTo(raw_indented_ostream &os, StringRef namePrefix) const override; + /// Add a template parameter. + template + void addTemplateParam(ParamT param); + + /// Add a list of template parameters. + template + void addTemplateParams(ContainerT &&container); + protected: /// A collection of method properties. Properties properties; @@ -459,6 +484,20 @@ namespace mlir { namespace tblgen { +template +void Method::addTemplateParam(ParamT param) { + // Templates imply inline. + properties |= Method::Inline; + methodSignature.addTemplateParam(param); +} + +template +void Method::addTemplateParams(ContainerT &&container) { + // Templates imply inline. + properties |= Method::Inline; + methodSignature.addTemplateParam(std::forward(container)); +} + /// This class describes a C++ parent class declaration. class ParentClass { public: diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -523,24 +523,13 @@ void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto sourceOp = cast(op); - rewrite(sourceOp, - OpAdaptor(operands, op->getDiscardableAttrDictionary(), - sourceOp.getProperties()), - rewriter); + rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto sourceOp = cast(op); - if constexpr (SourceOp::hasProperties()) - return matchAndRewrite(sourceOp, - OpAdaptor(operands, - op->getDiscardableAttrDictionary(), - sourceOp.getProperties()), - rewriter); - return matchAndRewrite( - sourceOp, OpAdaptor(operands, op->getDiscardableAttrDictionary()), - rewriter); + return matchAndRewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be diff --git a/mlir/include/mlir/Transforms/OneToNTypeConversion.h b/mlir/include/mlir/Transforms/OneToNTypeConversion.h --- a/mlir/include/mlir/Transforms/OneToNTypeConversion.h +++ b/mlir/include/mlir/Transforms/OneToNTypeConversion.h @@ -231,12 +231,9 @@ OpAdaptor(const OneToNTypeMapping *operandMapping, const OneToNTypeMapping *resultMapping, - const ValueRange *convertedOperands, RangeT values, - DictionaryAttr attrs = nullptr, Properties &properties = {}, - RegionRange regions = {}) - : BaseT(values, attrs, properties, regions), - operandMapping(operandMapping), resultMapping(resultMapping), - convertedOperands(convertedOperands) {} + const ValueRange *convertedOperands, RangeT values, SourceOp op) + : BaseT(values, op), operandMapping(operandMapping), + resultMapping(resultMapping), convertedOperands(convertedOperands) {} /// Get the type mapping of the original operands to the converted operands. const OneToNTypeMapping &getOperandMapping() const { @@ -276,8 +273,7 @@ valueRanges.push_back(values); } OpAdaptor adaptor(&operandMapping, &resultMapping, &convertedOperands, - valueRanges, op->getAttrDictionary(), - cast(op).getProperties(), op->getRegions()); + valueRanges, cast(op)); // Call overload implemented by the derived class. return matchAndRewrite(cast(op), adaptor, rewriter); diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -136,11 +136,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { auto reallocOp = cast(op); - return matchAndRewrite(reallocOp, - OpAdaptor(operands, - op->getDiscardableAttrDictionary(), - reallocOp.getProperties()), - rewriter); + return matchAndRewrite(reallocOp, OpAdaptor(operands, reallocOp), rewriter); } // A `realloc` is converted as follows: diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -292,7 +292,7 @@ void ConditionOp::getSuccessorRegions( ArrayRef operands, SmallVectorImpl ®ions) { - FoldAdaptor adaptor(operands); + FoldAdaptor adaptor(operands, *this); WhileOp whileOp = getParentOp(); @@ -2031,7 +2031,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef operands, SmallVectorImpl ®ions) { - FoldAdaptor adaptor(operands); + FoldAdaptor adaptor(operands, *this); auto boolAttr = dyn_cast_or_null(adaptor.getCondition()); if (!boolAttr || boolAttr.getValue()) regions.emplace_back(&getThenRegion()); @@ -4039,7 +4039,7 @@ void IndexSwitchOp::getEntrySuccessorRegions( ArrayRef operands, SmallVectorImpl &successors) { - FoldAdaptor adaptor(operands); + FoldAdaptor adaptor(operands, *this); // If a constant was not provided, all regions are possible successors. auto arg = dyn_cast_or_null(adaptor.getArg()); diff --git a/mlir/lib/TableGen/Class.cpp b/mlir/lib/TableGen/Class.cpp --- a/mlir/lib/TableGen/Class.cpp +++ b/mlir/lib/TableGen/Class.cpp @@ -93,6 +93,17 @@ os << ")"; } +void MethodSignature::writeTemplateParamsTo( + mlir::raw_indented_ostream &os) const { + if (templateParams.empty()) + return; + + os << "template <"; + llvm::interleaveComma(templateParams, os, + [&](StringRef param) { os << "typename " << param; }); + os << ">\n"; +} + //===----------------------------------------------------------------------===// // MethodBody definitions //===----------------------------------------------------------------------===// @@ -114,6 +125,7 @@ //===----------------------------------------------------------------------===// void Method::writeDeclTo(raw_indented_ostream &os) const { + methodSignature.writeTemplateParamsTo(os); if (deprecationMessage) { os << "[[deprecated(\""; os.write_escaped(*deprecationMessage); @@ -153,6 +165,7 @@ //===----------------------------------------------------------------------===// void Constructor::writeDeclTo(raw_indented_ostream &os) const { + methodSignature.writeTemplateParamsTo(os); if (properties & ConstexprValue) os << "constexpr "; methodSignature.writeDeclTo(os); diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -57,6 +57,7 @@ // CHECK: namespace detail { // CHECK: class AOpGenericAdaptorBase { // CHECK: public: +// CHECK: AOpGenericAdaptorBase(AOp{{[[:space:]]}} // CHECK: ::mlir::IntegerAttr getAttr1Attr(); // CHECK: uint32_t getAttr1(); // CHECK: ::mlir::FloatAttr getSomeAttr2Attr(); @@ -127,6 +128,14 @@ // DEFS-LABEL: NS::AOp definitions // DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(::mlir::DictionaryAttr attrs, const ::mlir::EmptyProperties &properties, ::mlir::RegionRange regions) : odsAttrs(attrs), odsRegions(regions) + +// Check that `getAttrDictionary()` is used when not using properties. + +// DEFS: AOpGenericAdaptorBase::AOpGenericAdaptorBase(AOp op) +// DEFS-SAME: op->getAttrDictionary() +// DEFS-SAME: p.getProperties() +// DEFS-SAME: op->getRegions() + // DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getSomeRegions() // DEFS-NEXT: return odsRegions.drop_front(1); // DEFS: ::mlir::RegionRange AOpGenericAdaptorBase::getRegions() @@ -330,6 +339,17 @@ // CHECK-LABEL: class MOp : // CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor); +def NS_NOp : NS_Op<"op_with_properties", []> { + let arguments = (ins Property<"unsigned">:$value); +} + +// Check that `getDiscardableAttrDictionary()` is used with properties. + +// DEFS: NOpGenericAdaptorBase::NOpGenericAdaptorBase(NOp op) : NOpGenericAdaptorBase( +// DEFS-SAME: op->getDiscardableAttrDictionary() +// DEFS-SAME: op.getProperties() +// DEFS-SAME: op->getRegions() + // Test that type defs have the proper namespaces when used as a constraint. // --- diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3961,6 +3961,35 @@ } } + // Create constructors constructing the adaptor from an instance of the op. + // This takes the attributes, properties and regions from the op instance + // and the value range from the parameter. + { + // Base class is in the cpp file and can simply access the members of the op + // class to initialize the template independent fields. + auto *constructor = genericAdaptorBase.addConstructor( + MethodParameter(op.getCppClassName(), "op")); + constructor->addMemberInitializer( + genericAdaptorBase.getClassName(), + llvm::Twine(!useProperties ? "op->getAttrDictionary()" + : "op->getDiscardableAttrDictionary()") + + ", op.getProperties(), op->getRegions()"); + + // Generic adaptor is templated and therefore defined inline in the header. + // We cannot use the Op class here as it is an incomplete type (we have a + // circular reference between the two). + // Use a template trick to make the constructor be instantiated at call site + // when the op class is complete. + constructor = genericAdaptor.addConstructor( + MethodParameter("RangeT", "values"), MethodParameter("LateInst", "op")); + constructor->addTemplateParam("LateInst = " + op.getCppClassName()); + constructor->addTemplateParam( + "= std::enable_if_t>"); + constructor->addMemberInitializer("Base", "op"); + constructor->addMemberInitializer("odsOperands", "values"); + } + std::string sizeAttrInit; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { if (op.getDialect().usePropertiesForAttributes()) @@ -4074,9 +4103,8 @@ // Constructor taking the Op as single parameter. auto *constructor = adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op")); - constructor->addMemberInitializer( - adaptor.getClassName(), "op->getOperands(), op->getAttrDictionary(), " - "op.getProperties(), op->getRegions()"); + constructor->addMemberInitializer(genericAdaptorClassName, + "op->getOperands(), op"); } // Add verification function.