diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -98,16 +98,22 @@ })); } + /// Return true if the given dialect is currently loading. + bool isDialectLoading(StringRef dialectNamespace); + /// Load a dialect in the context. template void loadDialect() { - getOrLoadDialect(); + // Do not load the dialect if it is currently loading. This can happen if a + // dialect initializer triggers loading the same dialect recursively. + if (!isDialectLoading(Dialect::getDialectNamespace())) + getOrLoadDialect(); } /// Load a list dialects in the context. template void loadDialect() { - getOrLoadDialect(); + loadDialect(); loadDialect(); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -429,9 +429,11 @@ ") while in a multi-threaded execution context (maybe " "the PassManager): this can indicate a " "missing `dependentDialects` in a pass for example."); -#endif - std::unique_ptr &dialect = - impl.loadedDialects.insert({dialectNamespace, ctor()}).first->second; +#endif // NDEBUG + // nullptr indicates that the dialect is currently being loaded. + impl.loadedDialects[dialectNamespace] = nullptr; + std::unique_ptr &dialect = impl.loadedDialects[dialectNamespace] = + ctor(); assert(dialect && "dialect ctor failed"); // Refresh all the identifiers dialect field, this catches cases where a @@ -449,6 +451,14 @@ return dialect.get(); } +#ifndef NDEBUG + if (dialectIt->second == nullptr) + llvm::report_fatal_error( + "Loading (and getting) a dialect (" + dialectNamespace + + ") while the same dialect is still loading: use loadDialect instead " + "of getOrLoadDialect."); +#endif // NDEBUG + // Abort if dialect with namespace has already been registered. std::unique_ptr &dialect = dialectIt->second; if (dialect->getTypeID() != dialectID) @@ -458,6 +468,12 @@ return dialect.get(); } +bool MLIRContext::isDialectLoading(StringRef dialectNamespace) { + auto it = getImpl().loadedDialects.find(dialectNamespace); + // nullptr indicates that the dialect is currently being loaded. + return it != getImpl().loadedDialects.end() && it->second == nullptr; +} + DynamicDialect *MLIRContext::getOrLoadDynamicDialect( StringRef dialectNamespace, function_ref ctor) { auto &impl = getImpl(); diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -107,7 +107,7 @@ /// Registration for a single dependent dialect: to be inserted in the ctor /// above for each dependent dialect. const char *const dialectRegistrationTemplate = R"( - getContext()->getOrLoadDialect<{0}>(); + getContext()->loadDialect<{0}>(); )"; /// The code block for the attribute parser/printer hooks.