diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -125,7 +125,8 @@ MLIRContext *ctx, TypeID interfaceKind, StringRef interfaceName) { for (auto *dialect : ctx->getLoadedDialects()) { #ifndef NDEBUG - dialect->handleUseOfUndefinedPromisedInterface(interfaceKind, interfaceName); + dialect->handleUseOfUndefinedPromisedInterface(interfaceKind, + interfaceName); #endif if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) { interfaces.insert(interface); @@ -248,8 +249,9 @@ extension.apply(ctx, requiredDialects); }; - for (const auto &extension : extensions) - applyExtension(*extension); + // Note: Additional extensions may be added while applying an extension. + for (int i = 0; i < static_cast(extensions.size()); ++i) + applyExtension(*extensions[i]); } void DialectRegistry::applyExtensions(MLIRContext *ctx) const { @@ -274,8 +276,9 @@ extension.apply(ctx, requiredDialects); }; - for (const auto &extension : extensions) - applyExtension(*extension); + // Note: Additional extensions may be added while applying an extension. + for (int i = 0; i < static_cast(extensions.size()); ++i) + applyExtension(*extensions[i]); } bool DialectRegistry::isSubsetOf(const DialectRegistry &rhs) const { diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -136,4 +136,50 @@ EXPECT_TRUE(testDialectInterface != nullptr); } +namespace { +/// A dummy extension that increases a counter when being applied and +/// recursively adds additional extensions. +struct DummyExtension : DialectExtension { + DummyExtension(int *counter, int numRecursive) + : DialectExtension(), counter(counter), numRecursive(numRecursive) {} + + void apply(MLIRContext *ctx, TestDialect *dialect) const final { + ++(*counter); + DialectRegistry nestedRegistry; + for (int i = 0; i < numRecursive; ++i) + nestedRegistry.addExtension( + std::make_unique(counter, /*numRecursive=*/0)); + // Adding additional extensions may trigger a reallocation of the + // `extensions` vector in the dialect registry. + ctx->appendDialectRegistry(nestedRegistry); + } + +private: + int *counter; + int numRecursive; +}; +} // namespace + +TEST(Dialect, NestedDialectExtension) { + DialectRegistry registry; + registry.insert(); + + // Add an extension that adds 100 more extensions. + int counter1 = 0; + registry.addExtension(std::make_unique(&counter1, 100)); + // Add one more extension. This should not crash. + int counter2 = 0; + registry.addExtension(std::make_unique(&counter2, 0)); + + // Load dialect and apply extensions. + MLIRContext context(registry); + Dialect *testDialect = context.getOrLoadDialect(); + ASSERT_TRUE(testDialect != nullptr); + + // Extensions may be applied multiple times. Make sure that each expected + // extension was applied at least once. + EXPECT_GE(counter1, 101); + EXPECT_GE(counter2, 1); +} + } // namespace