diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -98,6 +98,14 @@ })); } + /// Return whether the given dialect is loaded or not. + bool isDialectLoaded(StringRef dialectNamespace); + + /// Return whether the given dialect is loaded or not. + template bool isDialectLoaded() { + return isDialectLoaded(T::getDialectNamespace()); + } + /// Load a dialect in the context. template void loadDialect() { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -458,6 +458,12 @@ return dialect.get(); } +bool MLIRContext::isDialectLoaded(StringRef dialectNamespace) { + auto &impl = getImpl(); + return impl.loadedDialects.find(dialectNamespace) != + impl.loadedDialects.end(); +} + DynamicDialect *MLIRContext::getOrLoadDynamicDialect( StringRef dialectNamespace, function_ref ctor) { auto &impl = getImpl(); diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -258,7 +258,10 @@ {0}::{0}(::mlir::MLIRContext *context) : ::mlir::{2}(getDialectNamespace(), context, ::mlir::TypeID::get<{0}>()) {{ {1} - initialize(); + + // Do not initialize if this dialect was just loaded. + if (!context->isDialectLoaded(getDialectNamespace())) + initialize(); } )";