diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -98,16 +98,22 @@ })); } + /// Return true if the given dialect is currently loading. + bool isDialectLoading(StringRef dialectNamespace); + /// Load a dialect in the context. template void loadDialect() { - getOrLoadDialect(); + // Do not load the dialect if it is currently loading. This can happen if a + // dialect initializer triggers loading the same dialect recursively. + if (!isDialectLoading(Dialect::getDialectNamespace())) + getOrLoadDialect(); } /// Load a list dialects in the context. template void loadDialect() { - getOrLoadDialect(); + loadDialect(); loadDialect(); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -175,6 +175,10 @@ DenseMap> loadedDialects; DialectRegistry dialectsRegistry; + /// This is a list of dialect that are currently in the process of loading. + /// I.e., their constructor/initializer is still executing. + DenseSet loadingDialects; + /// An allocator used for AbstractAttribute and AbstractType objects. llvm::BumpPtrAllocator abstractDialectSymbolAllocator; @@ -429,10 +433,17 @@ ") while in a multi-threaded execution context (maybe " "the PassManager): this can indicate a " "missing `dependentDialects` in a pass for example."); + if (impl.loadingDialects.contains(dialectNamespace)) + llvm::report_fatal_error( + "Loading (and getting) a dialect (" + dialectNamespace + + ") while the same dialect is still loading: use loadDialect instead " + "of getOrLoadDialect."); #endif + auto it = impl.loadingDialects.insert(dialectNamespace); std::unique_ptr &dialect = impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second; assert(dialect && "dialect ctor failed"); + impl.loadingDialects.erase(it.first); // Refresh all the identifiers dialect field, this catches cases where a // dialect may be loaded after identifier prefixed with this dialect name @@ -458,6 +469,10 @@ return dialect.get(); } +bool MLIRContext::isDialectLoading(StringRef dialectNamespace) { + return getImpl().loadingDialects.contains(dialectNamespace); +} + DynamicDialect *MLIRContext::getOrLoadDynamicDialect( StringRef dialectNamespace, function_ref ctor) { auto &impl = getImpl(); diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -107,7 +107,7 @@ /// Registration for a single dependent dialect: to be inserted in the ctor /// above for each dependent dialect. const char *const dialectRegistrationTemplate = R"( - getContext()->getOrLoadDialect<{0}>(); + getContext()->loadDialect<{0}>(); )"; /// The code block for the attribute parser/printer hooks.