diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h --- a/mlir/include/mlir-c/Registration.h +++ b/mlir/include/mlir-c/Registration.h @@ -23,47 +23,34 @@ // API name (i.e. "Standard", "Tensor", "Linalg") and namespace (i.e. "std", // "tensor", "linalg"). The following declarations are produced: // -// /// Registers the dialect with the given context. This allows the -// /// dialect to be loaded dynamically if needed when parsing. */ -// void mlirContextRegister{NAME}Dialect(MlirContext); -// -// /// Loads the dialect into the given context. The dialect does _not_ -// /// have to be registered in advance. -// MlirDialect mlirContextLoad{NAME}Dialect(MlirContext context); -// -// /// Returns the namespace of the Standard dialect, suitable for loading it. -// MlirStringRef mlir{NAME}DialectGetNamespace(); -// // /// Gets the above hook methods in struct form for a dialect by namespace. // /// This is intended to facilitate dynamic lookup and registration of // /// dialects via a plugin facility based on shared library symbol lookup. -// const MlirDialectRegistrationHooks *mlirGetDialectHooks__{NAMESPACE}__(); +// const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__(); // // This is done via a common macro to facilitate future expansion to // registration schemes. //===----------------------------------------------------------------------===// +struct MlirDialectHandle { + const void *ptr; +}; +typedef struct MlirDialectHandle MlirDialectHandle; + #define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ - MLIR_CAPI_EXPORTED void mlirContextRegister##Name##Dialect( \ - MlirContext context); \ - MLIR_CAPI_EXPORTED MlirDialect mlirContextLoad##Name##Dialect( \ - MlirContext context); \ - MLIR_CAPI_EXPORTED MlirStringRef mlir##Name##DialectGetNamespace(); \ - MLIR_CAPI_EXPORTED const MlirDialectRegistrationHooks \ - *mlirGetDialectHooks__##Namespace##__() + MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__() -/// Hooks for dynamic discovery of dialects. -typedef void (*MlirContextRegisterDialectHook)(MlirContext context); -typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context); -typedef MlirStringRef (*MlirDialectGetNamespaceHook)(); +/// Returns the namespace associated with the provided dialect handle. +MLIR_CAPI_EXPORTED +MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); -/// Structure of dialect registration hooks. -struct MlirDialectRegistrationHooks { - MlirContextRegisterDialectHook registerHook; - MlirContextLoadDialectHook loadHook; - MlirDialectGetNamespaceHook getNamespaceHook; -}; -typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks; +/// Registers the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, + MlirContext); + +/// Loads the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, + MlirContext); /// Registers all dialects known to core MLIR with the provided Context. /// This is needed before creating IR for these Dialects. diff --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h --- a/mlir/include/mlir/CAPI/Registration.h +++ b/mlir/include/mlir/CAPI/Registration.h @@ -20,21 +20,34 @@ // of the dialect class. //===----------------------------------------------------------------------===// +/// Hooks for dynamic discovery of dialects. +typedef void (*MlirContextRegisterDialectHook)(MlirContext context); +typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context); +typedef MlirStringRef (*MlirDialectGetNamespaceHook)(); + +/// Structure of dialect registration hooks. +struct MlirDialectRegistrationHooks { + MlirContextRegisterDialectHook registerHook; + MlirContextLoadDialectHook loadHook; + MlirDialectGetNamespaceHook getNamespaceHook; +}; +typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks; + #define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \ - void mlirContextRegister##Name##Dialect(MlirContext context) { \ + static void mlirContextRegister##Name##Dialect(MlirContext context) { \ unwrap(context)->getDialectRegistry().insert(); \ } \ - MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \ + static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \ return wrap(unwrap(context)->getOrLoadDialect()); \ } \ - MlirStringRef mlir##Name##DialectGetNamespace() { \ + static MlirStringRef mlir##Name##DialectGetNamespace() { \ return wrap(ClassName::getDialectNamespace()); \ } \ - const MlirDialectRegistrationHooks *mlirGetDialectHooks__##Namespace##__() { \ + MlirDialectHandle mlirGetDialectHandle__##Namespace##__() { \ static MlirDialectRegistrationHooks hooks = { \ mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect, \ mlir##Name##DialectGetNamespace}; \ - return &hooks; \ + return MlirDialectHandle{&hooks}; \ } #endif // MLIR_CAPI_REGISTRATION_H diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -5,6 +5,7 @@ BuiltinAttributes.cpp BuiltinTypes.cpp Diagnostics.cpp + DialectHandle.cpp IntegerSet.cpp IR.cpp Pass.cpp diff --git a/mlir/lib/CAPI/IR/DialectHandle.cpp b/mlir/lib/CAPI/IR/DialectHandle.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/IR/DialectHandle.cpp @@ -0,0 +1,28 @@ +//===- DialectHandle.cpp - C Interface for MLIR Dialect Operations -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Registration.h" + +static const MlirDialectRegistrationHooks * +unwrapHooks(MlirDialectHandle handle) { + return (const MlirDialectRegistrationHooks *)handle.ptr; +} + +MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) { + return unwrapHooks(handle)->getNamespaceHook(); +} + +void mlirDialectHandleRegisterDialect(MlirDialectHandle handle, + MlirContext ctx) { + unwrapHooks(handle)->registerHook(ctx); +} + +MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle, + MlirContext ctx) { + return unwrapHooks(handle)->loadHook(ctx); +} diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1412,23 +1412,26 @@ if (mlirContextGetNumLoadedDialects(ctx) != 1) return 1; - MlirDialect std = - mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace()); + MlirDialectHandle stdHandle = mlirGetDialectHandle__std__(); + + MlirDialect std = mlirContextGetOrLoadDialect( + ctx, mlirDialectHandleGetNamespace(stdHandle)); if (!mlirDialectIsNull(std)) return 2; - mlirContextRegisterStandardDialect(ctx); + mlirDialectHandleRegisterDialect(stdHandle, ctx); - std = mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace()); + std = mlirContextGetOrLoadDialect(ctx, + mlirDialectHandleGetNamespace(stdHandle)); if (mlirDialectIsNull(std)) return 3; - MlirDialect alsoStd = mlirContextLoadStandardDialect(ctx); + MlirDialect alsoStd = mlirDialectHandleLoadDialect(stdHandle, ctx); if (!mlirDialectEqual(std, alsoStd)) return 4; MlirStringRef stdNs = mlirDialectGetNamespace(std); - MlirStringRef alsoStdNs = mlirStandardDialectGetNamespace(); + MlirStringRef alsoStdNs = mlirDialectHandleGetNamespace(stdHandle); if (stdNs.length != alsoStdNs.length || strncmp(stdNs.data, alsoStdNs.data, stdNs.length)) return 5;