diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -24,6 +24,8 @@ extern "C" { #endif +struct MlirStringRef; + /*============================================================================*/ /** Opaque type declarations. * @@ -46,6 +48,7 @@ typedef struct name name DEFINE_C_API_STRUCT(MlirContext, void); +DEFINE_C_API_STRUCT(MlirDialect, void); DEFINE_C_API_STRUCT(MlirOperation, void); DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); @@ -91,6 +94,37 @@ /** Takes an MLIR context owned by the caller and destroys it. */ void mlirContextDestroy(MlirContext context); +/** Returns the number of dialects registered with the given context. A + * registered dialect may be loaded explicitly or as a pass dependency. */ +intptr_t mlirContextGetNumRegisteredDialects(MlirContext context); + +/** Returns the number of dialects loaded by the context. A dialect must be + * loaded before the opeations, types and attributes it contains can be used. */ +intptr_t mlirContextGetNumLoadedDialects(MlirContext context); + +/** Gets the dialect instance owned by the given context using the dialect + * namespace to identify it, loads (i.e., construts the instance of) the dialect + * if necessary. If the dialect is not registered with the context, returns + * null. */ +MlirDialect mlirContextGetOrLoadDialect(MlirContext context, + struct MlirStringRef name); + +/*============================================================================*/ +/* Dialect API. */ +/*============================================================================*/ + +/** Returns the context that owns the dialect. */ +MlirContext mlirDialectGetContext(MlirDialect dialect); + +/** Checks if the dialect is null. */ +int mlirDialectIsNull(MlirDialect dialect); + +/** Checks if two dialects are equal. */ +int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2); + +/** Returns the namespace of the given dialect. */ +struct MlirStringRef mlirDialectGetNamespace(MlirDialect dialect); + /*============================================================================*/ /* Location API. */ /*============================================================================*/ diff --git a/mlir/include/mlir-c/StandardDialect.h b/mlir/include/mlir-c/StandardDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/StandardDialect.h @@ -0,0 +1,39 @@ +/*===-- mlir-c/StandardDialect.h - C API for Standard dialect -----*- C -*-===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +|*===----------------------------------------------------------------------===*| +|* *| +|* This header declares the C interface for registering and accessing the *| +|* Standard dialect. A dialect should be registered with a context to make it *| +|* available to users of the context. These users must load the dialect *| +|* before using any of its attributes, operations or types. Parser and pass *| +|* manager can load registered dialects automatically. *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifndef MLIR_C_STANDARDDIALECT_H +#define MLIR_C_STANDARDDIALECT_H + +#ifdef __cplusplus +extern "C" { +#endif + +struct MlirDialect; +struct MlirContext; +struct MlirStringRef; + +/** Registers the Standard dialect with the given context. */ +void mlirContextRegisterStandardDialect(struct MlirContext context); + +/** Returns the namespace of the Standard dialect, suitable for loading it. */ +struct MlirStringRef mlirStandardDialectGetNamespace(); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_STANDARDDIALECT_H diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -21,6 +21,7 @@ #include "mlir/IR/Operation.h" DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) +DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Registration) +add_subdirectory(Standard) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -7,8 +7,10 @@ //===----------------------------------------------------------------------===// #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" @@ -34,6 +36,41 @@ void mlirContextDestroy(MlirContext context) { delete unwrap(context); } +intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { + return static_cast(unwrap(context)->getAvailableDialects().size()); +} + +// TODO: expose a cheaper way than constructing + sorting a vector only to take +// its size. +intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { + return static_cast(unwrap(context)->getLoadedDialects().size()); +} + +MlirDialect mlirContextGetOrLoadDialect(MlirContext context, + MlirStringRef name) { + return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); +} + +/* ========================================================================== */ +/* Dialect API. */ +/* ========================================================================== */ + +MlirContext mlirDialectGetContext(MlirDialect dialect) { + return wrap(unwrap(dialect)->getContext()); +} + +int mlirDialectIsNull(MlirDialect dialect) { + return unwrap(dialect) == nullptr; +} + +int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { + return unwrap(dialect1) == unwrap(dialect2); +} + +MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { + return wrap(unwrap(dialect)->getNamespace()); +} + /* ========================================================================== */ /* Location API. */ /* ========================================================================== */ diff --git a/mlir/lib/CAPI/Standard/CMakeLists.txt b/mlir/lib/CAPI/Standard/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Standard/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_library(MLIRCAPIStandard + + StandardDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir-c + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRStandardOps + ) diff --git a/mlir/lib/CAPI/Standard/StandardDialect.cpp b/mlir/lib/CAPI/Standard/StandardDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Standard/StandardDialect.cpp @@ -0,0 +1,21 @@ +//===- StandardDialect.cpp - C Interface for Standard dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/StandardDialect.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +void mlirContextRegisterStandardDialect(MlirContext context) { + unwrap(context)->getDialectRegistry().insert(); +} + +MlirStringRef mlirStandardDialectGetNamespace() { + return wrap(mlir::StandardOpsDialect::getDialectNamespace()); +} diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -15,6 +15,7 @@ #include "mlir-c/Registration.h" #include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardTypes.h" +#include "mlir-c/StandardDialect.h" #include #include @@ -709,6 +710,32 @@ return 0; } +int registerOnlyStd() { + MlirContext ctx = mlirContextCreate(); + // The built-in dialect is always loaded. + if (mlirContextGetNumLoadedDialects(ctx) != 1) + return 1; + + MlirDialect std = + mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace()); + if (!mlirDialectIsNull(std)) + return 2; + + mlirContextRegisterStandardDialect(ctx); + if (mlirContextGetNumRegisteredDialects(ctx) != 1) + return 3; + if (mlirContextGetNumLoadedDialects(ctx) != 1) + return 4; + + std = mlirContextGetOrLoadDialect(ctx, mlirStandardDialectGetNamespace()); + if (mlirDialectIsNull(std)) + return 5; + if (mlirContextGetNumLoadedDialects(ctx) != 2) + return 6; + + return 0; +} + int main() { MlirContext ctx = mlirContextCreate(); mlirRegisterAllDialects(ctx); @@ -837,6 +864,14 @@ errcode = printAffineMap(ctx); fprintf(stderr, "%d\n", errcode); + fprintf(stderr, "@registration\n"); + errcode = registerOnlyStd(); + fprintf(stderr, "%d\n", errcode); + // clang-format off + // CHECK-LABEL: @registration + // CHECK: 0 + // clang-format on + mlirContextDestroy(ctx); return 0;