diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -903,7 +903,7 @@ Pure, SameVariadicResultSize, ViewLikeOpInterface, - InferTypeOpInterfaceAdaptor]> { + DeclareOpInterfaceMethods]> { let summary = "Extracts a buffer base with offset and strides"; let description = [{ Extracts a base buffer, offset and strides. This op allows additional layers diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -21,14 +21,7 @@ //===----------------------------------------------------------------------===// // These classes are used to define attribute specific traits. -class NativeAttrTrait - : NativeTrait { - let extraConcreteClassDeclaration = extraAttrDeclaration; - let extraConcreteClassDefinition = extraAttrDefinition; -} - +class NativeAttrTrait : NativeTrait; class ParamNativeAttrTrait : ParamNativeTrait; class GenInternalAttrTrait : GenInternalTrait; @@ -39,14 +32,7 @@ //===----------------------------------------------------------------------===// // These classes are used to define type specific traits. -class NativeTypeTrait - : NativeTrait { - let extraConcreteClassDeclaration = extraTypeDeclaration; - let extraConcreteClassDefinition = extraTypeDefinition; -} - +class NativeTypeTrait : NativeTrait; class ParamNativeTypeTrait : ParamNativeTrait; class GenInternalTypeTrait : GenInternalTrait; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1959,9 +1959,6 @@ class NativeTrait : Trait { string trait = name; string cppNamespace = "::mlir::" # entityType # "Trait"; - - code extraConcreteClassDeclaration = ?; - code extraConcreteClassDefinition = ?; } // ParamNativeTrait corresponds to the template-parameterized traits in the C++ @@ -1994,16 +1991,11 @@ class StructuralOpTrait; // These classes are used to define operation specific traits. -class NativeOpTrait traits = [], - code extraOpDeclaration = [{}], - code extraOpDefinition = [{}]> +class NativeOpTrait traits = []> : NativeTrait { // Specify the list of traits that need to be verified before the verification // of this NativeOpTrait. list dependentTraits = traits; - - let extraConcreteClassDeclaration = extraOpDeclaration; - let extraConcreteClassDefinition = extraOpDefinition; } class ParamNativeOpTrait traits = []> diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -241,7 +241,14 @@ // TODO: Consider generating typedefs for trait member functions if this usage // becomes more common. LogicalResult inferReturnTensorTypes( - SmallVector retComponents, + function_ref< + LogicalResult(MLIRContext *, std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &retComponents)> + componentTypeFn, + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes); /// Verifies that the inferred result types match the actual result types for @@ -261,10 +268,6 @@ namespace mlir { namespace OpTrait { -template -class InferTypeOpInterfaceAdaptor : - public TraitBase {}; - /// Tensor type inference trait that constructs a tensor from the inferred /// shape and elemental types. /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface. @@ -273,7 +276,24 @@ /// trait is currently only used where the interfaces are, so keep it /// restricted for now). template -class InferTensorType : public TraitBase {}; +class InferTensorType : public TraitBase { +public: + static LogicalResult + inferReturnTypes(MLIRContext *context, std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + static_assert( + ConcreteType::template hasTrait(), + "requires InferShapedTypeOpInterface to ensure succesful invocation"); + static_assert( + ConcreteType::template hasTrait(), + "requires InferTypeOpInterface to ensure succesful invocation"); + return ::mlir::detail::inferReturnTensorTypes( + ConcreteType::inferReturnTypeComponents, context, location, operands, + attributes, properties, regions, inferredReturnTypes); + } +}; } // namespace OpTrait } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -184,69 +184,18 @@ ]; } -// Convenient trait to define a wrapper to inferReturnTypes that passes in the -// Op Adaptor directly -def InferTypeOpInterfaceAdaptor : TraitList< - [ - // Op implements infer type op interface. - DeclareOpInterfaceMethods, - NativeOpTrait< - /*name=*/"InferTypeOpInterfaceAdaptor", - /*traits=*/[], - /*extraOpDeclaration=*/[{ - static LogicalResult - inferReturnTypesAdaptor(MLIRContext *context, - std::optional location, - Adaptor adaptor, - SmallVectorImpl &inferredReturnTypes); - }], - /*extraOpDefinition=*/[{ - LogicalResult - $cppClass::inferReturnTypes(MLIRContext *context, - std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - $cppClass::Adaptor adaptor(operands, attributes, properties, regions); - return $cppClass::inferReturnTypesAdaptor(context, - location, adaptor, inferredReturnTypes); - } - }] - > - ]>; - // Convenience class grouping together type and shaped type op interfaces for // ops that have tensor return types. class InferTensorTypeBase overridenMethods = []> : TraitList< [ // Op implements infer type op interface. - DeclareOpInterfaceMethods, + InferTypeOpInterface, // The op will have methods implementing the ShapedType type inference // interface. DeclareOpInterfaceMethods, // The op produces tensors and will use the ShapedType type infer interface // along with knowledge that it is producing Tensors to infer the type. - NativeOpTrait< - /*name=*/"InferTensorType", - /*traits=*/[], - /*extraOpDeclaration=*/[{}], - /*extraOpDefinition=*/[{ - LogicalResult - $cppClass::inferReturnTypes(MLIRContext *context, - std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - SmallVector retComponents; - if (failed($cppClass::inferReturnTypeComponents(context, location, - operands, attributes, properties, regions, - retComponents))) - return failure(); - return ::mlir::detail::inferReturnTensorTypes(retComponents, - inferredReturnTypes); - } - }] - > + NativeOpTrait<"InferTensorType"> ]>; def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>; diff --git a/mlir/include/mlir/TableGen/Trait.h b/mlir/include/mlir/TableGen/Trait.h --- a/mlir/include/mlir/TableGen/Trait.h +++ b/mlir/include/mlir/TableGen/Trait.h @@ -68,14 +68,6 @@ // Returns if this is a structural op trait. bool isStructuralOpTrait() const; - // Returns extra class declaration code to be added to the concrete type - // when the trait is specified - std::optional getExtraConcreteClassDeclaration() const; - - // Returns extra class definition code to be added to the concrete type - // when the trait is specified - std::optional getExtraConcreteClassDefinition() const; - static bool classof(const Trait *t) { return t->getKind() == Kind::Native; } }; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1355,12 +1355,14 @@ /// The number and type of the results are inferred from the /// shape of the source. -LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor( - MLIRContext *context, std::optional location, - ExtractStridedMetadataOp::Adaptor adaptor, +LogicalResult ExtractStridedMetadataOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { + ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, + properties); auto sourceType = - llvm::dyn_cast(adaptor.getSource().getType()); + llvm::dyn_cast(extractAdaptor.getSource().getType()); if (!sourceType) return failure(); diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -217,8 +217,19 @@ } LogicalResult mlir::detail::inferReturnTensorTypes( - SmallVector retComponents, + function_ref< + LogicalResult(MLIRContext *, std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &retComponents)> + componentTypeFn, + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { + SmallVector retComponents; + if (failed(componentTypeFn(context, location, operands, attributes, + properties, regions, retComponents))) + return failure(); for (const auto &shapeAndType : retComponents) { Type elementTy = shapeAndType.getElementType(); assert(elementTy && "element type required to construct tensor"); diff --git a/mlir/lib/TableGen/Trait.cpp b/mlir/lib/TableGen/Trait.cpp --- a/mlir/lib/TableGen/Trait.cpp +++ b/mlir/lib/TableGen/Trait.cpp @@ -54,18 +54,6 @@ return def->isSubClassOf("StructuralOpTrait"); } -std::optional -NativeTrait::getExtraConcreteClassDeclaration() const { - auto value = def->getValueAsOptionalString("extraConcreteClassDeclaration"); - return value->empty() ? std::optional() : value; -} - -std::optional -NativeTrait::getExtraConcreteClassDefinition() const { - auto value = def->getValueAsOptionalString("extraConcreteClassDefinition"); - return value->empty() ? std::optional() : value; -} - //===----------------------------------------------------------------------===// // InternalTrait //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -214,42 +214,14 @@ defCls.addParent(std::move(defParent)); } -/// Include declarations specified on NativeTrait -static std::string formatExtraDeclarations(const AttrOrTypeDef &def) { - std::string extraDeclarations = ""; - // Include extra class declarations from NativeTrait - for (const auto &trait : def.getTraits()) { - if (auto *attrOrTypeTrait = dyn_cast(&trait)) { - if (std::optional extraDecl = - attrOrTypeTrait->getExtraConcreteClassDeclaration()) { - extraDeclarations += extraDecl->str() + "\n"; - } - } - } - if (std::optional extraDecl = def.getExtraDecls()) { - extraDeclarations += extraDecl->str(); - } - return extraDeclarations; -} - /// Extra class definitions have a `$cppClass` substitution that is to be /// replaced by the C++ class name. static std::string formatExtraDefinitions(const AttrOrTypeDef &def) { - std::string extraDefinitions = ""; - // Include extra class definitions from NativeTrait - for (const auto &trait : def.getTraits()) { - if (auto *attrOrTypeTrait = dyn_cast(&trait)) { - if (std::optional extraDef = - attrOrTypeTrait->getExtraConcreteClassDefinition()) { - extraDefinitions += extraDef->str() + "\n"; - } - } - } if (std::optional extraDef = def.getExtraDefs()) { - extraDefinitions += extraDef->str(); + FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName()); + return tgfmt(*extraDef, &ctx).str(); } - FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName()); - return tgfmt(extraDefinitions, &ctx).str(); + return ""; } void DefGen::emitTopLevelDeclarations() { @@ -258,9 +230,9 @@ defCls.declare("Base::Base"); // Emit the extra declarations first in case there's a definition in there. - std::string extraDecl = formatExtraDeclarations(def); + std::optional extraDecl = def.getExtraDecls(); std::string extraDef = formatExtraDefinitions(def); - defCls.declare(std::move(extraDecl), + defCls.declare(extraDecl ? *extraDecl : "", std::move(extraDef)); } diff --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/tools/mlir-tblgen/OpClass.h --- a/mlir/tools/mlir-tblgen/OpClass.h +++ b/mlir/tools/mlir-tblgen/OpClass.h @@ -25,7 +25,7 @@ /// - inheritance of `print` /// - a type alias for the associated adaptor class /// - OpClass(StringRef name, std::string extraClassDeclaration, + OpClass(StringRef name, StringRef extraClassDeclaration, std::string extraClassDefinition); /// Add an op trait. @@ -39,7 +39,7 @@ private: /// Hand-written extra class declarations. - std::string extraClassDeclaration; + StringRef extraClassDeclaration; /// Hand-written extra class definitions. std::string extraClassDefinition; /// The parent class, which also contains the traits to be inherited. diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp --- a/mlir/tools/mlir-tblgen/OpClass.cpp +++ b/mlir/tools/mlir-tblgen/OpClass.cpp @@ -15,10 +15,9 @@ // OpClass definitions //===----------------------------------------------------------------------===// -OpClass::OpClass(StringRef name, std::string extraClassDeclaration, +OpClass::OpClass(StringRef name, StringRef extraClassDeclaration, std::string extraClassDefinition) - : Class(name.str()), - extraClassDeclaration(std::move(extraClassDeclaration)), + : Class(name.str()), extraClassDeclaration(extraClassDeclaration), extraClassDefinition(std::move(extraClassDefinition)), parent(addParent("::mlir::Op")) { parent.addTemplateParam(getClassName().str()); @@ -38,6 +37,6 @@ void OpClass::finalize() { Class::finalize(); declare(Visibility::Public); - declare(extraClassDeclaration, + declare(extraClassDeclaration.str(), extraClassDefinition); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -853,45 +853,17 @@ emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name)); } -/// Include declarations specified on NativeTrait -static std::string formatExtraDeclarations(const Operator &op) { - std::string extraDeclarations = ""; - // Include extra class declarations from NativeTrait - for (const auto &trait : op.getTraits()) { - if (auto *opTrait = dyn_cast(&trait)) { - if (std::optional extraDecl = - opTrait->getExtraConcreteClassDeclaration()) { - extraDeclarations += extraDecl->str() + "\n"; - } - } - } - extraDeclarations += op.getExtraClassDeclaration().str(); - return extraDeclarations; -} - /// Op extra class definitions have a `$cppClass` substitution that is to be /// replaced by the C++ class name. -/// Include declarations specified on NativeTrait static std::string formatExtraDefinitions(const Operator &op) { - std::string extraDefinitions = ""; - // Include extra class definitions from NativeTrait - for (const auto &trait : op.getTraits()) { - if (auto *opTrait = dyn_cast(&trait)) { - if (std::optional extraDef = - opTrait->getExtraConcreteClassDefinition()) { - extraDefinitions += extraDef->str() + "\n"; - } - } - } - extraDefinitions += op.getExtraClassDefinition().str(); FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName()); - return tgfmt(extraDefinitions, &ctx).str(); + return tgfmt(op.getExtraClassDefinition(), &ctx).str(); } OpEmitter::OpEmitter(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter) : def(op.getDef()), op(op), - opClass(op.getCppClassName(), formatExtraDeclarations(op), + opClass(op.getCppClassName(), op.getExtraClassDeclaration(), formatExtraDefinitions(op)), staticVerifierEmitter(staticVerifierEmitter), emitHelper(op, /*emitForOp=*/true) {