diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -239,7 +239,8 @@ using MapTy = std::map>; using InterfaceMapTy = - DenseMap>; + DenseMap, 2>>; public: explicit DialectRegistry() {} @@ -292,17 +293,20 @@ /// dialect provided as template parameter. The dialect must be present in /// the registry. template - void addDialectInterface(InterfaceAllocatorFunction allocator) { - addDialectInterface(DialectTy::getDialectNamespace(), allocator); + void addDialectInterface(TypeID interfaceTypeID, + InterfaceAllocatorFunction allocator) { + addDialectInterface(DialectTy::getDialectNamespace(), interfaceTypeID, + allocator); } /// Add an interface to the dialect, both provided as template parameter. The /// dialect must be present in the registry. template void addDialectInterface() { - addDialectInterface([](Dialect *dialect) { - return std::make_unique(dialect); - }); + addDialectInterface( + InterfaceTy::getInterfaceID(), [](Dialect *dialect) { + return std::make_unique(dialect); + }); } /// Register any interfaces required for the given dialect (based on its @@ -312,7 +316,7 @@ private: /// Add an interface constructed with the given allocation function to the /// dialect identified by its namespace. - void addDialectInterface(StringRef dialectName, + void addDialectInterface(StringRef dialectName, TypeID interfaceTypeID, InterfaceAllocatorFunction allocator); MapTy registry; diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -14,9 +14,12 @@ #include "mlir/IR/Operation.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" +#define DEBUG_TYPE "dialect" + using namespace mlir; using namespace detail; @@ -27,12 +30,28 @@ //===----------------------------------------------------------------------===// void DialectRegistry::addDialectInterface( - StringRef dialectName, InterfaceAllocatorFunction allocator) { + StringRef dialectName, TypeID interfaceTypeID, + InterfaceAllocatorFunction allocator) { assert(allocator && "unexpected null interface allocation function"); auto it = registry.find(dialectName.str()); assert(it != registry.end() && "adding an interface for an unregistered dialect"); - interfaces[it->second.first].push_back(allocator); + + // Bail out if the interface with the given ID is already in the registry for + // the given dialect. We expect a small number (dozens) of interfaces so a + // linear search is fine here. + auto &dialectInterfaces = interfaces[it->second.first]; + for (const auto &kvp : dialectInterfaces) { + if (kvp.first == interfaceTypeID) { + LLVM_DEBUG(llvm::dbgs() + << "[" DEBUG_TYPE + "] repeated interface registration for dialect " + << dialectName); + return; + } + } + + dialectInterfaces.emplace_back(interfaceTypeID, allocator); } DialectAllocatorFunctionRef @@ -59,8 +78,12 @@ if (it == interfaces.end()) return; - for (const InterfaceAllocatorFunction &createInterface : it->second) - dialect->addInterface(createInterface(dialect)); + // Add an interface if it is not already present. + for (const auto &kvp : it->second) { + if (dialect->getRegisteredInterface(kvp.first)) + continue; + dialect->addInterface(kvp.second(dialect)); + } } //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -94,4 +94,29 @@ EXPECT_TRUE(secondTestDialectInterface != nullptr); } +TEST(Dialect, RepeatedDelayedRegistration) { + // Set up the delayed registration. + DialectRegistry registry; + registry.insert(); + registry.addDialectInterface(); + MLIRContext context(registry); + + // Load the TestDialect and check that the interface got registered for it. + auto *testDialect = context.getOrLoadDialect(); + ASSERT_TRUE(testDialect != nullptr); + auto *testDialectInterface = + testDialect->getRegisteredInterface(); + EXPECT_TRUE(testDialectInterface != nullptr); + + // Try adding the same dialect interface again and check that we don't crash + // on repeated interface registration. + DialectRegistry secondRegistry; + secondRegistry.insert(); + secondRegistry.addDialectInterface(); + context.appendDialectRegistry(secondRegistry); + testDialectInterface = + testDialect->getRegisteredInterface(); + EXPECT_TRUE(testDialectInterface != nullptr); +} + } // end namespace