diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -345,7 +345,8 @@ Or, !if(!eq(summary, ""), !interleave(!foreach(t, allowedTypes, t.summary), " or "), - summary)>; + summary), + cppClassName>; // Integer types. diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -49,7 +49,7 @@ Optional getBuilderCall() const; // Return the C++ class name for this type (which may just be ::mlir::Type). - StringRef getCPPClassName() const; + std::string getCPPClassName() const; }; // Wrapper class with helper methods for accessing Types defined in TableGen. diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Type.h" +#include "mlir/TableGen/Dialect.h" +#include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/TableGen/Record.h" @@ -54,8 +56,19 @@ } // Return the C++ class name for this type (which may just be ::mlir::Type). -StringRef TypeConstraint::getCPPClassName() const { - return def->getValueAsString("cppClassName"); +std::string TypeConstraint::getCPPClassName() const { + StringRef className = def->getValueAsString("cppClassName"); + + // If the class name is already namespace resolved, use it. + if (className.contains("::")) + return className.str(); + + // Otherwise, check to see if there is a namespace from a dialect to prepend. + if (const llvm::RecordVal *value = def->getValue("dialect")) { + Dialect dialect(cast(value->getValue())->getDef()); + return (dialect.getCppNamespace() + "::" + className).str(); + } + return className.str(); } Type::Type(const llvm::Record *record) : TypeConstraint(record) {} diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -247,6 +247,23 @@ // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// Test that type defs have the proper namespaces when used as a constraint. +// --- + +def Test_Dialect2 : Dialect { + let name = "test"; + let cppNamespace = "::mlir::dialect2"; +} +def TestDialect2Type : TypeDef; + +def NS_ResultWithDialectTypeOp : NS_Op<"op_with_dialect_type", []> { + let results = (outs TestDialect2Type); +} + +// CHECK-LABEL: NS::ResultWithDialectTypeOp declarations +// CHECK: class ResultWithDialectTypeOp : +// CHECK-SAME: ::mlir::OpTrait::OneTypedResult<::mlir::dialect2::Dialect2TypeType> + // Check that default builders can be suppressed. // ---