diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2123,6 +2123,9 @@ // The mnemonic of the op. string opName = mnemonic; + // The C++ namespace to use for this op. + string cppNamespace = dialect.cppNamespace; + // One-line human-readable description of what the op does. string summary = ""; diff --git a/mlir/include/mlir/TableGen/CodeGenHelpers.h b/mlir/include/mlir/TableGen/CodeGenHelpers.h --- a/mlir/include/mlir/TableGen/CodeGenHelpers.h +++ b/mlir/include/mlir/TableGen/CodeGenHelpers.h @@ -45,6 +45,11 @@ for (StringRef ns : namespaces) os << "namespace " << ns << " {\n"; } + NamespaceEmitter(raw_ostream &os, StringRef cppNamespace) : os(os) { + llvm::SplitString(cppNamespace, namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + } ~NamespaceEmitter() { for (StringRef ns : llvm::reverse(namespaces)) diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -58,6 +58,9 @@ // Returns this op's C++ class name prefixed with namespaces. std::string getQualCppClassName() const; + // Returns this op's C++ namespace. + StringRef getCppNamespace() const; + // Returns the name of op's adaptor C++ class. std::string getAdaptorName() const; @@ -304,6 +307,9 @@ // The unqualified C++ class name of the op. StringRef cppClassName; + // The C++ namespace for this op. + StringRef cppNamespace; + // The operands of the op. SmallVector operands; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -50,6 +50,8 @@ cppClassName = prefix; } + cppNamespace = def.getValueAsString("cppNamespace"); + populateOpStructure(); } @@ -70,12 +72,13 @@ StringRef Operator::getCppClassName() const { return cppClassName; } std::string Operator::getQualCppClassName() const { - auto prefix = dialect.getCppNamespace(); - if (prefix.empty()) + if (cppNamespace.empty()) return std::string(cppClassName); - return std::string(llvm::formatv("{0}::{1}", prefix, cppClassName)); + return std::string(llvm::formatv("{0}::{1}", cppNamespace, cppClassName)); } +StringRef Operator::getCppNamespace() const { return cppNamespace; } + int Operator::getNumResults() const { DagInit *results = def.getValueAsDag("results"); return results->getNumArgs(); diff --git a/mlir/test/mlir-tblgen/dialect.td b/mlir/test/mlir-tblgen/dialect.td --- a/mlir/test/mlir-tblgen/dialect.td +++ b/mlir/test/mlir-tblgen/dialect.td @@ -34,20 +34,37 @@ def D_DSomeOp : Op; +// Check op with namespace override. +def E_Dialect : Dialect { + let name = "e"; + let cppNamespace = "ENS"; +} + +def E_SomeOp : Op; +def E_SpecialNSOp : Op { + let cppNamespace = "::E::SPECIAL_NS"; +} + // DEF-LABEL: GET_OP_LIST // DEF: a::SomeOp // DEF-NEXT: BNS::SomeOp // DEF-NEXT: ::C::CC::SomeOp // DEF-NEXT: DSomeOp +// DEF-NEXT: ENS::SomeOp +// DEF-NEXT: ::E::SPECIAL_NS::SpecialNSOp // DEF-LABEL: GET_OP_CLASSES // DEF: a::SomeOp definitions // DEF: BNS::SomeOp definitions // DEF: ::C::CC::SomeOp definitions // DEF: DSomeOp definitions +// DEF: ENS::SomeOp definitions +// DEF: ::E::SPECIAL_NS::SpecialNSOp definitions // DECL-LABEL: GET_OP_CLASSES // DECL: a::SomeOp declarations // DECL: BNS::SomeOp declarations // DECL: ::C::CC::SomeOp declarations // DECL: DSomeOp declarations +// DECL: ENS::SomeOp declarations +// DECL: ::E::SPECIAL_NS::SpecialNSOp declarations diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -174,7 +174,7 @@ llvm::Optional namespaceEmitter; if (!emitDecl) { os << formatv(opCommentHeader, "Local Utility Method", "Definitions"); - namespaceEmitter.emplace(os, Operator(*opDefs[0]).getDialect()); + namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace()); } emitTypeConstraintMethods(opDefs, os, emitDecl); @@ -2423,7 +2423,7 @@ os << "#undef GET_OP_FWD_DEFINES\n"; for (auto *def : defs) { Operator op(*def); - NamespaceEmitter emitter(os, op.getDialect()); + NamespaceEmitter emitter(os, op.getCppNamespace()); os << "class " << op.getCppClassName() << ";\n"; } os << "#endif\n\n"; @@ -2438,7 +2438,7 @@ emitDecl); for (auto *def : defs) { Operator op(*def); - NamespaceEmitter emitter(os, op.getDialect()); + NamespaceEmitter emitter(os, op.getCppNamespace()); if (emitDecl) { os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); OpOperandAdaptorEmitter::emitDecl(op, os);