diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -247,7 +247,10 @@ /// Registers a specific dialect creation function with the system, typically /// used through the DialectRegistration template. -void registerDialectAllocator(const DialectAllocatorFunction &function); +/// Registrations are deduplicated by dialect ClassID and only the first +/// registration will be used. +void registerDialectAllocator(const ClassID *classId, + const DialectAllocatorFunction &function); /// Registers all dialects with the specified MLIRContext. void registerAllDialects(MLIRContext *context); @@ -255,10 +258,12 @@ /// Utility to register a dialect. Client can register their dialect with the /// global registry by calling registerDialect(); template void registerDialect() { - registerDialectAllocator([](MLIRContext *ctx) { - // Just allocate the dialect, the context takes ownership of it. - new ConcreteDialect(ctx); - }); + registerDialectAllocator(ClassID::getID(), + [](MLIRContext *ctx) { + // Just allocate the dialect, the context takes + // ownership of it. + new ConcreteDialect(ctx); + }); } /// DialectRegistration provides a global initializer that registers a Dialect diff --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h --- a/mlir/include/mlir/IR/DialectHooks.h +++ b/mlir/include/mlir/IR/DialectHooks.h @@ -39,32 +39,36 @@ /// Registers a function that will set hooks in the registered dialects /// based on information coming from DialectHooksRegistration. -void registerDialectHooksSetter(const DialectHooksSetter &function); +void registerDialectHooksSetter(const ClassID *classId, + const DialectHooksSetter &function); /// DialectHooksRegistration provides a global initializer that registers /// a dialect hooks setter routine. +/// Registrations are deduplicated by dialect ClassID and only the first +/// registration will be used. /// Usage: /// /// // At namespace scope. -/// static DialectHooksRegistration unused; +/// static DialectHooksRegistration unused("dialect_namespace"); template struct DialectHooksRegistration { DialectHooksRegistration(StringRef dialectName) { - registerDialectHooksSetter([dialectName](MLIRContext *ctx) { - Dialect *dialect = ctx->getRegisteredDialect(dialectName); - if (!dialect) { - llvm::errs() << "error: cannot register hooks for unknown dialect '" - << dialectName << "'\n"; - abort(); - } - // Set hooks. - ConcreteHooks hooks; - if (auto h = hooks.getConstantFoldHook()) - dialect->constantFoldHook = h; - if (auto h = hooks.getDecodeHook()) - dialect->decodeHook = h; - if (auto h = hooks.getExtractElementHook()) - dialect->extractElementHook = h; - }); + registerDialectHooksSetter( + ClassID::getID(), [dialectName](MLIRContext *ctx) { + Dialect *dialect = ctx->getRegisteredDialect(dialectName); + if (!dialect) { + llvm::errs() << "error: cannot register hooks for unknown dialect '" + << dialectName << "'\n"; + abort(); + } + // Set hooks. + ConcreteHooks hooks; + if (auto h = hooks.getConstantFoldHook()) + dialect->constantFoldHook = h; + if (auto h = hooks.getDecodeHook()) + dialect->decodeHook = h; + if (auto h = hooks.getExtractElementHook()) + dialect->extractElementHook = h; + }); } }; diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -27,38 +27,40 @@ //===----------------------------------------------------------------------===// // Registry for all dialect allocation functions. -static llvm::ManagedStatic> +static llvm::ManagedStatic> dialectRegistry; // Registry for functions that set dialect hooks. -static llvm::ManagedStatic> +static llvm::ManagedStatic> dialectHooksRegistry; /// Registers a specific dialect creation function with the system, typically /// used through the DialectRegistration template. -void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) { +void mlir::registerDialectAllocator(const ClassID *classId, + const DialectAllocatorFunction &function) { assert(function && "Attempting to register an empty dialect initialize function"); - dialectRegistry->push_back(function); + dialectRegistry->try_emplace(classId, function); } /// Registers a function to set specific hooks for a specific dialect, typically /// used through the DialectHooksRegistration template. -void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) { +void mlir::registerDialectHooksSetter(const ClassID *classId, + const DialectHooksSetter &function) { assert( function && "Attempting to register an empty dialect hooks initialization function"); - dialectHooksRegistry->push_back(function); + dialectHooksRegistry->try_emplace(classId, function); } /// Registers all dialects and their const folding hooks with the specified /// MLIRContext. void mlir::registerAllDialects(MLIRContext *context) { - for (const auto &fn : *dialectRegistry) - fn(context); - for (const auto &fn : *dialectHooksRegistry) { - fn(context); + for (const auto &it : *dialectRegistry) + it.second(context); + for (const auto &it : *dialectHooksRegistry) { + it.second(context); } }