diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -100,6 +100,22 @@ `foldTrait` hook out-of-line as a free function when possible to avoid instantiating the implementation for every concrete operation type. +### Extra Declarations and Definitions +A trait may require additional declarations and definitions directly on +the Operation, Attribute or Type instances which specify that trait. +The `extraConcreteClassDeclaration` and `extraConcreteClassDefinition` +fields under the `NativeTrait` class are mechanisms designed for injecting +code directly into generated C++ Operation, Attribute or Type classes. + +Code within the `extraConcreteClassDeclaration` field will be formatted and copied +into the generated C++ Operation, Attribute or Type class. Code within +`extraConcreteClassDefinition` will be added to the generated source file inside +the class’s C++ namespace. The substitution `$cppClass` is replaced by the C++ class +name. + +The intention is to group trait specific logic together and reduce +redundant extra declarations and definitions on the instances themselves. + ### Parametric Traits The above demonstrates the definition of a simple self-contained trait. It is diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -903,7 +903,7 @@ Pure, SameVariadicResultSize, ViewLikeOpInterface, - DeclareOpInterfaceMethods]> { + InferTypeOpInterfaceAdaptor]> { let summary = "Extracts a buffer base with offset and strides"; let description = [{ Extracts a base buffer, offset and strides. This op allows additional layers diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -21,7 +21,14 @@ //===----------------------------------------------------------------------===// // These classes are used to define attribute specific traits. -class NativeAttrTrait : NativeTrait; + +// Specify attribute specific declarations and definitions in `extraAttrDeclaration` +// and `extraAttrDefinition` template arguments. +class NativeAttrTrait + : NativeTrait; + class ParamNativeAttrTrait : ParamNativeTrait; class GenInternalAttrTrait : GenInternalTrait; @@ -32,7 +39,14 @@ //===----------------------------------------------------------------------===// // These classes are used to define type specific traits. -class NativeTypeTrait : NativeTrait; + +// Specify type specific declarations and definitions in `extraTypeDeclaration` +// and `extraTypeDefinition` template arguments. +class NativeTypeTrait + : NativeTrait; + class ParamNativeTypeTrait : ParamNativeTrait; class GenInternalTypeTrait : GenInternalTrait; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1958,9 +1958,16 @@ // NativeTrait corresponds to the MLIR C++ trait mechanism. The purpose to wrap // around C++ symbol string with this class is to make traits specified for // entities in TableGen less alien and more integrated. -class NativeTrait : Trait { +// `extraConcreteClassDeclaration` and `extraConcreteClassDefinition` code +// get injected into the entities in which the NativeTrait is specified for. +class NativeTrait : Trait { string trait = name; string cppNamespace = "::mlir::" # entityType # "Trait"; + + code extraConcreteClassDeclaration = extraClassDeclaration; + code extraConcreteClassDefinition = extraClassDefinition; } // ParamNativeTrait corresponds to the template-parameterized traits in the C++ @@ -1993,8 +2000,13 @@ class StructuralOpTrait; // These classes are used to define operation specific traits. -class NativeOpTrait traits = []> - : NativeTrait { + +// Specify op specific declarations and definitions in `extraOpDeclaration` +// and `extraOpDefinition` template arguments. +class NativeOpTrait traits = [], + code extraOpDeclaration = [{}], + code extraOpDefinition = [{}]> + : NativeTrait { // Specify the list of traits that need to be verified before the verification // of this NativeOpTrait. list dependentTraits = traits; diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -237,19 +237,9 @@ namespace detail { // Helper function to infer return tensor returns types given element and // shape inference function. -// -// TODO: Consider generating typedefs for trait member functions if this usage -// becomes more common. -LogicalResult inferReturnTensorTypes( - function_ref< - LogicalResult(MLIRContext *, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &retComponents)> - componentTypeFn, - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes); +LogicalResult +inferReturnTensorTypes(ArrayRef retComponents, + SmallVectorImpl &inferredReturnTypes); /// Verifies that the inferred result types match the actual result types for /// the op. Precondition: op implements InferTypeOpInterface. @@ -268,6 +258,10 @@ namespace mlir { namespace OpTrait { +template +class InferTypeOpInterfaceAdaptor + : public TraitBase {}; + /// Tensor type inference trait that constructs a tensor from the inferred /// shape and elemental types. /// Requires: Op implements InferShapedTypeOpInterface and InferTypeOpInterface. @@ -276,24 +270,7 @@ /// trait is currently only used where the interfaces are, so keep it /// restricted for now). template -class InferTensorType : public TraitBase { -public: - static LogicalResult - inferReturnTypes(MLIRContext *context, std::optional location, - ValueRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - static_assert( - ConcreteType::template hasTrait(), - "requires InferShapedTypeOpInterface to ensure succesful invocation"); - static_assert( - ConcreteType::template hasTrait(), - "requires InferTypeOpInterface to ensure succesful invocation"); - return ::mlir::detail::inferReturnTensorTypes( - ConcreteType::inferReturnTypeComponents, context, location, operands, - attributes, properties, regions, inferredReturnTypes); - } -}; +class InferTensorType : public TraitBase {}; } // namespace OpTrait } // namespace mlir diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -184,18 +184,69 @@ ]; } +// Convenient trait to define a wrapper to inferReturnTypes that passes in the +// Op Adaptor directly +def InferTypeOpInterfaceAdaptor : TraitList< + [ + // Op implements infer type op interface. + DeclareOpInterfaceMethods, + NativeOpTrait< + /*name=*/"InferTypeOpInterfaceAdaptor", + /*traits=*/[], + /*extraOpDeclaration=*/[{ + static LogicalResult + inferReturnTypesAdaptor(MLIRContext *context, + std::optional location, + Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes); + }], + /*extraOpDefinition=*/[{ + LogicalResult + $cppClass::inferReturnTypes(MLIRContext *context, + std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + $cppClass::Adaptor adaptor(operands, attributes, properties, regions); + return $cppClass::inferReturnTypesAdaptor(context, + location, adaptor, inferredReturnTypes); + } + }] + > + ]>; + // Convenience class grouping together type and shaped type op interfaces for // ops that have tensor return types. class InferTensorTypeBase overridenMethods = []> : TraitList< [ // Op implements infer type op interface. - InferTypeOpInterface, + DeclareOpInterfaceMethods, // The op will have methods implementing the ShapedType type inference // interface. DeclareOpInterfaceMethods, // The op produces tensors and will use the ShapedType type infer interface // along with knowledge that it is producing Tensors to infer the type. - NativeOpTrait<"InferTensorType"> + NativeOpTrait< + /*name=*/"InferTensorType", + /*traits=*/[], + /*extraOpDeclaration=*/[{}], + /*extraOpDefinition=*/[{ + LogicalResult + $cppClass::inferReturnTypes(MLIRContext *context, + std::optional location, + ValueRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + SmallVector retComponents; + if (failed($cppClass::inferReturnTypeComponents(context, location, + operands, attributes, properties, regions, + retComponents))) + return failure(); + return ::mlir::detail::inferReturnTensorTypes(retComponents, + inferredReturnTypes); + } + }] + > ]>; def InferTensorType : InferTensorTypeBase<["inferReturnTypeComponents"]>; diff --git a/mlir/include/mlir/TableGen/Trait.h b/mlir/include/mlir/TableGen/Trait.h --- a/mlir/include/mlir/TableGen/Trait.h +++ b/mlir/include/mlir/TableGen/Trait.h @@ -68,6 +68,14 @@ // Returns if this is a structural op trait. bool isStructuralOpTrait() const; + // Returns extra class declaration code to be added to the concrete instance + // when the trait is specified + StringRef getExtraConcreteClassDeclaration() const; + + // Returns extra class definition code to be added to the concrete instance + // when the trait is specified + StringRef getExtraConcreteClassDefinition() const; + static bool classof(const Trait *t) { return t->getKind() == Kind::Native; } }; diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1355,14 +1355,11 @@ /// The number and type of the results are inferred from the /// shape of the source. -LogicalResult ExtractStridedMetadataOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, +LogicalResult ExtractStridedMetadataOp::inferReturnTypesAdaptor( + MLIRContext *context, std::optional location, + ExtractStridedMetadataOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { - ExtractStridedMetadataOpAdaptor extractAdaptor(operands, attributes, - properties); - auto sourceType = - llvm::dyn_cast(extractAdaptor.getSource().getType()); + auto sourceType = llvm::dyn_cast(adaptor.getSource().getType()); if (!sourceType) return failure(); diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -217,19 +217,8 @@ } LogicalResult mlir::detail::inferReturnTensorTypes( - function_ref< - LogicalResult(MLIRContext *, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &retComponents)> - componentTypeFn, - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + ArrayRef retComponents, SmallVectorImpl &inferredReturnTypes) { - SmallVector retComponents; - if (failed(componentTypeFn(context, location, operands, attributes, - properties, regions, retComponents))) - return failure(); for (const auto &shapeAndType : retComponents) { Type elementTy = shapeAndType.getElementType(); assert(elementTy && "element type required to construct tensor"); diff --git a/mlir/lib/TableGen/Trait.cpp b/mlir/lib/TableGen/Trait.cpp --- a/mlir/lib/TableGen/Trait.cpp +++ b/mlir/lib/TableGen/Trait.cpp @@ -54,6 +54,14 @@ return def->isSubClassOf("StructuralOpTrait"); } +StringRef NativeTrait::getExtraConcreteClassDeclaration() const { + return def->getValueAsString("extraConcreteClassDeclaration"); +} + +StringRef NativeTrait::getExtraConcreteClassDefinition() const { + return def->getValueAsString("extraConcreteClassDefinition"); +} + //===----------------------------------------------------------------------===// // InternalTrait //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -214,14 +214,42 @@ defCls.addParent(std::move(defParent)); } +/// Include declarations specified on NativeTrait +static std::string formatExtraDeclarations(const AttrOrTypeDef &def) { + SmallVector extraDeclarations; + // Include extra class declarations from NativeTrait + for (const auto &trait : def.getTraits()) { + if (auto *attrOrTypeTrait = dyn_cast(&trait)) { + StringRef value = attrOrTypeTrait->getExtraConcreteClassDeclaration(); + if (value.empty()) + continue; + extraDeclarations.push_back(value); + } + } + if (std::optional extraDecl = def.getExtraDecls()) { + extraDeclarations.push_back(*extraDecl); + } + return llvm::join(extraDeclarations, "\n"); +} + /// Extra class definitions have a `$cppClass` substitution that is to be /// replaced by the C++ class name. static std::string formatExtraDefinitions(const AttrOrTypeDef &def) { + SmallVector extraDefinitions; + // Include extra class definitions from NativeTrait + for (const auto &trait : def.getTraits()) { + if (auto *attrOrTypeTrait = dyn_cast(&trait)) { + StringRef value = attrOrTypeTrait->getExtraConcreteClassDefinition(); + if (value.empty()) + continue; + extraDefinitions.push_back(value); + } + } if (std::optional extraDef = def.getExtraDefs()) { - FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName()); - return tgfmt(*extraDef, &ctx).str(); + extraDefinitions.push_back(*extraDef); } - return ""; + FmtContext ctx = FmtContext().addSubst("cppClass", def.getCppClassName()); + return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str(); } void DefGen::emitTopLevelDeclarations() { @@ -230,9 +258,9 @@ defCls.declare("Base::Base"); // Emit the extra declarations first in case there's a definition in there. - std::optional extraDecl = def.getExtraDecls(); + std::string extraDecl = formatExtraDeclarations(def); std::string extraDef = formatExtraDefinitions(def); - defCls.declare(extraDecl ? *extraDecl : "", + defCls.declare(std::move(extraDecl), std::move(extraDef)); } diff --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/tools/mlir-tblgen/OpClass.h --- a/mlir/tools/mlir-tblgen/OpClass.h +++ b/mlir/tools/mlir-tblgen/OpClass.h @@ -25,7 +25,7 @@ /// - inheritance of `print` /// - a type alias for the associated adaptor class /// - OpClass(StringRef name, StringRef extraClassDeclaration, + OpClass(StringRef name, std::string extraClassDeclaration, std::string extraClassDefinition); /// Add an op trait. @@ -39,7 +39,7 @@ private: /// Hand-written extra class declarations. - StringRef extraClassDeclaration; + std::string extraClassDeclaration; /// Hand-written extra class definitions. std::string extraClassDefinition; /// The parent class, which also contains the traits to be inherited. diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp --- a/mlir/tools/mlir-tblgen/OpClass.cpp +++ b/mlir/tools/mlir-tblgen/OpClass.cpp @@ -15,9 +15,10 @@ // OpClass definitions //===----------------------------------------------------------------------===// -OpClass::OpClass(StringRef name, StringRef extraClassDeclaration, +OpClass::OpClass(StringRef name, std::string extraClassDeclaration, std::string extraClassDefinition) - : Class(name.str()), extraClassDeclaration(extraClassDeclaration), + : Class(name.str()), + extraClassDeclaration(std::move(extraClassDeclaration)), extraClassDefinition(std::move(extraClassDefinition)), parent(addParent("::mlir::Op")) { parent.addTemplateParam(getClassName().str()); @@ -37,6 +38,5 @@ void OpClass::finalize() { Class::finalize(); declare(Visibility::Public); - declare(extraClassDeclaration.str(), - extraClassDefinition); + declare(extraClassDeclaration, extraClassDefinition); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -853,17 +853,45 @@ emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name)); } +/// Include declarations specified on NativeTrait +static std::string formatExtraDeclarations(const Operator &op) { + SmallVector extraDeclarations; + // Include extra class declarations from NativeTrait + for (const auto &trait : op.getTraits()) { + if (auto *opTrait = dyn_cast(&trait)) { + StringRef value = opTrait->getExtraConcreteClassDeclaration(); + if (value.empty()) + continue; + extraDeclarations.push_back(value); + } + } + extraDeclarations.push_back(op.getExtraClassDeclaration()); + return llvm::join(extraDeclarations, "\n"); +} + /// Op extra class definitions have a `$cppClass` substitution that is to be /// replaced by the C++ class name. +/// Include declarations specified on NativeTrait static std::string formatExtraDefinitions(const Operator &op) { + SmallVector extraDefinitions; + // Include extra class definitions from NativeTrait + for (const auto &trait : op.getTraits()) { + if (auto *opTrait = dyn_cast(&trait)) { + StringRef value = opTrait->getExtraConcreteClassDefinition(); + if (value.empty()) + continue; + extraDefinitions.push_back(value); + } + } + extraDefinitions.push_back(op.getExtraClassDefinition()); FmtContext ctx = FmtContext().addSubst("cppClass", op.getCppClassName()); - return tgfmt(op.getExtraClassDefinition(), &ctx).str(); + return tgfmt(llvm::join(extraDefinitions, "\n"), &ctx).str(); } OpEmitter::OpEmitter(const Operator &op, const StaticVerifierFunctionEmitter &staticVerifierEmitter) : def(op.getDef()), op(op), - opClass(op.getCppClassName(), op.getExtraClassDeclaration(), + opClass(op.getCppClassName(), formatExtraDeclarations(op), formatExtraDefinitions(op)), staticVerifierEmitter(staticVerifierEmitter), emitHelper(op, /*emitForOp=*/true) {