diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1681,7 +1681,16 @@ //===----------------------------------------------------------------------===// // Trait represents a trait regarding an attribute, operation, or type. -class Trait; +class Trait { + // Additional code that will be added to the public part of the generated + // C++ code of the entity (attribute, operation, or type) declaration. + code extraInjectedClassDeclaration = ""; + + // Additional code that will be added to the generated source file. The + // generated code is placed inside the entity's C++ namespace. `$cppClass` is + // replaced by the entity's C++ class name. + code extraInjectedClassDefinition = ""; +} // Define a Trait corresponding to a list of Traits, this allows for specifying // a list of traits as trait. Avoids needing to do `[Traits, ...] # ListOfTraits diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -235,10 +235,10 @@ StringRef getAssemblyFormat() const; // Returns this op's extra class declaration code. - StringRef getExtraClassDeclaration() const; + std::string getExtraClassDeclaration() const; // Returns this op's extra class definition code. - StringRef getExtraClassDefinition() const; + std::string getExtraClassDefinition() const; // Returns the Tablegen definition this operator was constructed from. // TODO: do not expose the TableGen record, this is a temporary solution to diff --git a/mlir/include/mlir/TableGen/Trait.h b/mlir/include/mlir/TableGen/Trait.h --- a/mlir/include/mlir/TableGen/Trait.h +++ b/mlir/include/mlir/TableGen/Trait.h @@ -53,6 +53,9 @@ // Returns the Tablegen definition this operator was constructed from. const llvm::Record &getDef() const { return *def; } + StringRef getExtraInjectedClassDeclaration() const; + StringRef getExtraInjectedClassDefinition() const; + protected: // The TableGen definition of this trait. const llvm::Record *def; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -121,18 +121,34 @@ return results->getNumArgs(); } -StringRef Operator::getExtraClassDeclaration() const { +std::string Operator::getExtraClassDeclaration() const { constexpr auto attr = "extraClassDeclaration"; - if (def.isValueUnset(attr)) - return {}; - return def.getValueAsString(attr); + std::string value; + if (!def.isValueUnset(attr)) + value = def.getValueAsString(attr).str(); + for (auto &trait : getTraits()) { + StringRef traitDeclaration = trait.getExtraInjectedClassDeclaration(); + if (!traitDeclaration.empty()) { + value += "\n"; + value += traitDeclaration.str(); + } + } + return value; } -StringRef Operator::getExtraClassDefinition() const { +std::string Operator::getExtraClassDefinition() const { constexpr auto attr = "extraClassDefinition"; - if (def.isValueUnset(attr)) - return {}; - return def.getValueAsString(attr); + std::string value; + if (!def.isValueUnset(attr)) + value = def.getValueAsString(attr).str(); + for (auto &trait : getTraits()) { + StringRef traitDefinition = trait.getExtraInjectedClassDefinition(); + if (!traitDefinition.empty()) { + value += "\n"; + value += traitDefinition.str(); + } + } + return value; } const llvm::Record &Operator::getDef() const { return def; } diff --git a/mlir/lib/TableGen/Trait.cpp b/mlir/lib/TableGen/Trait.cpp --- a/mlir/lib/TableGen/Trait.cpp +++ b/mlir/lib/TableGen/Trait.cpp @@ -39,6 +39,20 @@ Trait::Trait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {} +StringRef Trait::getExtraInjectedClassDeclaration() const { + constexpr auto attr = "extraInjectedClassDeclaration"; + if (def->isValueUnset(attr)) + return {}; + return def->getValueAsString(attr); +} + +StringRef Trait::getExtraInjectedClassDefinition() const { + constexpr auto attr = "extraInjectedClassDefinition"; + if (def->isValueUnset(attr)) + return {}; + return def->getValueAsString(attr); +} + //===----------------------------------------------------------------------===// // NativeTrait //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -479,6 +479,24 @@ // Test Traits //===----------------------------------------------------------------------===// +def OpTraitWithInjectedCode : NativeOpTrait<"SameOperandsElementType"> { + code extraInjectedClassDeclaration = [{ + bool testInjection(); + }]; + code extraInjectedClassDefinition = [{ + bool $cppClass::testInjection() { + return true; + } + }]; +} + +def SameOperandElementTypeInjectedOp : TEST_Op<"same_operand_element_type_injected_code", + [OpTraitWithInjectedCode]> { + let arguments = (ins AnyType, AnyType); + let results = (outs AnyType); +} + + def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type", [SameOperandsElementType]> { let arguments = (ins AnyType, AnyType); diff --git a/mlir/tools/mlir-tblgen/OpClass.h b/mlir/tools/mlir-tblgen/OpClass.h --- a/mlir/tools/mlir-tblgen/OpClass.h +++ b/mlir/tools/mlir-tblgen/OpClass.h @@ -25,7 +25,7 @@ /// - inheritance of `print` /// - a type alias for the associated adaptor class /// - OpClass(StringRef name, StringRef extraClassDeclaration, + OpClass(StringRef name, std::string extraClassDeclaration, std::string extraClassDefinition); /// Add an op trait. @@ -39,7 +39,7 @@ private: /// Hand-written extra class declarations. - StringRef extraClassDeclaration; + std::string extraClassDeclaration; /// Hand-written extra class definitions. std::string extraClassDefinition; /// The parent class, which also contains the traits to be inherited. diff --git a/mlir/tools/mlir-tblgen/OpClass.cpp b/mlir/tools/mlir-tblgen/OpClass.cpp --- a/mlir/tools/mlir-tblgen/OpClass.cpp +++ b/mlir/tools/mlir-tblgen/OpClass.cpp @@ -15,9 +15,10 @@ // OpClass definitions //===----------------------------------------------------------------------===// -OpClass::OpClass(StringRef name, StringRef extraClassDeclaration, +OpClass::OpClass(StringRef name, std::string extraClassDeclaration, std::string extraClassDefinition) - : Class(name.str()), extraClassDeclaration(extraClassDeclaration), + : Class(name.str()), + extraClassDeclaration(std::move(extraClassDeclaration)), extraClassDefinition(std::move(extraClassDefinition)), parent(addParent("::mlir::Op")) { parent.addTemplateParam(getClassName().str());