diff --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md --- a/mlir/docs/CAPI.md +++ b/mlir/docs/CAPI.md @@ -75,6 +75,28 @@ expect null objects as arguments unless explicitly stated otherwise. API functions _may_ return null objects. +### Type Hierarchies + +MLIR objects can form type hierarchies in C++. For example, all IR classes +representing types are derived from `mlir::Type`, some of them may also be also +derived from common base classes such as `mlir::ShapedType` or dialect-specific +base classes. Type hierarchies are exposed to C API through naming conventions +as follows. + +- Only the top-level class of each hierarchy is exposed, e.g. `MlirType` is + defined as a type but `MlirShapedType` is not. This avoids the need for + explicit upcasting when passing an object of a derived type to a function + that expects a base type (this happens more often in core/standard APIs, + while downcasting usually involves further checks anyway). +- A type `Y` that derives from `X` provides a function `int mlirXIsAY(MlirX)` + that returns a non-zero value if the given dynamic instance of `X` is also + an instance of `Y`. For example, `int MlirTypeIsAInteger(MlirType)`. +- A function that expects a derived type as its first argument takes the base + type instead and documents the expectation by using `Y` in its name + `MlirY<...>(MlirX, ...)`. This function asserts that the dynamic instance of + its first argument is `Y`, and it is the responsibility of the caller to + ensure it is indeed the case. + ### Conversion To String and Printing IR objects can be converted to a string representation, for example for @@ -96,11 +118,11 @@ For convenience, `mlirXDump(MlirX)` functions are provided to print the given object to the standard error stream. -### Common Patterns +## Common Patterns The API adopts the following patterns for recurrent functionality in MLIR. -#### Indexed Components +### Indexed Components An object has an _indexed component_ if it has fields accessible using a zero-based contiguous integer index, typically arrays. For example, an @@ -120,7 +142,7 @@ type of the subobject. For example, `mlirOperationGetOperand` returns a `MlirValue`. -#### Iterable Components +### Iterable Components An object has an _iterable component_ if it has iterators accessing its fields in some order other than integer indexing, typically linked lists. For example, @@ -146,3 +168,17 @@ /* User 'iter'. */ } ``` + +## Extending the API + +### Extensions for Dialect Attributes and Types + +Dialect attributes and types can follow the example of standard attrbutes and +types, provided that implementations live in separate directories, i.e. +`include/mlir-c/<...>Dialect/` and `lib/CAPI/<...>Dialect/`. The core APIs +provide implementation-private headers in `include/mlir/CAPI/IR` that allow one +to convert between opaque C structures for core IR components and their C++ +counterparts. `wrap` converts a C++ class into a C structure and `unwrap` does +the inverse conversion. Once the a C++ object is available, the API +implementation should rely on `isa` to implement `mlirXIsAY` and is expected to +use `cast` inside other API calls. diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/AffineMap.h @@ -0,0 +1,25 @@ +/*===-- mlir-c/AffineMap.h - C API for MLIR Affine maps -----------*- C -*-===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifndef MLIR_C_AFFINEMAP_H +#define MLIR_C_AFFINEMAP_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +DEFINE_C_API_STRUCT(MlirAffineMap, const void); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_AFFINEMAP_H diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -56,8 +56,6 @@ DEFINE_C_API_STRUCT(MlirLocation, const void); DEFINE_C_API_STRUCT(MlirModule, const void); -#undef DEFINE_C_API_STRUCT - /** Named MLIR attribute. * * A named attribute is essentially a (name, attribute) pair where the name is @@ -314,6 +312,9 @@ /** Parses a type. The type is owned by the context. */ MlirType mlirTypeParseGet(MlirContext context, const char *type); +/** Checks if two types are equal. */ +int mlirTypeEqual(MlirType t1, MlirType t2); + /** Prints a location by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ diff --git a/mlir/include/mlir-c/StandardTypes.h b/mlir/include/mlir-c/StandardTypes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/StandardTypes.h @@ -0,0 +1,249 @@ +/*===-- mlir-c/StandardTypes.h - C API for MLIR Standard types ----*- C -*-===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifndef MLIR_C_STANDARDTYPES_H +#define MLIR_C_STANDARDTYPES_H + +#include "mlir-c/AffineMap.h" +#include "mlir-c/IR.h" +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*============================================================================*/ +/* Integer types. */ +/*============================================================================*/ + +/** Checks whether the given type is an integer type. */ +int mlirTypeIsAInteger(MlirType type); + +/** Creates a signless integer type of the given bitwidth in the context. The + * type is owned by the context. */ +MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth); + +/** Creates a signed integer type of the given bitwidth in the context. The type + * is owned by the context. */ +MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth); + +/** Creates an unsigned integer type of the given bitwidth in the context. The + * type is owned by the context. */ +MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth); + +/** Returns the bitwidth of an integer type. */ +unsigned mlirIntegerTypeGetWidth(MlirType type); + +/** Checks whether the given integer type is signless. */ +int mlirIntegerTypeIsSignless(MlirType type); + +/** Checks whether the given integer type is signed. */ +int mlirIntegerTypeIsSigned(MlirType type); + +/** Checks whether the given integer type is unsigned. */ +int mlirIntegerTypeIsUnsigned(MlirType type); + +/*============================================================================*/ +/* Index type. */ +/*============================================================================*/ + +/** Checks whether the given type is an index type. */ +int mlirTypeIsAIndex(MlirType type); + +/** Creates an index type in the given context. The type is owned by the + * context. */ +MlirType mlirIndexTypeGet(MlirContext ctx); + +/*============================================================================*/ +/* Floating-point types. */ +/*============================================================================*/ + +/** Checks whether the given type is a bf16 type. */ +int mlirTypeIsABF16(MlirType type); + +/** Creates a bf16 type in the given context. The type is owned by the + * context. */ +MlirType mlirBF16TypeGet(MlirContext ctx); + +/** Checks whether the given type is an f16 type. */ +int mlirTypeIsAF16(MlirType type); + +/** Creates an f16 type in the given context. The type is owned by the + * context. */ +MlirType mlirF16TypeGet(MlirContext ctx); + +/** Checks whether the given type is an f32 type. */ +int mlirTypeIsAF32(MlirType type); + +/** Creates an f32 type in the given context. The type is owned by the + * context. */ +MlirType mlirF32TypeGet(MlirContext ctx); + +/** Checks whether the given type is an f64 type. */ +int mlirTypeIsAF64(MlirType type); + +/** Creates a f64 type in the given context. The type is owned by the + * context. */ +MlirType mlirF64TypeGet(MlirContext ctx); + +/*============================================================================*/ +/* None type. */ +/*============================================================================*/ + +/** Checks whether the given type is a None type. */ +int mlirTypeIsANone(MlirType type); + +/** Creates a None type in the given context. The type is owned by the + * context. */ +MlirType mlirNoneTypeGet(MlirContext ctx); + +/*============================================================================*/ +/* Complex type. */ +/*============================================================================*/ + +/** Checks whether the given type is a Complex type. */ +int mlirTypeIsAComplex(MlirType type); + +/** Creates a complex type with the given element type in the same context as + * the element type. The type is owned by the context. */ +MlirType mlirComplexTypeGet(MlirType elementType); + +/** Returns the element type of the given complex type. */ +MlirType mlirComplexTypeGetElementType(MlirType type); + +/*============================================================================*/ +/* Shaped type. */ +/*============================================================================*/ + +/** Checks whether the given type is a Shaped type. */ +int mlirTypeIsAShaped(MlirType type); + +/** Returns the element type of the shaped type. */ +MlirType mlirShapedTypeGetElementType(MlirType type); + +/** Checks whether the given shaped type is ranked. */ +int mlirShapedTypeHasRank(MlirType type); + +/** Returns the rank of the given ranked shaped type. */ +int64_t mlirShapedTypeGetRank(MlirType type); + +/** Checks whether the given shaped type has a static shape. */ +int mlirShapedTypeHasStaticShape(MlirType type); + +/** Checks wither the dim-th dimension of the given shaped type is dynamic. */ +int mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim); + +/** Returns the dim-th dimension of the given ranked shaped type. */ +int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim); + +/** Checks whether the given value is used as a placeholder for dynamic sizes + * in shaped types. */ +int mlirShapedTypeIsDynamicSize(int64_t size); + +/** Checks whether the given value is used as a placeholder for dynamic strides + * and offsets in shaped types. */ +int mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); + +/*============================================================================*/ +/* Vector type. */ +/*============================================================================*/ + +/** Checks whether the given type is a Vector type. */ +int mlirTypeIsAVector(MlirType type); + +/** Creates a vector type of the shape identified by its rank and dimensios, + * with the given element type in the same context as the element type. The type + * is owned by the context. */ +MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape, MlirType elementType); + +/*============================================================================*/ +/* Ranked / Unranked Tensor type. */ +/*============================================================================*/ + +/** Checks whether the given type is a Tensor type. */ +int mlirTypeIsATensor(MlirType type); + +/** Checks whether the given type is a ranked tensor type. */ +int mlirTypeIsARankedTensor(MlirType type); + +/** Checks whether the given type is an unranked tensor type. */ +int mlirTypeIsAUnrankedTensor(MlirType type); + +/** Creates a tensor type of a fixed rank with the given shape and element type + * in the same context as the element type. The type is owned by the context. */ +MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape, + MlirType elementType); + +/** Creates an unranked tensor type with the given element type in the same + * context as the element type. The type is owned by the context. */ +MlirType mlirUnrankedTensorTypeGet(MlirType elementType); + +/*============================================================================*/ +/* Ranked / Unranked MemRef type. */ +/*============================================================================*/ + +/** Checks whether the given type is a MemRef type. */ +int mlirTypeIsAMemRef(MlirType type); + +/** Checks whether the given type is an UnrankedMemRef type. */ +int mlirTypeIsAUnrankedMemRef(MlirType type); + +/** Creates a MemRef type with the given rank and shape, a potentially empty + * list of affine layout maps, the given memory space and element type, in the + * same context as element type. The type is owned by the context. */ +MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape, + intptr_t numMaps, MlirAttribute *affineMaps, + unsigned memorySpace); + +/** Creates a MemRef type with the given rank, shape, memory space and element + * type in the same context as the element type. The type has no affine maps, + * i.e. represents a default row-major contiguous memref. The type is owned by + * the context. */ +MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, + int64_t *shape, unsigned memorySpace); + +/** Creates an Unranked MemRef type with the given element type and in the given + * memory space. The type is owned by the context of element type. */ +MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace); + +/** Returns the number of affine layout maps in the given MemRef type. */ +intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); + +/** Returns the pos-th affine map of the given MemRef type. */ +MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos); + +/** Returns the memory space of the given MemRef type. */ +unsigned mlirMemRefTypeGetMemorySpace(MlirType type); + +/** Returns the memory spcae of the given Unranked MemRef type. */ +unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type); + +/*============================================================================*/ +/* Tuple type. */ +/*============================================================================*/ + +/** Checks whether the given type is a tuple type. */ +int mlirTypeIsATuple(MlirType type); + +/** Creates a tuple type that consists of the given list of elemental types. The + * type is owned by the context. */ +MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, + MlirType *elements); + +/** Returns the number of types contained in a tuple. */ +intptr_t mlirTupleTypeGetNumTypes(MlirType type); + +/** Returns the pos-th type in the tuple type. */ +MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_STANDARDTYPES_H diff --git a/mlir/include/mlir/CAPI/AffineMap.h b/mlir/include/mlir/CAPI/AffineMap.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/CAPI/AffineMap.h @@ -0,0 +1,24 @@ +//===- AffineMap.h - C API Utils for Affine Maps ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// MLIR Affine maps. This file should not be included from C++ code other than +// C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_AFFINEMAP_H +#define MLIR_CAPI_AFFINEMAP_H + +#include "mlir-c/AffineMap.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/AffineMap.h" + +DEFINE_C_API_METHODS(MlirAffineMap, mlir::AffineMap) + +#endif // MLIR_CAPI_AFFINEMAP_H diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/CAPI/IR.h @@ -0,0 +1,34 @@ +//===- IR.h - C API Utils for Core MLIR classes -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// core MLIR classes. This file should not be included from C++ code other than +// C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INCLUDE_MLIR_CAPI_IR_H +#define MLIR_INCLUDE_MLIR_CAPI_IR_H + +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" + +DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) +DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) +DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) +DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) + +DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute) +DEFINE_C_API_METHODS(MlirLocation, mlir::Location) +DEFINE_C_API_METHODS(MlirType, mlir::Type) +DEFINE_C_API_METHODS(MlirValue, mlir::Value) +DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) + +#endif // MLIR_INCLUDE_MLIR_CAPI_IR_H diff --git a/mlir/include/mlir/CAPI/Wrap.h b/mlir/include/mlir/CAPI/Wrap.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/CAPI/Wrap.h @@ -0,0 +1,56 @@ +//===- Wrap.h - C API Utilities ---------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains common definitions for wrapping opaque C++ pointers into +// C structures for the purpose of C API. This file should not be included from +// C++ code other than C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_WRAP_H +#define MLIR_CAPI_WRAP_H + +#include "mlir-c/IR.h" +#include "mlir/Support/LLVM.h" + +/* ========================================================================== */ +/* Definitions of methods for non-owning structures used in C API. */ +/* ========================================================================== */ + +#define DEFINE_C_API_PTR_METHODS(name, cpptype) \ + static inline name wrap(cpptype *cpp) { return name{cpp}; } \ + static inline cpptype *unwrap(name c) { \ + return static_cast(c.ptr); \ + } + +#define DEFINE_C_API_METHODS(name, cpptype) \ + static inline name wrap(cpptype cpp) { \ + return name{cpp.getAsOpaquePointer()}; \ + } \ + static inline cpptype unwrap(name c) { \ + return cpptype::getFromOpaquePointer(c.ptr); \ + } + +template +static llvm::ArrayRef unwrapList(size_t size, CTy *first, + llvm::SmallVectorImpl &storage) { + static_assert( + std::is_same())), CppTy>::value, + "incompatible C and C++ types"); + + if (size == 0) + return llvm::None; + + assert(storage.empty() && "expected to populate storage"); + storage.reserve(size); + for (size_t i = 0; i < size; ++i) + storage.push_back(unwrap(*(first + i))); + return storage; +} + +#endif // MLIR_CAPI_WRAP_H diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -196,6 +196,14 @@ friend ::llvm::hash_code hash_value(AffineMap arg); + /// Methods supporting C API. + const void *getAsOpaquePointer() const { + return static_cast(map); + } + static AffineMap getFromOpaquePointer(const void *pointer) { + return AffineMap(reinterpret_cast(const_cast(pointer))); + } + private: ImplType *map; diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/IR/AffineMap.cpp @@ -0,0 +1,15 @@ +//===- AffineMap.cpp - C API for MLIR Affine Maps -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/AffineMap.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/AffineMap.h" +#include "mlir/IR/AffineMap.h" + +// This is a placeholder for affine map bindings. The file is here to serve as a +// compilation unit that includes the headers. diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -1,6 +1,8 @@ # Main API. add_mlir_library(MLIRCAPIIR + AffineMap.cpp IR.cpp + StandardTypes.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -8,6 +8,7 @@ #include "mlir-c/IR.h" +#include "mlir/CAPI/IR.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" @@ -17,46 +18,6 @@ using namespace mlir; -/* ========================================================================== */ -/* Definitions of methods for non-owning structures used in C API. */ -/* ========================================================================== */ - -#define DEFINE_C_API_PTR_METHODS(name, cpptype) \ - static name wrap(cpptype *cpp) { return name{cpp}; } \ - static cpptype *unwrap(name c) { return static_cast(c.ptr); } - -DEFINE_C_API_PTR_METHODS(MlirContext, MLIRContext) -DEFINE_C_API_PTR_METHODS(MlirOperation, Operation) -DEFINE_C_API_PTR_METHODS(MlirBlock, Block) -DEFINE_C_API_PTR_METHODS(MlirRegion, Region) - -#define DEFINE_C_API_METHODS(name, cpptype) \ - static name wrap(cpptype cpp) { return name{cpp.getAsOpaquePointer()}; } \ - static cpptype unwrap(name c) { return cpptype::getFromOpaquePointer(c.ptr); } - -DEFINE_C_API_METHODS(MlirAttribute, Attribute) -DEFINE_C_API_METHODS(MlirLocation, Location); -DEFINE_C_API_METHODS(MlirType, Type) -DEFINE_C_API_METHODS(MlirValue, Value) -DEFINE_C_API_METHODS(MlirModule, ModuleOp) - -template -static ArrayRef unwrapList(intptr_t size, CTy *first, - SmallVectorImpl &storage) { - static_assert( - std::is_same())), CppTy>::value, - "incompatible C and C++ types"); - - if (size == 0) - return llvm::None; - - assert(storage.empty() && "expected to populate storage"); - storage.reserve(size); - for (intptr_t i = 0; i < size; ++i) - storage.push_back(unwrap(*(first + i))); - return storage; -} - /* ========================================================================== */ /* Printing helper. */ /* ========================================================================== */ @@ -388,6 +349,8 @@ return wrap(mlir::parseType(type, unwrap(context))); } +int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } + void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) { CallbackOstream stream(callback, userData); unwrap(type).print(stream); diff --git a/mlir/lib/CAPI/IR/StandardTypes.cpp b/mlir/lib/CAPI/IR/StandardTypes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/IR/StandardTypes.cpp @@ -0,0 +1,263 @@ +//===- StandardTypes.cpp - C Interface to MLIR Standard Types -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/StandardTypes.h" +#include "mlir-c/AffineMap.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/AffineMap.h" +#include "mlir/CAPI/IR.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; + +/* ========================================================================== */ +/* Integer types. */ +/* ========================================================================== */ + +int mlirTypeIsAInteger(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { + return wrap(IntegerType::get(bitwidth, unwrap(ctx))); +} + +MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) { + return wrap(IntegerType::get(bitwidth, IntegerType::Signed, unwrap(ctx))); +} + +MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) { + return wrap(IntegerType::get(bitwidth, IntegerType::Unsigned, unwrap(ctx))); +} + +unsigned mlirIntegerTypeGetWidth(MlirType type) { + return unwrap(type).cast().getWidth(); +} + +int mlirIntegerTypeIsSignless(MlirType type) { + return unwrap(type).cast().isSignless(); +} + +int mlirIntegerTypeIsSigned(MlirType type) { + return unwrap(type).cast().isSigned(); +} + +int mlirIntegerTypeIsUnsigned(MlirType type) { + return unwrap(type).cast().isUnsigned(); +} + +/* ========================================================================== */ +/* Index type. */ +/* ========================================================================== */ + +int mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa(); } + +MlirType mlirIndexTypeGet(MlirContext ctx) { + return wrap(IndexType::get(unwrap(ctx))); +} + +/* ========================================================================== */ +/* Floating-point types. */ +/* ========================================================================== */ + +int mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } + +MlirType mlirBF16TypeGet(MlirContext ctx) { + return wrap(FloatType::getBF16(unwrap(ctx))); +} + +int mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } + +MlirType mlirF16TypeGet(MlirContext ctx) { + return wrap(FloatType::getF16(unwrap(ctx))); +} + +int mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } + +MlirType mlirF32TypeGet(MlirContext ctx) { + return wrap(FloatType::getF32(unwrap(ctx))); +} + +int mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } + +MlirType mlirF64TypeGet(MlirContext ctx) { + return wrap(FloatType::getF64(unwrap(ctx))); +} + +/* ========================================================================== */ +/* None type. */ +/* ========================================================================== */ + +int mlirTypeIsANone(MlirType type) { return unwrap(type).isa(); } + +MlirType mlirNoneTypeGet(MlirContext ctx) { + return wrap(NoneType::get(unwrap(ctx))); +} + +/* ========================================================================== */ +/* Complex type. */ +/* ========================================================================== */ + +int mlirTypeIsAComplex(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirComplexTypeGet(MlirType elementType) { + return wrap(ComplexType::get(unwrap(elementType))); +} + +MlirType mlirComplexTypeGetElementType(MlirType type) { + return wrap(unwrap(type).cast().getElementType()); +} + +/* ========================================================================== */ +/* Shaped type. */ +/* ========================================================================== */ + +int mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa(); } + +MlirType mlirShapedTypeGetElementType(MlirType type) { + return wrap(unwrap(type).cast().getElementType()); +} + +int mlirShapedTypeHasRank(MlirType type) { + return unwrap(type).cast().hasRank(); +} + +int64_t mlirShapedTypeGetRank(MlirType type) { + return unwrap(type).cast().getRank(); +} + +int mlirShapedTypeHasStaticShape(MlirType type) { + return unwrap(type).cast().hasStaticShape(); +} + +int mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { + return unwrap(type).cast().isDynamicDim( + static_cast(dim)); +} + +int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { + return unwrap(type).cast().getDimSize(static_cast(dim)); +} + +int mlirShapedTypeIsDynamicSize(int64_t size) { + return ShapedType::isDynamic(size); +} + +int mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { + return ShapedType::isDynamicStrideOrOffset(val); +} + +/* ========================================================================== */ +/* Vector type. */ +/* ========================================================================== */ + +int mlirTypeIsAVector(MlirType type) { return unwrap(type).isa(); } + +MlirType mlirVectorTypeGet(intptr_t rank, int64_t *shape, + MlirType elementType) { + return wrap( + VectorType::get(llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(elementType))); +} + +/* ========================================================================== */ +/* Ranked / Unranked tensor type. */ +/* ========================================================================== */ + +int mlirTypeIsATensor(MlirType type) { return unwrap(type).isa(); } + +int mlirTypeIsARankedTensor(MlirType type) { + return unwrap(type).isa(); +} + +int mlirTypeIsAUnrankedTensor(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirRankedTensorTypeGet(intptr_t rank, int64_t *shape, + MlirType elementType) { + return wrap(RankedTensorType::get( + llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(elementType))); +} + +MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { + return wrap(UnrankedTensorType::get(unwrap(elementType))); +} + +/* ========================================================================== */ +/* Ranked / Unranked MemRef type. */ +/* ========================================================================== */ + +int mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa(); } + +MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, int64_t *shape, + intptr_t numMaps, MlirAffineMap *affineMaps, + unsigned memorySpace) { + SmallVector maps; + (void)unwrapList(numMaps, affineMaps, maps); + return wrap( + MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(elementType), maps, memorySpace)); +} + +MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, + int64_t *shape, unsigned memorySpace) { + return wrap( + MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(elementType), llvm::None, memorySpace)); +} + +intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { + return static_cast( + unwrap(type).cast().getAffineMaps().size()); +} + +MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) { + return wrap(unwrap(type).cast().getAffineMaps()[pos]); +} + +unsigned mlirMemRefTypeGetMemorySpace(MlirType type) { + return unwrap(type).cast().getMemorySpace(); +} + +int mlirTypeIsAUnrankedMemRef(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) { + return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace)); +} + +unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { + return unwrap(type).cast().getMemorySpace(); +} + +/* ========================================================================== */ +/* Tuple type. */ +/* ========================================================================== */ + +int mlirTypeIsATuple(MlirType type) { return unwrap(type).isa(); } + +MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, + MlirType *elements) { + SmallVector types; + ArrayRef typeRef = unwrapList(numElements, elements, types); + return wrap(TupleType::get(typeRef, unwrap(ctx))); +} + +intptr_t mlirTupleTypeGetNumTypes(MlirType type) { + return unwrap(type).cast().size(); +} + +MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { + return wrap(unwrap(type).cast().getType(static_cast(pos))); +} diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -12,6 +12,7 @@ #include "mlir-c/IR.h" #include "mlir-c/Registration.h" +#include "mlir-c/StandardTypes.h" #include #include @@ -240,6 +241,146 @@ fprintf(stderr, "\n"); } +/// Dumps instances of all standard types to check that C API works correctly. +/// Additionally, performs simple identity checks that a standard type +/// constructed with C API can be inspected and has the expected type. The +/// latter achieves full coverage of C API for standard types. Returns 0 on +/// success and a non-zero error code on failure. +static int printStandardTypes(MlirContext ctx) { + // Integer types. + MlirType i32 = mlirIntegerTypeGet(ctx, 32); + MlirType si32 = mlirIntegerTypeSignedGet(ctx, 32); + MlirType ui32 = mlirIntegerTypeUnsignedGet(ctx, 32); + if (!mlirTypeIsAInteger(i32) || mlirTypeIsAF32(i32)) + return 1; + if (!mlirTypeIsAInteger(si32) || !mlirIntegerTypeIsSigned(si32)) + return 2; + if (!mlirTypeIsAInteger(ui32) || !mlirIntegerTypeIsUnsigned(ui32)) + return 3; + if (mlirTypeEqual(i32, ui32) || mlirTypeEqual(i32, si32)) + return 4; + if (mlirIntegerTypeGetWidth(i32) != mlirIntegerTypeGetWidth(si32)) + return 5; + mlirTypeDump(i32); + fprintf(stderr, "\n"); + mlirTypeDump(si32); + fprintf(stderr, "\n"); + mlirTypeDump(ui32); + fprintf(stderr, "\n"); + + // Index type. + MlirType index = mlirIndexTypeGet(ctx); + if (!mlirTypeIsAIndex(index)) + return 6; + mlirTypeDump(index); + fprintf(stderr, "\n"); + + // Floating-point types. + MlirType bf16 = mlirBF16TypeGet(ctx); + MlirType f16 = mlirF16TypeGet(ctx); + MlirType f32 = mlirF32TypeGet(ctx); + MlirType f64 = mlirF64TypeGet(ctx); + if (!mlirTypeIsABF16(bf16)) + return 7; + if (!mlirTypeIsAF16(f16)) + return 9; + if (!mlirTypeIsAF32(f32)) + return 10; + if (!mlirTypeIsAF64(f64)) + return 11; + mlirTypeDump(bf16); + fprintf(stderr, "\n"); + mlirTypeDump(f16); + fprintf(stderr, "\n"); + mlirTypeDump(f32); + fprintf(stderr, "\n"); + mlirTypeDump(f64); + fprintf(stderr, "\n"); + + // None type. + MlirType none = mlirNoneTypeGet(ctx); + if (!mlirTypeIsANone(none)) + return 12; + mlirTypeDump(none); + fprintf(stderr, "\n"); + + // Complex type. + MlirType cplx = mlirComplexTypeGet(f32); + if (!mlirTypeIsAComplex(cplx) || + !mlirTypeEqual(mlirComplexTypeGetElementType(cplx), f32)) + return 13; + mlirTypeDump(cplx); + fprintf(stderr, "\n"); + + // Vector (and Shaped) type. ShapedType is a common base class for vectors, + // memrefs and tensors, one cannot create instances of this class so it is + // tested on an instance of vector type. + int64_t shape[] = {2, 3}; + MlirType vector = + mlirVectorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32); + if (!mlirTypeIsAVector(vector) || !mlirTypeIsAShaped(vector)) + return 14; + if (!mlirTypeEqual(mlirShapedTypeGetElementType(vector), f32) || + !mlirShapedTypeHasRank(vector) || mlirShapedTypeGetRank(vector) != 2 || + mlirShapedTypeGetDimSize(vector, 0) != 2 || + mlirShapedTypeIsDynamicDim(vector, 0) || + mlirShapedTypeGetDimSize(vector, 1) != 3 || + !mlirShapedTypeHasStaticShape(vector)) + return 15; + mlirTypeDump(vector); + fprintf(stderr, "\n"); + + // Ranked tensor type. + MlirType rankedTensor = + mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32); + if (!mlirTypeIsATensor(rankedTensor) || + !mlirTypeIsARankedTensor(rankedTensor)) + return 16; + mlirTypeDump(rankedTensor); + fprintf(stderr, "\n"); + + // Unranked tensor type. + MlirType unrankedTensor = mlirUnrankedTensorTypeGet(f32); + if (!mlirTypeIsATensor(unrankedTensor) || + !mlirTypeIsAUnrankedTensor(unrankedTensor) || + mlirShapedTypeHasRank(unrankedTensor)) + return 17; + mlirTypeDump(unrankedTensor); + fprintf(stderr, "\n"); + + // MemRef type. + MlirType memRef = mlirMemRefTypeContiguousGet( + f32, sizeof(shape) / sizeof(int64_t), shape, 2); + if (!mlirTypeIsAMemRef(memRef) || + mlirMemRefTypeGetNumAffineMaps(memRef) != 0 || + mlirMemRefTypeGetMemorySpace(memRef) != 2) + return 18; + mlirTypeDump(memRef); + fprintf(stderr, "\n"); + + // Unranked MemRef type. + MlirType unrankedMemRef = mlirUnrankedMemRefTypeGet(f32, 4); + if (!mlirTypeIsAUnrankedMemRef(unrankedMemRef) || + mlirTypeIsAMemRef(unrankedMemRef) || + mlirUnrankedMemrefGetMemorySpace(unrankedMemRef) != 4) + return 19; + mlirTypeDump(unrankedMemRef); + fprintf(stderr, "\n"); + + // Tuple type. + MlirType types[] = {unrankedMemRef, f32}; + MlirType tuple = mlirTupleTypeGet(ctx, 2, types); + if (!mlirTypeIsATuple(tuple) || + mlirTupleTypeGetNumTypes(tuple) != 2 || + !mlirTypeEqual(mlirTupleTypeGetType(tuple, 0), unrankedMemRef) || + !mlirTypeEqual(mlirTupleTypeGetType(tuple, 1), f32)) + return 20; + mlirTypeDump(tuple); + fprintf(stderr, "\n"); + + return 0; +} + int main() { mlirRegisterAllDialects(); MlirContext ctx = mlirContextCreate(); @@ -293,6 +434,31 @@ // clang-format on mlirModuleDestroy(moduleOp); + + // clang-format off + // CHECK-LABEL: @types + // CHECK: i32 + // CHECK: si32 + // CHECK: ui32 + // CHECK: index + // CHECK: bf16 + // CHECK: f16 + // CHECK: f32 + // CHECK: f64 + // CHECK: none + // CHECK: complex + // CHECK: vector<2x3xf32> + // CHECK: tensor<2x3xf32> + // CHECK: tensor<*xf32> + // CHECK: memref<2x3xf32, 2> + // CHECK: memref<*xf32, 4> + // CHECK: tuple, f32> + // CHECK: 0 + // clang-format on + fprintf(stderr, "@types"); + int errcode = printStandardTypes(ctx); + fprintf(stderr, "%d\n", errcode); + mlirContextDestroy(ctx); return 0;