diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -903,7 +903,7 @@ Pure, SameVariadicResultSize, ViewLikeOpInterface, - DeclareOpInterfaceMethods]> { + InferTypeOpInterfaceAdaptor]> { let summary = "Extracts a buffer base with offset and strides"; let description = [{ Extracts a base buffer, offset and strides. This op allows additional layers diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -21,7 +21,14 @@ //===----------------------------------------------------------------------===// // These classes are used to define attribute specific traits. -class NativeAttrTrait : NativeTrait; +class NativeAttrTrait + : NativeTrait { + let extraConcreteClassDeclaration = extraAttrDeclaration; + let extraConcreteClassDefinition = extraAttrDefinition; +} + class ParamNativeAttrTrait : ParamNativeTrait; class GenInternalAttrTrait : GenInternalTrait; @@ -32,7 +39,14 @@ //===----------------------------------------------------------------------===// // These classes are used to define type specific traits. -class NativeTypeTrait : NativeTrait; +class NativeTypeTrait + : NativeTrait { + let extraConcreteClassDeclaration = extraTypeDeclaration; + let extraConcreteClassDefinition = extraTypeDefinition; +} + class ParamNativeTypeTrait : ParamNativeTrait; class GenInternalTypeTrait : GenInternalTrait; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1959,6 +1959,9 @@ class NativeTrait : Trait { string trait = name; string cppNamespace = "::mlir::" # entityType # "Trait"; + + code extraConcreteClassDeclaration = ""; + code extraConcreteClassDefinition = ""; } // ParamNativeTrait corresponds to the template-parameterized traits in the C++ @@ -1991,11 +1994,16 @@ class StructuralOpTrait; // These classes are used to define operation specific traits. -class NativeOpTrait traits = []> +class NativeOpTrait traits = [], + code extraOpDeclaration = [{}], + code extraOpDefinition = [{}]> : NativeTrait { // Specify the list of traits that need to be verified before the verification // of this NativeOpTrait. list dependentTraits = traits; + + let extraConcreteClassDeclaration = extraOpDeclaration; + let extraConcreteClassDefinition = extraOpDefinition; } class ParamNativeOpTrait traits = []> diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -241,14 +241,7 @@ // TODO: Consider generating typedefs for trait member functions if this usage // becomes more common. LogicalResult inferReturnTensorTypes( - function_ref< - LogicalResult(MLIRContext *, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &retComponents)> - componentTypeFn, - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVector retComponents, SmallVectorImpl &inferredReturnTypes); /// Verifies that the inferred result types match the actual result types for @@ -268,6 +261,10 @@ namespace mlir { namespace OpTrait { +template +class InferTypeOpInterfaceAdaptor : + public TraitBase {}; + /// Tensor type inference trait that constructs a tensor from the inferred /// shape and elemental types. /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface. @@ -276,24 +273,7 @@ /// trait is currently only used where the interfaces are, so keep it /// restricted for now). template -class InferTensorType : public TraitBase { -public: - static LogicalResult - inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - static_assert( - ConcreteType::template hasTrait(), - "requires InferShapedTypeOpInterface to ensure succesful invocation"); - static_assert( - ConcreteType::template hasTrait(), - "requires InferTypeOpInterface to ensure succesful invocation"); - return ::mlir::detail::inferReturnTensorTypes( - ConcreteType::inferReturnTypeComponents, context, location, operands, - attributes, properties, regions, inferredReturnTypes); - } -}; +class InferTensorType : public TraitBase {}; } // namespace OpTrait } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -184,18 +184,69 @@ ]; } +// Convenient trait to define a wrapper to inferReturnTypes that passes in the +// Op Adaptor directly +def InferTypeOpInterfaceAdaptor : TraitList< + [ + // Op implements infer type op interface. + DeclareOpInterfaceMethods, + NativeOpTrait< + /*name=*/"InferTypeOpInterfaceAdaptor", + /*traits=*/[], + /*extraOpDeclaration=*/[{ + static LogicalResult + inferReturnTypesAdaptor(MLIRContext *context, + std::optional location, + Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes); + }], + /*extraOpDefinition=*/[{ + LogicalResult + $cppClass::inferReturnTypes(MLIRContext *context, + std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + $cppClass::Adaptor adaptor(operands, attributes, properties, regions); + return $cppClass::inferReturnTypesAdaptor(context, + location, adaptor, inferredReturnTypes); + } + }] + > + ]>; + // Convenience class grouping together type and shaped type op interfaces for // ops that have tensor return types. class InferTensorTypeBase overridenMethods = []> : TraitList< [ // Op implements infer type op interface. - InferTypeOpInterface, + DeclareOpInterfaceMethods, // The op will have methods implementing the ShapedType type inference // interface. DeclareOpInterfaceMethods, // The op produces tensors and will use the ShapedType type infer interface // along with knowledge that it is producing Tensors to infer the type. - NativeOpTrait<"InferTensorType"> + NativeOpTrait< + /*name=*/"InferTensorType", + /*traits=*/[], + /*extraOpDeclaration=*/[{}], + /*extraOpDefinition=*/[{ + LogicalResult + $cppClass::inferReturnTypes(MLIRContext *context, + std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + SmallVector retComponents; + if (failed($cppClass::inferReturnTypeComponents(context, location, + operands, attributes, properties, regions, + retComponents))) + return failure(); + return ::mlir::detail::inferReturnTensorTypes(retComponents, + inferredReturnTypes); + } + }] + > ]>; def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>; diff --git a/mlir/include/mlir/TableGen/Trait.h b/mlir/include/mlir/TableGen/Trait.h --- a/mlir/include/mlir/TableGen/Trait.h +++ b/mlir/include/mlir/TableGen/Trait.h @@ -68,6 +68,14 @@ // Returns if this is a structural op trait. bool isStructuralOpTrait() const; + // Returns extra class declaration code to be added to the concrete type + // when the trait is specified + StringRef getExtraConcreteClassDeclaration() const; + + // Returns extra class definition code to be added to the concrete type + // when the trait is specified + StringRef getExtraConcreteClassDefinition() const; + static bool classof(const Trait *t) { return t->getKind() == Kind::Native; } }; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1355,14 +1355,12 @@ /// The number and type of the results are inferred from the /// shape of the source. -LogicalResult ExtractStridedMetadataOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, +LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor( + MLIRContext *context, std::optional location, + ExtractStridedMetadataOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { - ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, - properties); auto sourceType = - llvm::dyn_cast(extractAdaptor.getSource().getType()); + llvm::dyn_cast(adaptor.getSource().getType()); if (!sourceType) return failure(); diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -217,19 +217,8 @@ } LogicalResult mlir::detail::inferReturnTensorTypes( - function_ref< - LogicalResult(MLIRContext *, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &retComponents)> - componentTypeFn, - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVector retComponents, SmallVectorImpl &inferredReturnTypes) { - SmallVector retComponents; - if (failed(componentTypeFn(context, location, operands, attributes, - properties, regions, retComponents))) - return failure(); for (const auto &shapeAndType : retComponents) { Type elementTy = shapeAndType.getElementType(); assert(elementTy && "element type required to construct tensor"); diff --git a/mlir/lib/TableGen/Trait.cpp b/mlir/lib/TableGen/Trait.cpp --- a/mlir/lib/TableGen/Trait.cpp +++ b/mlir/lib/TableGen/Trait.cpp @@ -54,6 +54,14 @@ return def->isSubClassOf("StructuralOpTrait"); } +StringRef NativeTrait::getExtraConcreteClassDeclaration() const { + return def->getValueAsString("extraConcreteClassDeclaration"); +} + +StringRef NativeTrait::getExtraConcreteClassDefinition() const { + return def->getValueAsString("extraConcreteClassDefinition"); +} + //===----------------------------------------------------------------------===// // InternalTrait //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -214,14 +214,44 @@ defCls.addParent(std::move(defParent)); } +/// Include declarations specified on NativeTrait +static std::string formatExtraDeclarations(const AttrOrTypeDef &def) { + std::string extraDeclarations = ""; + // Include extra class declarations from NativeTrait + for (const auto &trait : def.getTraits()) { + if (auto *attrOrTypeTrait = dyn_cast(&trait)) { + if (auto *opTrait = dyn_cast(&trait)) { + StringRef value = opTrait->getExtraConcreteClassDeclaration(); + if (value.empty()) continue; + extraDeclarations += value.str() + "\n"; + } + } + } + if (std::optional extraDecl = def.getExtraDecls()) { + extraDeclarations += extraDecl->str(); + } + return extraDeclarations; +} + /// Extra class definitions have a `$cppClass` substitution that is to be /// replaced by the C++ class name. static std::string formatExtraDefinitions(const AttrOrTypeDef &def) { + std::string extraDefinitions = ""; + // Include extra class definitions from NativeTrait + for (const auto &trait : def.getTraits()) { + if (auto *attrOrTypeTrait = dyn_cast(&trait)) { + if (auto *opTrait = dyn_cast(&trait)) { + StringRef value = opTrait->getExtraConcreteClassDefinition(); + if (value.empty()) continue; + extraDefinitions += value.str() + "\n"; + } + } + } if (std::optional extraDef = def.getExtraDefs()) { - FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName()); - return tgfmt(*extraDef, &ctx).str(); + extraDefinitions += extraDef->str(); } - return ""; + FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName()); + return tgfmt(extraDefinitions, &ctx).str(); } void DefGen::emitTopLevelDeclarations() { @@ -230,9 +260,9 @@ defCls.declare("Base::Base"); // Emit the extra declarations first in case there's a definition in there. - std::optional extraDecl = def.getExtraDecls(); + std::string extraDecl = formatExtraDeclarations(def); std::string extraDef = formatExtraDefinitions(def); - defCls.declare(extraDecl ? *extraDecl : "", + defCls.declare(std::move(extraDecl), std::move(extraDef)); } diff --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/tools/mlir-tblgen/OpClass.h --- a/mlir/tools/mlir-tblgen/OpClass.h +++ b/mlir/tools/mlir-tblgen/OpClass.h @@ -25,7 +25,7 @@ /// - inheritance of `print` /// - a type alias for the associated adaptor class /// - OpClass(StringRef name, StringRef extraClassDeclaration, + OpClass(StringRef name, std::string extraClassDeclaration, std::string extraClassDefinition); /// Add an op trait. @@ -39,7 +39,7 @@ private: /// Hand-written extra class declarations. - StringRef extraClassDeclaration; + std::string extraClassDeclaration; /// Hand-written extra class definitions. std::string extraClassDefinition; /// The parent class, which also contains the traits to be inherited. diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp --- a/mlir/tools/mlir-tblgen/OpClass.cpp +++ b/mlir/tools/mlir-tblgen/OpClass.cpp @@ -15,9 +15,10 @@ // OpClass definitions //===----------------------------------------------------------------------===// -OpClass::OpClass(StringRef name, StringRef extraClassDeclaration, +OpClass::OpClass(StringRef name, std::string extraClassDeclaration, std::string extraClassDefinition) - : Class(name.str()), extraClassDeclaration(extraClassDeclaration), + : Class(name.str()), + extraClassDeclaration(std::move(extraClassDeclaration)), extraClassDefinition(std::move(extraClassDefinition)), parent(addParent("::mlir::Op")) { parent.addTemplateParam(getClassName().str()); @@ -37,6 +38,5 @@ void OpClass::finalize() { Class::finalize(); declare(Visibility::Public); - declare(extraClassDeclaration.str(), - extraClassDefinition); + declare(extraClassDeclaration, extraClassDefinition); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -853,17 +853,43 @@ emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name)); } +/// Include declarations specified on NativeTrait +static std::string formatExtraDeclarations(const Operator &op) { + std::string extraDeclarations = ""; + // Include extra class declarations from NativeTrait + for (const auto &trait : op.getTraits()) { + if (auto *opTrait = dyn_cast(&trait)) { + StringRef value = opTrait->getExtraConcreteClassDeclaration(); + if (value.empty()) continue; + extraDeclarations += value.str() + "\n"; + } + } + extraDeclarations += op.getExtraClassDeclaration().str(); + return extraDeclarations; +} + /// Op extra class definitions have a `$cppClass` substitution that is to be /// replaced by the C++ class name. +/// Include declarations specified on NativeTrait static std::string formatExtraDefinitions(const Operator &op) { + std::string extraDefinitions = ""; + // Include extra class definitions from NativeTrait + for (const auto &trait : op.getTraits()) { + if (auto *opTrait = dyn_cast(&trait)) { + StringRef value = opTrait->getExtraConcreteClassDefinition(); + if (value.empty()) continue; + extraDefinitions += value.str() + "\n"; + } + } + extraDefinitions += op.getExtraClassDefinition().str(); FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName()); - return tgfmt(op.getExtraClassDefinition(), &ctx).str(); + return tgfmt(extraDefinitions, &ctx).str(); } OpEmitter::OpEmitter(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter) : def(op.getDef()), op(op), - opClass(op.getCppClassName(), op.getExtraClassDeclaration(), + opClass(op.getCppClassName(), formatExtraDeclarations(op), formatExtraDefinitions(op)), staticVerifierEmitter(staticVerifierEmitter), emitHelper(op, /*emitForOp=*/true) {