diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1718,8 +1718,9 @@ // NativeOpTrait corresponds to the MLIR C++ OpTrait mechanism. The // purpose to wrap around C++ symbol string with this class is to make // traits specified for ops in TableGen less alien and more integrated. -class NativeOpTrait : OpTrait { - string trait = "::mlir::OpTrait::" # prop; +class NativeOpTrait : OpTrait { + string trait = name; + string cppNamespace = "::mlir::OpTrait"; } // ParamNativeOpTrait corresponds to the template-parameterized traits in the @@ -1852,6 +1853,7 @@ class OpInterfaceTrait : NativeOpTrait<""> { let trait = name # "::Trait"; + let cppNamespace = ""; // Specify the body of the verification function. `$_op` will be replaced with // the operation being verified. diff --git a/mlir/include/mlir/TableGen/OpTrait.h b/mlir/include/mlir/TableGen/OpTrait.h --- a/mlir/include/mlir/TableGen/OpTrait.h +++ b/mlir/include/mlir/TableGen/OpTrait.h @@ -63,7 +63,7 @@ class NativeOpTrait : public OpTrait { public: // Returns the trait corresponding to a C++ trait class. - StringRef getTrait() const; + std::string getTrait() const; static bool classof(const OpTrait *t) { return t->getKind() == Kind::Native; } }; diff --git a/mlir/lib/TableGen/OpTrait.cpp b/mlir/lib/TableGen/OpTrait.cpp --- a/mlir/lib/TableGen/OpTrait.cpp +++ b/mlir/lib/TableGen/OpTrait.cpp @@ -35,8 +35,11 @@ OpTrait::OpTrait(Kind kind, const llvm::Record *def) : def(def), kind(kind) {} -llvm::StringRef NativeOpTrait::getTrait() const { - return def->getValueAsString("trait"); +std::string NativeOpTrait::getTrait() const { + llvm::StringRef trait = def->getValueAsString("trait"); + llvm::StringRef cppNamespace = def->getValueAsString("cppNamespace"); + return cppNamespace.empty() ? trait.str() + : (cppNamespace + "::" + trait).str(); } llvm::StringRef InternalOpTrait::getTrait() const { diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -248,6 +248,19 @@ // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// Check native OpTrait usage +// --- + +def NS_TestTrait : NativeOpTrait<"TestTrait"> { + let cppNamespace = "SomeNamespace"; +} + +def NS_KWithTraitOp : NS_Op<"KWithTrait", [NS_TestTrait]>; + +// CHECK-LABEL: NS::KWithTraitOp declarations +// CHECK: class KWithTraitOp : public ::mlir::Op