diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h --- a/mlir/include/mlir/IR/Dialect.h +++ b/mlir/include/mlir/IR/Dialect.h @@ -28,6 +28,7 @@ Operation *, ArrayRef, SmallVectorImpl &)>; using DialectExtractElementHook = std::function)>; +using DialectAllocatorFunction = std::function; /// Dialects are groups of MLIR operations and behavior associated with the /// entire group. For example, hooks into other systems for constant folding, @@ -241,24 +242,30 @@ /// A collection of registered dialect interfaces. DenseMap> registeredInterfaces; -}; - -using DialectAllocatorFunction = std::function; - -/// Registers a specific dialect creation function with the system, typically -/// used through the DialectRegistration template. -void registerDialectAllocator(const DialectAllocatorFunction &function); -/// Registers all dialects with the specified MLIRContext. + /// Registers a specific dialect creation function with the global registry. + /// Used through the registerDialect template. + /// Registrations are deduplicated by dialect ClassID and only the first + /// registration will be used. + static void + registerDialectAllocator(const ClassID *classId, + const DialectAllocatorFunction &function); + template + friend void registerDialect(); +}; +/// Registers all dialects and hooks from the global registries with the +/// specified MLIRContext. void registerAllDialects(MLIRContext *context); /// Utility to register a dialect. Client can register their dialect with the /// global registry by calling registerDialect(); template void registerDialect() { - registerDialectAllocator([](MLIRContext *ctx) { - // Just allocate the dialect, the context takes ownership of it. - new ConcreteDialect(ctx); - }); + Dialect::registerDialectAllocator(ClassID::getID(), + [](MLIRContext *ctx) { + // Just allocate the dialect, the context + // takes ownership of it. + new ConcreteDialect(ctx); + }); } /// DialectRegistration provides a global initializer that registers a Dialect diff --git a/mlir/include/mlir/IR/DialectHooks.h b/mlir/include/mlir/IR/DialectHooks.h --- a/mlir/include/mlir/IR/DialectHooks.h +++ b/mlir/include/mlir/IR/DialectHooks.h @@ -35,36 +35,53 @@ DialectConstantDecodeHook getDecodeHook() { return nullptr; } // Returns hook to extract an element of an opaque constant tensor. DialectExtractElementHook getExtractElementHook() { return nullptr; } + +private: + /// Registers a function that will set hooks in the registered dialects. + /// Registrations are deduplicated by dialect ClassID and only the first + /// registration will be used. + static void registerDialectHooksSetter(const ClassID *classId, + const DialectHooksSetter &function); + template + friend void registerDialectHooks(StringRef dialectName); }; -/// Registers a function that will set hooks in the registered dialects -/// based on information coming from DialectHooksRegistration. -void registerDialectHooksSetter(const DialectHooksSetter &function); +void registerDialectHooksSetter(const ClassID *classId, + const DialectHooksSetter &function); + +/// Utility to register dialect hooks. Client can register their dialect hooks +/// with the global registry by calling +/// registerDialectHooks("dialect_namespace"); +template +void registerDialectHooks(StringRef dialectName) { + DialectHooks::registerDialectHooksSetter( + ClassID::getID(), [dialectName](MLIRContext *ctx) { + Dialect *dialect = ctx->getRegisteredDialect(dialectName); + if (!dialect) { + llvm::errs() << "error: cannot register hooks for unknown dialect '" + << dialectName << "'\n"; + abort(); + } + // Set hooks. + ConcreteHooks hooks; + if (auto h = hooks.getConstantFoldHook()) + dialect->constantFoldHook = h; + if (auto h = hooks.getDecodeHook()) + dialect->decodeHook = h; + if (auto h = hooks.getExtractElementHook()) + dialect->extractElementHook = h; + }); +} /// DialectHooksRegistration provides a global initializer that registers /// a dialect hooks setter routine. /// Usage: /// /// // At namespace scope. -/// static DialectHooksRegistration unused; +/// static DialectHooksRegistration Unused("dialect_namespace"); template struct DialectHooksRegistration { DialectHooksRegistration(StringRef dialectName) { - registerDialectHooksSetter([dialectName](MLIRContext *ctx) { - Dialect *dialect = ctx->getRegisteredDialect(dialectName); - if (!dialect) { - llvm::errs() << "error: cannot register hooks for unknown dialect '" - << dialectName << "'\n"; - abort(); - } - // Set hooks. - ConcreteHooks hooks; - if (auto h = hooks.getConstantFoldHook()) - dialect->constantFoldHook = h; - if (auto h = hooks.getDecodeHook()) - dialect->decodeHook = h; - if (auto h = hooks.getExtractElementHook()) - dialect->extractElementHook = h; - }); + registerDialectHooks(dialectName); } }; diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/DialectInterface.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" @@ -26,39 +27,40 @@ // Dialect Registration //===----------------------------------------------------------------------===// -// Registry for all dialect allocation functions. -static llvm::ManagedStatic> +/// Registry for all dialect allocation functions. +static llvm::ManagedStatic< + llvm::MapVector> dialectRegistry; -// Registry for functions that set dialect hooks. -static llvm::ManagedStatic> +/// Registry for functions that set dialect hooks. +static llvm::ManagedStatic> dialectHooksRegistry; -/// Registers a specific dialect creation function with the system, typically -/// used through the DialectRegistration template. -void mlir::registerDialectAllocator(const DialectAllocatorFunction &function) { +void Dialect::registerDialectAllocator( + const ClassID *classId, const DialectAllocatorFunction &function) { assert(function && "Attempting to register an empty dialect initialize function"); - dialectRegistry->push_back(function); + dialectRegistry->insert({classId, function}); } /// Registers a function to set specific hooks for a specific dialect, typically /// used through the DialectHooksRegistration template. -void mlir::registerDialectHooksSetter(const DialectHooksSetter &function) { +void DialectHooks::registerDialectHooksSetter( + const ClassID *classId, const DialectHooksSetter &function) { assert( function && "Attempting to register an empty dialect hooks initialization function"); - dialectHooksRegistry->push_back(function); + dialectHooksRegistry->insert({classId, function}); } -/// Registers all dialects and their const folding hooks with the specified -/// MLIRContext. +/// Registers all dialects and hooks from the global registries with the +/// specified MLIRContext. void mlir::registerAllDialects(MLIRContext *context) { - for (const auto &fn : *dialectRegistry) - fn(context); - for (const auto &fn : *dialectHooksRegistry) { - fn(context); + for (const auto &it : *dialectRegistry) + it.second(context); + for (const auto &it : *dialectHooksRegistry) { + it.second(context); } }