diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -87,16 +87,9 @@ /// /// {0}: The name of the dialect class. /// {1}: The dialect namespace. -/// {2}: initialization code that is emitted in the ctor body before calling -/// initialize() static const char *const dialectDeclBeginStr = R"( class {0} : public ::mlir::Dialect { - explicit {0}(::mlir::MLIRContext *context) - : ::mlir::Dialect(getDialectNamespace(), context, - ::mlir::TypeID::get<{0}>()) {{ - {2} - initialize(); - } + explicit {0}(::mlir::MLIRContext *context); void initialize(); friend class ::mlir::MLIRContext; @@ -190,23 +183,13 @@ const iterator_range &dialectAttrs, const iterator_range &dialectTypes, raw_ostream &os) { - /// Build the list of dependent dialects - std::string dependentDialectRegistrations; - { - llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); - for (StringRef dependentDialect : dialect.getDependentDialects()) - dialectsOs << llvm::formatv(dialectRegistrationTemplate, - dependentDialect); - } - // Emit all nested namespaces. { NamespaceEmitter nsEmitter(os, dialect); // Emit the start of the decl. std::string cppName = dialect.getCppClassName(); - os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), - dependentDialectRegistrations); + os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName()); // Check for any attributes/types registered to this dialect. If there are, // add the hooks for parsing/printing. @@ -262,6 +245,19 @@ // GEN: Dialect definitions //===----------------------------------------------------------------------===// +/// The code block to generate a dialect constructor definition. +/// +/// {0}: The name of the dialect class. +/// {1}: initialization code that is emitted in the ctor body before calling +/// initialize(). +static const char *const dialectConstructorStr = R"( +{0}::{0}(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{ + {1} + initialize(); +} +)"; + /// The code block to generate a default desturctor definition. /// /// {0}: The name of the dialect class. @@ -271,16 +267,30 @@ )"; static void emitDialectDef(Dialect &dialect, raw_ostream &os) { + std::string cppClassName = dialect.getCppClassName(); + // Emit the TypeID explicit specializations to have a single symbol def. if (!dialect.getCppNamespace().empty()) os << "MLIR_DEFINE_EXPLICIT_TYPE_ID(" << dialect.getCppNamespace() - << "::" << dialect.getCppClassName() << ")\n"; + << "::" << cppClassName << ")\n"; // Emit all nested namespaces. NamespaceEmitter nsEmitter(os, dialect); + /// Build the list of dependent dialects. + std::string dependentDialectRegistrations; + { + llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); + for (StringRef dependentDialect : dialect.getDependentDialects()) + dialectsOs << llvm::formatv(dialectRegistrationTemplate, + dependentDialect); + } + + // Emit the constructor and destructor. + os << llvm::formatv(dialectConstructorStr, cppClassName, + dependentDialectRegistrations); if (!dialect.hasNonDefaultDestructor()) - os << llvm::formatv(dialectDestructorStr, dialect.getCppClassName()); + os << llvm::formatv(dialectDestructorStr, cppClassName); } static bool emitDialectDefs(const llvm::RecordKeeper &recordKeeper,