diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -26,6 +26,8 @@ class Type; using DialectAllocatorFunction = std::function; +using InterfaceAllocatorFunction = + std::function(Dialect *)>; /// Dialects are groups of MLIR operations, types and attributes, as well as /// behavior associated with the entire group. For example, hooks into other @@ -222,6 +224,7 @@ /// A collection of registered dialect interfaces. DenseMap> registeredInterfaces; + friend class DialectRegistry; friend void registerDialect(); friend class MLIRContext; }; @@ -234,8 +237,13 @@ class DialectRegistry { using MapTy = std::map>; + using InterfaceMapTy = + DenseMap>; public: + explicit DialectRegistry(MLIRContext *context = nullptr) + : owningContext(context) {} + template void insert() { insert(TypeID::get(), @@ -254,7 +262,9 @@ insert(); } - /// Add a new dialect constructor to the registry. + /// Add a new dialect constructor to the registry. The constructor must be + /// calling MLIRContext::getOrLoadDialect in order for the context to take + /// ownership of the dialect and for delayed interface registration to happen. void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor); /// Load a dialect for this namespace in the provided context. @@ -267,6 +277,7 @@ destination.insert(nameAndRegistrationIt.second.first, nameAndRegistrationIt.first, nameAndRegistrationIt.second.second); + destination.interfaces.insert(interfaces.begin(), interfaces.end()); } // Load all dialects available in the registry in the provided context. void loadAll(MLIRContext *context) { @@ -274,11 +285,47 @@ nameAndRegistrationIt.second.second(context); } - MapTy::const_iterator begin() const { return registry.begin(); } - MapTy::const_iterator end() const { return registry.end(); } + /// Return the names of dialects known to this registry. + auto getDialectNames() { + return llvm::map_range( + registry, [](const MapTy::value_type &item) { return item.first; }); + } + + /// Add an interface constructed with the given allocation function to the + /// dialect provided as template parameter. The dialect must be present in + /// the registry, but may or may not be loaded. If it is not loaded, the + /// interface registration is delayed until the loading. + template + void addDialectInterface(InterfaceAllocatorFunction allocator) { + addDialectInterface(DialectTy::getDialectNamespace(), allocator); + } + + /// Add an interface to the dialect, both provided as template parameter. The + /// dialect must be present in the registry, but may or may not be loaded. If + /// it is not loaded, the interface registration is delayed until the loading. + template + void addDialectInterface() { + addDialectInterface([](Dialect *dialect) { + return std::make_unique(dialect); + }); + } + + /// Register any interfaces required for the given dialect (based on its + /// TypeID). Users are not expected to call this directly. + void registerDelayedInterfaces(Dialect *dialect); private: + /// Add an interface constructed with the given allocation function to the + /// dialect identified by its namespace. + void addDialectInterface(StringRef dialectName, + InterfaceAllocatorFunction allocator); + MapTy registry; + InterfaceMapTy interfaces; + + /// If this registry belongs to a context, this points back to the context. + /// Useful for checking if a dialect is loaded in the context. + MLIRContext *owningContext; }; } // namespace mlir diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -22,6 +22,29 @@ DialectAsmParser::~DialectAsmParser() {} +//===----------------------------------------------------------------------===// +// DialectRegistry +//===----------------------------------------------------------------------===// + +void DialectRegistry::addDialectInterface( + StringRef dialectName, InterfaceAllocatorFunction allocator) { + assert(allocator && "unexpected null interface allocation function"); + + // If the dialect is already loaded, directly add the interface. + if (Dialect *dialect = owningContext + ? owningContext->getLoadedDialect(dialectName) + : nullptr) { + dialect->addInterface(allocator(dialect)); + return; + } + + // Otherwise, store it in the interface map for delayed registration. + auto it = registry.find(dialectName.str()); + assert(it != registry.end() && + "adding an interface for an unregistered dialect"); + interfaces[it->second.first].push_back(allocator); +} + Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) { auto it = registry.find(name.str()); if (it == registry.end()) @@ -40,6 +63,15 @@ } } +void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) { + auto it = interfaces.find(dialect->getTypeID()); + if (it == interfaces.end()) + return; + + for (const InterfaceAllocatorFunction &createInterface : it->second) + dialect->addInterface(createInterface(dialect)); +} + //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -326,7 +326,8 @@ DictionaryAttr emptyDictionaryAttr; public: - MLIRContextImpl() : identifiers(identifierAllocator) {} + MLIRContextImpl(MLIRContext *ctx) + : dialectsRegistry(ctx), identifiers(identifierAllocator) {} ~MLIRContextImpl() { for (auto typeMapping : registeredTypes) typeMapping.second->~AbstractType(); @@ -336,7 +337,7 @@ }; } // end namespace mlir -MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) { +MLIRContext::MLIRContext() : impl(new MLIRContextImpl(this)) { // Initialize values based on the command line flags if they were provided. if (clOptions.isConstructed()) { disableMultithreading(clOptions->disableThreading); @@ -441,8 +442,8 @@ } std::vector MLIRContext::getAvailableDialects() { std::vector result; - for (auto &dialect : impl->dialectsRegistry) - result.push_back(dialect.first); + for (auto dialect : impl->dialectsRegistry.getDialectNames()) + result.push_back(dialect); return result; } @@ -493,6 +494,8 @@ identifierEntry.first().startswith(dialectNamespace)) identifierEntry.second = dialect.get(); + // Actually register the interfaces with delayed registration. + impl.dialectsRegistry.registerDelayedInterfaces(dialect.get()); return dialect.get(); } diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -201,10 +201,8 @@ { llvm::raw_string_ostream os(helpHeader); MLIRContext context; - interleaveComma(registry, os, [&](auto ®istryEntry) { - StringRef name = registryEntry.first; - os << name; - }); + interleaveComma(registry.getDialectNames(), os, + [&](auto name) { os << name; }); } // Parse pass names in main to ensure static initialization completed. cl::ParseCommandLineOptions(argc, argv, helpHeader); @@ -212,8 +210,8 @@ if (showDialects) { llvm::outs() << "Available Dialects:\n"; interleave( - registry, llvm::outs(), - [](auto ®istryEntry) { llvm::outs() << registryEntry.first; }, "\n"); + registry.getDialectNames(), llvm::outs(), + [](auto name) { llvm::outs() << name; }, "\n"); return success(); } diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectInterface.h" #include "gtest/gtest.h" using namespace mlir; @@ -34,4 +35,61 @@ ASSERT_DEATH(context.loadDialect(), ""); } +struct SecondTestDialect : public Dialect { + static StringRef getDialectNamespace() { return "test2"; } + SecondTestDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, + TypeID::get()) {} +}; + +struct TestDialectInterfaceBase + : public DialectInterface::Base { + TestDialectInterfaceBase(Dialect *dialect) : Base(dialect) {} + virtual int function() const { return 42; } +}; + +struct TestDialectInterface : public TestDialectInterfaceBase { + using TestDialectInterfaceBase::TestDialectInterfaceBase; + int function() const final { return 56; } +}; + +struct SecondTestDialectInterface : public TestDialectInterfaceBase { + using TestDialectInterfaceBase::TestDialectInterfaceBase; + int function() const final { return 78; } +}; + +TEST(Dialect, DelayedInterfaceRegistration) { + DialectRegistry registry; + registry.insert(); + + // Delayed registration of an interface for TestDialect. + registry.addDialectInterface(); + + MLIRContext context; + registry.appendTo(context.getDialectRegistry()); + + // Load the TestDialect and check that the interface got registered for it. + auto *testDialect = context.getOrLoadDialect(); + ASSERT_TRUE(testDialect != nullptr); + auto *testDialectInterface = + testDialect->getRegisteredInterface(); + EXPECT_TRUE(testDialectInterface != nullptr); + + // Load the SecondTestDialect and check that the interface is not registered + // for it. + auto *secondTestDialect = context.getOrLoadDialect(); + ASSERT_TRUE(secondTestDialect != nullptr); + auto *secondTestDialectInterface = + secondTestDialect->getRegisteredInterface(); + EXPECT_TRUE(secondTestDialectInterface == nullptr); + + // Use the same mechanism as for delayed registration but for an already + // loaded dialect and check that the interface is now registered. + context.getDialectRegistry() + .addDialectInterface(); + secondTestDialectInterface = + secondTestDialect->getRegisteredInterface(); + EXPECT_TRUE(secondTestDialectInterface != nullptr); +} + } // end namespace