diff --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h --- a/mlir/include/mlir/CAPI/Registration.h +++ b/mlir/include/mlir/CAPI/Registration.h @@ -22,7 +22,9 @@ #define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \ void mlirContextRegister##Name##Dialect(MlirContext context) { \ - unwrap(context)->getDialectRegistry().insert(); \ + mlir::DialectRegistry registry; \ + registry.insert(); \ + unwrap(context)->appendDialectRegistry(registry); \ } \ MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \ return wrap(unwrap(context)->getOrLoadDialect()); \ diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -241,8 +241,7 @@ DenseMap>; public: - explicit DialectRegistry(MLIRContext *context = nullptr) - : owningContext(context) {} + explicit DialectRegistry() {} template void insert() { @@ -267,42 +266,37 @@ /// ownership of the dialect and for delayed interface registration to happen. void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor); - /// Load a dialect for this namespace in the provided context. - Dialect *loadByName(StringRef name, MLIRContext *context); + /// Return an allocation function for constructing the dialect identified by + /// its namespace, or nullptr if the namespace is not in this registry. + DialectAllocatorFunction getDialectAllocator(StringRef name) const; // Register all dialects available in the current registry with the registry // in the provided context. - void appendTo(DialectRegistry &destination) { + void appendTo(DialectRegistry &destination) const { for (const auto &nameAndRegistrationIt : registry) destination.insert(nameAndRegistrationIt.second.first, nameAndRegistrationIt.first, nameAndRegistrationIt.second.second); destination.interfaces.insert(interfaces.begin(), interfaces.end()); } - // Load all dialects available in the registry in the provided context. - void loadAll(MLIRContext *context) { - for (const auto &nameAndRegistrationIt : registry) - nameAndRegistrationIt.second.second(context); - } /// Return the names of dialects known to this registry. - auto getDialectNames() { + auto getDialectNames() const { return llvm::map_range( - registry, [](const MapTy::value_type &item) { return item.first; }); + registry, + [](const MapTy::value_type &item) -> StringRef { return item.first; }); } /// Add an interface constructed with the given allocation function to the /// dialect provided as template parameter. The dialect must be present in - /// the registry, but may or may not be loaded. If it is not loaded, the - /// interface registration is delayed until the loading. + /// the registry. template void addDialectInterface(InterfaceAllocatorFunction allocator) { addDialectInterface(DialectTy::getDialectNamespace(), allocator); } /// Add an interface to the dialect, both provided as template parameter. The - /// dialect must be present in the registry, but may or may not be loaded. If - /// it is not loaded, the interface registration is delayed until the loading. + /// dialect must be present in the registry. template void addDialectInterface() { addDialectInterface([](Dialect *dialect) { @@ -312,7 +306,7 @@ /// Register any interfaces required for the given dialect (based on its /// TypeID). Users are not expected to call this directly. - void registerDelayedInterfaces(Dialect *dialect); + void registerDelayedInterfaces(Dialect *dialect) const; private: /// Add an interface constructed with the given allocation function to the @@ -322,10 +316,6 @@ MapTy registry; InterfaceMapTy interfaces; - - /// If this registry belongs to a context, this points back to the context. - /// Useful for checking if a dialect is loaded in the context. - MLIRContext *owningContext; }; } // namespace mlir diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h --- a/mlir/include/mlir/IR/MLIRContext.h +++ b/mlir/include/mlir/IR/MLIRContext.h @@ -36,17 +36,19 @@ class MLIRContext { public: /// Create a new Context. - /// The loadAllDialects parameters allows to load all dialects from the global - /// registry on Context construction. It is deprecated and will be removed - /// soon. explicit MLIRContext(); + explicit MLIRContext(const DialectRegistry ®istry); ~MLIRContext(); /// Return information about all IR dialects loaded in the context. std::vector getLoadedDialects(); /// Return the dialect registry associated with this context. - DialectRegistry &getDialectRegistry(); + const DialectRegistry &getDialectRegistry(); + + /// Appends the contents of the given dialect registry to the registry + /// associated with this context. + void appendDialectRegistry(const DialectRegistry ®istry); /// Return information about all available dialects in the registry in this /// context. @@ -87,6 +89,9 @@ loadDialect(); } + /// Loads all dialects available in the regsitry in this context. + void loadAllAvailableDialects(); + /// Get (or create) a dialect for the given derived dialect name. /// The dialect will be loaded from the registry if no dialect is found. /// If no dialect is loaded for this name and none is available in the diff --git a/mlir/lib/CAPI/Registration/Registration.cpp b/mlir/lib/CAPI/Registration/Registration.cpp --- a/mlir/lib/CAPI/Registration/Registration.cpp +++ b/mlir/lib/CAPI/Registration/Registration.cpp @@ -12,7 +12,9 @@ #include "mlir/InitAllDialects.h" void mlirRegisterAllDialects(MlirContext context) { - registerAllDialects(unwrap(context)->getDialectRegistry()); + mlir::DialectRegistry registry; + registerAllDialects(registry); + unwrap(context)->appendDialectRegistry(registry); // TODO: we may not want to eagerly load here. - unwrap(context)->getDialectRegistry().loadAll(unwrap(context)); + unwrap(context)->loadAllAvailableDialects(); } diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -330,8 +330,9 @@ } } - MLIRContext context; - registerAllDialects(context.getDialectRegistry()); + DialectRegistry registry; + registerAllDialects(registry); + MLIRContext context(registry); auto m = parseMLIRInput(options.inputFilename, &context); if (!m) { diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -29,27 +29,18 @@ void DialectRegistry::addDialectInterface( StringRef dialectName, InterfaceAllocatorFunction allocator) { assert(allocator && "unexpected null interface allocation function"); - - // If the dialect is already loaded, directly add the interface. - if (Dialect *dialect = owningContext - ? owningContext->getLoadedDialect(dialectName) - : nullptr) { - dialect->addInterface(allocator(dialect)); - return; - } - - // Otherwise, store it in the interface map for delayed registration. auto it = registry.find(dialectName.str()); assert(it != registry.end() && "adding an interface for an unregistered dialect"); interfaces[it->second.first].push_back(allocator); } -Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) { +DialectAllocatorFunction +DialectRegistry::getDialectAllocator(StringRef name) const { auto it = registry.find(name.str()); if (it == registry.end()) return nullptr; - return it->second.second(context); + return it->second.second; } void DialectRegistry::insert(TypeID typeID, StringRef name, @@ -63,7 +54,7 @@ } } -void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) { +void DialectRegistry::registerDelayedInterfaces(Dialect *dialect) const { auto it = interfaces.find(dialect->getTypeID()); if (it == interfaces.end()) return; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -326,8 +326,7 @@ DictionaryAttr emptyDictionaryAttr; public: - MLIRContextImpl(MLIRContext *ctx) - : dialectsRegistry(ctx), identifiers(identifierAllocator) {} + MLIRContextImpl() : identifiers(identifierAllocator) {} ~MLIRContextImpl() { for (auto typeMapping : registeredTypes) typeMapping.second->~AbstractType(); @@ -337,7 +336,10 @@ }; } // end namespace mlir -MLIRContext::MLIRContext() : impl(new MLIRContextImpl(this)) { +MLIRContext::MLIRContext() : MLIRContext(DialectRegistry()) {} + +MLIRContext::MLIRContext(const DialectRegistry ®istry) + : impl(new MLIRContextImpl) { // Initialize values based on the command line flags if they were provided. if (clOptions.isConstructed()) { disableMultithreading(clOptions->disableThreading); @@ -348,6 +350,9 @@ // Ensure the builtin dialect is always pre-loaded. getOrLoadDialect(); + // Pre-populate the registry. + registry.appendTo(impl->dialectsRegistry); + // Initialize several common attributes and types to avoid the need to lock // the context when accessing them. @@ -424,7 +429,15 @@ // Dialect and Operation Registration //===----------------------------------------------------------------------===// -DialectRegistry &MLIRContext::getDialectRegistry() { +void MLIRContext::appendDialectRegistry(const DialectRegistry ®istry) { + registry.appendTo(impl->dialectsRegistry); + + // For the already loaded dialects, register the interfaces immediately. + for (const auto &kvp : impl->loadedDialects) + registry.registerDelayedInterfaces(kvp.second.get()); +} + +const DialectRegistry &MLIRContext::getDialectRegistry() { return impl->dialectsRegistry; } @@ -459,7 +472,9 @@ Dialect *dialect = getLoadedDialect(name); if (dialect) return dialect; - return impl->dialectsRegistry.loadByName(name, this); + DialectAllocatorFunction allocator = + impl->dialectsRegistry.getDialectAllocator(name); + return allocator ? allocator(this) : nullptr; } /// Get a dialect for the provided namespace and TypeID: abort the program if a @@ -507,6 +522,11 @@ return dialect.get(); } +void MLIRContext::loadAllAvailableDialects() { + for (StringRef name : getAvailableDialects()) + getOrLoadDialect(name); +} + llvm::hash_code MLIRContext::getRegistryHash() { llvm::hash_code hash(0); // Factor in number of loaded dialects, attributes, operations, types. diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -866,7 +866,9 @@ // Register all dialects for the current pipeline. DialectRegistry dependentDialects; getDependentDialects(dependentDialects); - dependentDialects.loadAll(context); + context->appendDialectRegistry(dependentDialects); + for (auto name : dependentDialects.getDialectNames()) + context->getOrLoadDialect(name); // Initialize all of the passes within the pass manager with a new generation. llvm::hash_code newInitKey = context->getRegistryHash(); diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp --- a/mlir/lib/Support/MlirOptMain.cpp +++ b/mlir/lib/Support/MlirOptMain.cpp @@ -95,10 +95,9 @@ sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); // Parse the input file. - MLIRContext context; - registry.appendTo(context.getDialectRegistry()); + MLIRContext context(registry); if (preloadDialectsInContext) - registry.loadAll(&context); + context.loadAllAvailableDialects(); context.allowUnregisteredDialects(allowUnregisteredDialects); context.printOpOnDiagnostic(!verifyDiagnostics); diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -136,8 +136,8 @@ if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo))) return failure(); - MLIRContext deserializationContext; - context->getDialectRegistry().loadAll(&deserializationContext); + MLIRContext deserializationContext(context->getDialectRegistry()); + deserializationContext.loadAllAvailableDialects(); // Then deserialize to get back a SPIR-V module. spirv::OwningSPIRVModuleRef spirvModule = spirv::deserialize(binary, &deserializationContext); diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp --- a/mlir/lib/Translation/Translation.cpp +++ b/mlir/lib/Translation/Translation.cpp @@ -13,6 +13,7 @@ #include "mlir/Translation.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" #include "mlir/Parser.h" #include "mlir/Support/FileUtilities.h" @@ -97,7 +98,9 @@ registerTranslation(name, [function, dialectRegistration]( llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { - dialectRegistration(context->getDialectRegistry()); + DialectRegistry registry; + dialectRegistration(registry); + context->appendDialectRegistry(registry); auto module = OwningModuleRef(parseSourceFile(sourceMgr, context)); if (!module) return failure(); diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp --- a/mlir/unittests/IR/DialectTest.cpp +++ b/mlir/unittests/IR/DialectTest.cpp @@ -65,8 +65,7 @@ // Delayed registration of an interface for TestDialect. registry.addDialectInterface(); - MLIRContext context; - registry.appendTo(context.getDialectRegistry()); + MLIRContext context(registry); // Load the TestDialect and check that the interface got registered for it. auto *testDialect = context.getOrLoadDialect(); @@ -85,8 +84,11 @@ // Use the same mechanism as for delayed registration but for an already // loaded dialect and check that the interface is now registered. - context.getDialectRegistry() + DialectRegistry secondRegistry; + secondRegistry.insert(); + secondRegistry .addDialectInterface(); + context.appendDialectRegistry(secondRegistry); secondTestDialectInterface = secondTestDialect->getRegisteredInterface(); EXPECT_TRUE(secondTestDialectInterface != nullptr);