diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -50,6 +50,7 @@ DEFINE_C_API_STRUCT(MlirContext, void); DEFINE_C_API_STRUCT(MlirDialect, void); +DEFINE_C_API_STRUCT(MlirDialectRegistry, void); DEFINE_C_API_STRUCT(MlirOperation, void); DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void); DEFINE_C_API_STRUCT(MlirBlock, void); @@ -108,6 +109,11 @@ MLIR_CAPI_EXPORTED intptr_t mlirContextGetNumRegisteredDialects(MlirContext context); +/// Append the contents of the given dialect registry to the registry associated +/// with the context. +MLIR_CAPI_EXPORTED void +mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry); + /// Returns the number of dialects loaded by the context. MLIR_CAPI_EXPORTED intptr_t @@ -152,6 +158,22 @@ /// Returns the namespace of the given dialect. MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect); +//===----------------------------------------------------------------------===// +// DialectRegistry API. +//===----------------------------------------------------------------------===// + +/// Creates a dialect registry and transfers its ownership to the caller. +MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate(); + +/// Checks if the dialect registry is null. +static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) { + return !registry.ptr; +} + +/// Takes a dialect registry owned by the caller and destroys it. +MLIR_CAPI_EXPORTED void +mlirDialectRegistryDestroy(MlirDialectRegistry registry); + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h --- a/mlir/include/mlir-c/Registration.h +++ b/mlir/include/mlir-c/Registration.h @@ -44,6 +44,11 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); +/// Inserts the dialect associated with the provided dialect handle into the +/// provided dialect registry +MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, + MlirDialectRegistry); + /// Registers the dialect associated with the provided dialect handle. MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, MlirContext); diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -22,6 +22,7 @@ DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) +DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry) DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags) diff --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h --- a/mlir/include/mlir/CAPI/Registration.h +++ b/mlir/include/mlir/CAPI/Registration.h @@ -21,12 +21,17 @@ //===----------------------------------------------------------------------===// /// Hooks for dynamic discovery of dialects. +typedef void (*MlirDialectRegistryInsertDialectHook)( + MlirDialectRegistry registry); typedef void (*MlirContextRegisterDialectHook)(MlirContext context); typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context); typedef MlirStringRef (*MlirDialectGetNamespaceHook)(); /// Structure of dialect registration hooks. struct MlirDialectRegistrationHooks { + MlirDialectRegistryInsertDialectHook insertHook; + // TODO: Remove `registerHook` and implement + // `mlirDialectHandleRegisterDialect` using `mlirDialectHandleInsertDialect` MlirContextRegisterDialectHook registerHook; MlirContextLoadDialectHook loadHook; MlirDialectGetNamespaceHook getNamespaceHook; @@ -34,6 +39,10 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks; #define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \ + static void mlirDialectRegistryInsert##Name##Dialect( \ + MlirDialectRegistry registry) { \ + unwrap(registry)->insert(); \ + } \ static void mlirContextRegister##Name##Dialect(MlirContext context) { \ mlir::DialectRegistry registry; \ registry.insert(); \ @@ -47,6 +56,7 @@ } \ MlirDialectHandle mlirGetDialectHandle__##Namespace##__() { \ static MlirDialectRegistrationHooks hooks = { \ + mlirDialectRegistryInsert##Name##Dialect, \ mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect, \ mlir##Name##DialectGetNamespace}; \ return MlirDialectHandle{&hooks}; \ diff --git a/mlir/lib/CAPI/IR/DialectHandle.cpp b/mlir/lib/CAPI/IR/DialectHandle.cpp --- a/mlir/lib/CAPI/IR/DialectHandle.cpp +++ b/mlir/lib/CAPI/IR/DialectHandle.cpp @@ -17,6 +17,11 @@ return unwrap(handle)->getNamespaceHook(); } +void mlirDialectHandleInsertDialect(MlirDialectHandle handle, + MlirDialectRegistry registry) { + unwrap(handle)->insertHook(registry); +} + void mlirDialectHandleRegisterDialect(MlirDialectHandle handle, MlirContext ctx) { unwrap(handle)->registerHook(ctx); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -53,6 +53,11 @@ return static_cast(unwrap(context)->getAvailableDialects().size()); } +void mlirContextAppendDialectRegistry(MlirContext ctx, + MlirDialectRegistry registry) { + unwrap(ctx)->appendDialectRegistry(*unwrap(registry)); +} + // TODO: expose a cheaper way than constructing + sorting a vector only to take // its size. intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { @@ -88,6 +93,18 @@ return wrap(unwrap(dialect)->getNamespace()); } +//===----------------------------------------------------------------------===// +// DialectRegistry API. +//===----------------------------------------------------------------------===// + +MlirDialectRegistry mlirDialectRegistryCreate() { + return wrap(new DialectRegistry()); +} + +void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { + delete unwrap(registry); +} + //===----------------------------------------------------------------------===// // Printing flags API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1904,6 +1904,36 @@ return 0; } +int testDialectRegistry() { + fprintf(stderr, "@testDialectRegistry\n"); + + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + if (mlirDialectRegistryIsNull(registry)) { + fprintf(stderr, "ERROR: Expected registry to be present\n"); + return 1; + } + + MlirDialectHandle stdHandle = mlirGetDialectHandle__std__(); + mlirDialectHandleInsertDialect(stdHandle, registry); + + MlirContext ctx = mlirContextCreate(); + if (mlirContextGetNumRegisteredDialects(ctx) != 0) { + fprintf(stderr, + "ERROR: Expected no dialects to be registered to new context\n"); + } + + mlirContextAppendDialectRegistry(ctx, registry); + if (mlirContextGetNumRegisteredDialects(ctx) != 1) { + fprintf(stderr, "ERROR: Expected the dialect in the registry to be " + "registered to the context\n"); + } + + mlirContextDestroy(ctx); + mlirDialectRegistryDestroy(registry); + + return 0; +} + void testDiagnostics() { MlirContext ctx = mlirContextCreate(); MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( @@ -1988,6 +2018,8 @@ return 13; if (testSymbolTable(ctx)) return 14; + if (testDialectRegistry()) + return 15; mlirContextDestroy(ctx);