diff --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md --- a/mlir/docs/CAPI.md +++ b/mlir/docs/CAPI.md @@ -97,10 +97,25 @@ its first argument is `Y`, and it is the responsibility of the caller to ensure it is indeed the case. +### Returning String References + +Numerous MLIR functions return instances of `StringRef` to refer to a non-owning +segment of a string. This segment may or may not be null-terminated. In C API, +these functions take an additional callback argument of type +`MlirStringCallback` (pointer to a function with signature `void (*)(const char +*, intptr_t, void *)`) and a pointer to user-defined data. This callback is +invoked with a pointer to the string segment, its size and is forwarded the +user-defined data. The caller is in charge of managing the string segment +according to its memory model: for strings owned by the object (e.g., string +attributes), the caller can store the pointer and the size and use them directly +as long as the parent object is live or copy the string to a new location with a +null terminator if expected; for generated strings (e.g., in printing), the +caller is expected to copy the string segment if it intends to use it later. + ### Conversion To String and Printing IR objects can be converted to a string representation, for example for -printing, using `mlirXPrint(MlirX, MlirPrintCallback, void *)` functions. These +printing, using `mlirXPrint(MlirX, MlirStringCallback, void *)` functions. These functions accept take arguments a callback with signature `void (*)(const char *, intptr_t, void *)` and a pointer to user-defined data. They call the callback and supply it with chunks of the string representation, provided as a pointer to diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -67,16 +67,16 @@ }; typedef struct MlirNamedAttribute MlirNamedAttribute; -/** A callback for printing to IR objects. +/** A callback for returning string referenes. * - * This function is called back by the printing functions with the following - * arguments: + * This function is called back by the functions that need to return a reference + * to the portion of the string with the following arguments: * - a pointer to the beginning of a string; * - the length of the string (the pointer may point to a larger buffer, not * necessarily null-terminated); * - a pointer to user data forwarded from the printing call. */ -typedef void (*MlirPrintCallback)(const char *, intptr_t, void *); +typedef void (*MlirStringCallback)(const char *, intptr_t, void *); /*============================================================================*/ /* Context API. */ @@ -103,7 +103,7 @@ /** Prints a location by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ -void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, +void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData); /*============================================================================*/ @@ -221,7 +221,7 @@ /** Prints an operation by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ -void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, +void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, void *userData); /** Prints an operation to stderr. */ @@ -289,7 +289,7 @@ /** Prints a block by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ -void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, +void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData); /*============================================================================*/ @@ -302,7 +302,7 @@ /** Prints a value by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ -void mlirValuePrint(MlirValue value, MlirPrintCallback callback, +void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); /*============================================================================*/ @@ -318,7 +318,7 @@ /** Prints a location by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ -void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData); +void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData); /** Prints the type to the standard error stream. */ void mlirTypeDump(MlirType type); @@ -330,10 +330,13 @@ /** Parses an attribute. The attribute is owned by the context. */ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr); +/** Checks if two attributes are equal. */ +int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2); + /** Prints an attribute by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ -void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, +void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData); /** Prints the attrbute to the standard error stream. */ diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/StandardAttributes.h @@ -0,0 +1,430 @@ +/*===-- mlir-c/StandardAttributes.h - C API for Std Attributes-----*- C -*-===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +|*===----------------------------------------------------------------------===*| +|* *| +|* This header declares the C interface to MLIR Standard attributes. *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifndef MLIR_C_STANDARDATTRIBUTES_H +#define MLIR_C_STANDARDATTRIBUTES_H + +#include "mlir-c/AffineMap.h" +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*============================================================================*/ +/* Affine map attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is an affine map attribute. */ +int mlirAttributeIsAAffineMap(MlirAttribute attr); + +/** Creates an affine map attribute wrapping the given map. The attribute + * belongs to the same context as the affine map. */ +MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map); + +/** Returns the affine map wrapped in the given affine map attribute. */ +MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr); + +/*============================================================================*/ +/* Array attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is an array attribute. */ +int mlirAttributeIsAArray(MlirAttribute attr); + +/** Creates an array element containing the given list of elements in the given + * context. */ +MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, + MlirAttribute *elements); + +/** Returns the number of elements stored in the given array attribute. */ +intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr); + +/** Returns pos-th element stored in the given array attribute. */ +MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos); + +/*============================================================================*/ +/* Dictionary attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is a dictionary attribute. */ +int mlirAttributeIsADictionary(MlirAttribute attr); + +/** Creates a dictionary attribute containing the given list of elements in the + * provided context. */ +MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, + MlirNamedAttribute *elements); + +/** Returns the number of attributes contained in a dictionary attribute. */ +intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr); + +/** Returns pos-th element of the given dictionary attribute. */ +MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, + intptr_t pos); + +/** Returns the dictionary attribute element with the given name or NULL if the + * given name does not exist in the dictionary. */ +MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, + const char *name); + +/*============================================================================*/ +/* Floating point attribute. */ +/*============================================================================*/ + +/* TODO: add support for APFloat and APInt to LLVM IR C API, then expose the + * relevant functions here. */ + +/** Checks whether the given attribute is a floating point attribute. */ +int mlirAttributeIsAFloat(MlirAttribute attr); + +/** Creates a floating point attribute in the given context with the given + * double value and double-precision FP semantics. */ +MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, + double value); + +/** Returns the value stored in the given floating point attribute, interpreting + * the value as double. */ +double mlirFloatAttrGetValueDouble(MlirAttribute attr); + +/*============================================================================*/ +/* Integer attribute. */ +/*============================================================================*/ + +/* TODO: add support for APFloat and APInt to LLVM IR C API, then expose the + * relevant functions here. */ + +/** Checks whether the given attribute is an integer attribute. */ +int mlirAttributeIsAInteger(MlirAttribute attr); + +/** Creates an integer attribute of the given type with the given integer + * value. */ +MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value); + +/** Returns the value stored in the given integer attribute, assuming the value + * fits into a 64-bit integer. */ +int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr); + +/*============================================================================*/ +/* Bool attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is a bool attribute. */ +int mlirAttributeIsABool(MlirAttribute attr); + +/** Creates a bool attribute in the given context with the given value. */ +MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value); + +/** Returns the value stored in the given bool attribute. */ +int mlirBoolAttrGetValue(MlirAttribute attr); + +/*============================================================================*/ +/* Integer set attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is an integer set attribute. */ +int mlirAttributeIsAIntegerSet(MlirAttribute attr); + +/*============================================================================*/ +/* Opaque attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is an opaque attribute. */ +int mlirAttributeIsAOpaque(MlirAttribute attr); + +/** Creates an opaque attribute in the given context associated with the dialect + * identified by its namespace. The attribute contains opaque byte data of the + * specified length (data need not be null-terminated). */ +MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, const char *dialectNamespace, + intptr_t dataLength, const char *data, + MlirType type); + +/** Returns the namepsace of the dialect with which the given opaque attribute + * is associated. The namespace string is owned by the context. */ +const char *mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr); + +/** Calls the provided callback with the opaque byte data stored in the given + * opaque attribute. The callback is invoked once, and the data it receives is + * not necessarily null terminated. */ +void mlirOpaqueAttrGetData(MlirAttribute attr, MlirStringCallback callback, + void *userData); + +/*============================================================================*/ +/* String attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is a string attribute. */ +int mlirAttributeIsAString(MlirAttribute attr); + +/** Creates a string attribute in the given context containing the given string. + * The string need not be null-terminated and its length must be specified. */ +MlirAttribute mlirStringAttrGet(MlirContext ctx, intptr_t length, + const char *data); + +/** Creates a string attribute in the given context containing the given string. + * The string need not be null-terminated and its length must be specified. + * Additionally, the attribute has the given type. */ +MlirAttribute mlirStringAttrTypedGet(MlirType type, intptr_t length, + const char *data); + +/** Calls the provided callback with the string stored in the given string + * attribute. The callback is invoked once, and the data it receives is not + * necessarily null terminated. */ +void mlirStringAttrGetValue(MlirAttribute attr, MlirStringCallback callback, + void *userData); + +/*============================================================================*/ +/* SymbolRef attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is a symbol reference attribute. */ +int mlirAttributeIsASymbolRef(MlirAttribute attr); + +/** Creates a symbol reference attribute in the given context referencing a + * symbol identified by the given string inside a list of nested references. + * Each of the references in the list must not be nested. The string need not be + * null-terminated and its length must be specified. */ +MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length, + const char *symbol, intptr_t numReferences, + MlirAttribute *references); + +/** Calls the provided callback with the string containing the root referenced + * symbol. The callback is invoked once, and the data it receives is not + * necessarily null terminated. */ +void mlirSymbolRefAttrGetRootReference(MlirAttribute attr, + MlirStringCallback callback, + void *userData); + +/** Calls the provided callback with the string containing the leaf referenced + * symbol. The callback is invoked once, and the data it receives is not + * necessarily null terminated. */ +void mlirSymbolRefAttrGetLeafReference(MlirAttribute attr, + MlirStringCallback callback, + void *userData); + +/** Returns the number of references nested in the given symbol reference + * attribute. */ +intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr); + +/** Returns pos-th reference nested in the given symbol reference attribute. */ +MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, + intptr_t pos); + +/*============================================================================*/ +/* Flat SymbolRef attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is a flat symbol reference attribute. */ +int mlirAttributeIsAFlatSymbolRef(MlirAttribute attr); + +/** Creates a flat symbol reference attribute in the given context referencing a + * symbol identified by the given string. The string need not be null-terminated + * and its length must be specified. */ +MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, intptr_t length, + const char *symbol); + +/** Calls the provided callback with the string containing the referenced + * symbol. The callback is invoked once, and the data it receives is not + * necessarily null terminated. */ +void mlirFloatSymbolRefAttrGetValue(MlirAttribute attr, + MlirStringCallback callback, + void *userData); + +/*============================================================================*/ +/* Type attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is a type attribute. */ +int mlirAttributeIsAType(MlirAttribute attr); + +/** Creates a type attribute wrapping the given type in the same context as the + * type. */ +MlirAttribute mlirTypeAttrGet(MlirType type); + +/** Returns the type stored in the given type attribute. */ +MlirType mlirTypeAttrGetValue(MlirAttribute attr); + +/*============================================================================*/ +/* Unit attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is a unit attribute. */ +int mlirAttributeIsAUnit(MlirAttribute attr); + +/** Creates a unit attribute in the given context. */ +MlirAttribute mlirUnitAttrGet(MlirContext ctx); + +/*============================================================================*/ +/* Elements attributes. */ +/*============================================================================*/ + +/** Checks whether the given attribute is an elements attribute. */ +int mlirAttributeIsAElements(MlirAttribute attr); + +/** Returns the element at the given rank-dimensional index. */ +MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, + uint64_t *idxs); + +/** Checks whether the given rank-dimensional index is valid in the given + * elements attribute. */ +int mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, + uint64_t *idxs); + +/** Gets the total number of elements in the given elements attribute. In order + * to iterate over the attribute, obtain its type, which must be a statically + * shaped type and use its sizes to build a multi-dimensional index. */ +int64_t mlirElementsAttrGetNumElements(MlirAttribute attr); + +/*============================================================================*/ +/* Dense elements attribute. */ +/*============================================================================*/ + +/* TODO: decide on the interface and add support for complex elements. */ +/* TODO: add support for APFloat and APInt to LLVM IR C API, then expose the + * relevant functions here. */ + +/** Checks whether the given attribute is a dense elements attribute. */ +int mlirAttributeIsADenseElements(MlirAttribute attr); +int mlirAttributeIsADenseIntElements(MlirAttribute attr); +int mlirAttributeIsADenseFPElements(MlirAttribute attr); + +/** Creates a dense elements attribute with the given Shaped type and elements + * in the same context as the type. */ +MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, + intptr_t numElements, + MlirAttribute *elements); + +/** Creates a dense elements attribute with the given Shaped type containing a + * single replicated element (splat). */ +MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, + MlirAttribute element); +MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, + int element); +MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, + uint32_t element); +MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, + int32_t element); +MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, + uint64_t element); +MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, + int64_t element); +MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, + float element); +MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, + double element); + +/** Creates a dense elements attribute with the given shaped type from elements + * of a specific type. Expects the element type of the shaped type to match the + * data element type. */ +MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, + intptr_t numElements, int *elements); +MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, + intptr_t numElements, + uint32_t *elements); +MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, + intptr_t numElements, + int32_t *elements); +MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, + intptr_t numElements, + uint64_t *elements); +MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, + intptr_t numElements, + int64_t *elements); +MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, + intptr_t numElements, + float *elements); +MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, + intptr_t numElements, + double *elements); + +/** Creates a dense elements attribute with the given shaped type from string + * elements. The strings need not be null-terminated and their lengths are + * provided as a separate argument co-indexed with the strs argument. */ +MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, + intptr_t numElements, + intptr_t *strLengths, + const char **strs); +/** Creates a dense elements attribute that has the same data as the given dense + * elements attribute and a different shaped type. The new type must have the + * same total number of elements. */ +MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, + MlirType shapedType); + +/** Checks whether the given dense elements attribute contains a single + * replicated value (splat). */ +int mlirDenseElementsAttrIsSplat(MlirAttribute attr); + +/** Returns the single replicated value (splat) of a specific type contained by + * the given dense elements attribute. */ +MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr); +int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr); +int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr); +uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr); +int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr); +uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr); +float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr); +double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr); +void mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr, + MlirStringCallback callback, + void *userData); + +/** Returns the pos-th value (flat contiguous indexing) of a specific type + * contained by the given dense elements attribute. */ +int mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos); +int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos); +uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos); +int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos); +uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos); +float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos); +double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos); +void mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos, + MlirStringCallback callback, + void *userData); + +/*============================================================================*/ +/* Opaque elements attribute. */ +/*============================================================================*/ + +/* TODO: expose Dialect to the bindings and implement accessors here. */ + +/** Checks whether the given attribute is an opaque elements attribute. */ +int mlirAttributeIsAOpaqueElements(MlirAttribute attr); + +/*============================================================================*/ +/* Sparse elements attribute. */ +/*============================================================================*/ + +/** Checks whether the given attribute is a sparse elements attribute. */ +int mlirAttributeIsASparseElements(MlirAttribute attr); + +/** Creates a sparse elements attribute of the given shape from a list of + * indices and a list of associated values. Both lists are expected to be dense + * elements attributes with the same number of elements. The list of indices is + * expected to contain 64-bit integers. The attribute is created in the same + * context as the type. */ +MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, + MlirAttribute denseIndices, + MlirAttribute denseValues); + +/** Returns the dense elements attribute containing 64-bit integer indices of + * non-null elements in the given sparse elements attribute. */ +MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr); + +/** Returns the dense elements attribute containing the non-null elements in the + * given sparse elements attribute. */ +MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_STANDARDATTRIBUTES_H diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -40,13 +40,13 @@ namespace { /// Accumulates into a python string from a method that accepts an -/// MlirPrintCallback. +/// MlirStringCallback. struct PyPrintAccumulator { py::list parts; void *getUserData() { return this; } - MlirPrintCallback getCallback() { + MlirStringCallback getCallback() { return [](const char *part, intptr_t size, void *userData) { PyPrintAccumulator *printAccum = static_cast(userData); diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRCAPIIR AffineMap.cpp IR.cpp + StandardAttributes.cpp StandardTypes.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -70,7 +70,7 @@ return wrap(UnknownLoc::get(unwrap(context))); } -void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, +void mlirLocationPrint(MlirLocation location, MlirStringCallback callback, void *userData) { CallbackOstream stream(callback, userData); unwrap(location).print(stream); @@ -237,7 +237,7 @@ return wrap(unwrap(op)->getAttr(name)); } -void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, +void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, void *userData) { CallbackOstream stream(callback, userData); unwrap(op)->print(stream); @@ -319,7 +319,7 @@ return wrap(unwrap(block)->getArgument(static_cast(pos))); } -void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, +void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData) { CallbackOstream stream(callback, userData); unwrap(block)->print(stream); @@ -334,7 +334,7 @@ return wrap(unwrap(value).getType()); } -void mlirValuePrint(MlirValue value, MlirPrintCallback callback, +void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData) { CallbackOstream stream(callback, userData); unwrap(value).print(stream); @@ -351,7 +351,7 @@ int mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } -void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) { +void mlirTypePrint(MlirType type, MlirStringCallback callback, void *userData) { CallbackOstream stream(callback, userData); unwrap(type).print(stream); stream.flush(); @@ -367,7 +367,11 @@ return wrap(mlir::parseAttribute(attr, unwrap(context))); } -void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, +int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { + return unwrap(a1) == unwrap(a2); +} + +void mlirAttributePrint(MlirAttribute attr, MlirStringCallback callback, void *userData) { CallbackOstream stream(callback, userData); unwrap(attr).print(stream); diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp @@ -0,0 +1,561 @@ +//===- StandardAttributes.cpp - C Interface to MLIR Standard Attributes ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/StandardAttributes.h" +#include "mlir/CAPI/AffineMap.h" +#include "mlir/CAPI/IR.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; + +/*============================================================================*/ +/* Affine map attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAAffineMap(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { + return wrap(AffineMapAttr::get(unwrap(map))); +} + +MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { + return wrap(unwrap(attr).cast().getValue()); +} + +/*============================================================================*/ +/* Array attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAArray(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, + MlirAttribute *elements) { + SmallVector attrs; + return wrap(ArrayAttr::get( + unwrapList(static_cast(numElements), elements, attrs), + unwrap(ctx))); +} + +intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { + return static_cast(unwrap(attr).cast().size()); +} + +MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { + return wrap(unwrap(attr).cast().getValue()[pos]); +} + +/*============================================================================*/ +/* Dictionary attribute. */ +/*============================================================================*/ + +int mlirAttributeIsADictionary(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, + MlirNamedAttribute *elements) { + SmallVector attributes; + attributes.reserve(numElements); + for (intptr_t i = 0; i < numElements; ++i) + attributes.emplace_back(Identifier::get(elements[i].name, unwrap(ctx)), + unwrap(elements[i].attribute)); + return wrap(DictionaryAttr::get(attributes, unwrap(ctx))); +} + +intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { + return static_cast(unwrap(attr).cast().size()); +} + +MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, + intptr_t pos) { + NamedAttribute attribute = + unwrap(attr).cast().getValue()[pos]; + return {attribute.first.c_str(), wrap(attribute.second)}; +} + +MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, + const char *name) { + return wrap(unwrap(attr).cast().get(name)); +} + +/*============================================================================*/ +/* Floating point attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAFloat(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, + double value) { + return wrap(FloatAttr::get(unwrap(type), value)); +} + +double mlirFloatAttrGetValueDouble(MlirAttribute attr) { + return unwrap(attr).cast().getValueAsDouble(); +} + +/*============================================================================*/ +/* Integer attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAInteger(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { + return wrap(IntegerAttr::get(unwrap(type), value)); +} + +int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { + return unwrap(attr).cast().getInt(); +} + +/*============================================================================*/ +/* Bool attribute. */ +/*============================================================================*/ + +int mlirAttributeIsABool(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { + return wrap(BoolAttr::get(value, unwrap(ctx))); +} + +int mlirBoolAttrGetValue(MlirAttribute attr) { + return unwrap(attr).cast().getValue(); +} + +/*============================================================================*/ +/* Integer set attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAIntegerSet(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +/*============================================================================*/ +/* Opaque attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAOpaque(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, const char *dialectNamespace, + intptr_t dataLength, const char *data, + MlirType type) { + return wrap(OpaqueAttr::get(Identifier::get(dialectNamespace, unwrap(ctx)), + StringRef(data, dataLength), unwrap(type), + unwrap(ctx))); +} + +const char *mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { + return unwrap(attr).cast().getDialectNamespace().c_str(); +} + +void mlirOpaqueAttrGetData(MlirAttribute attr, MlirStringCallback callback, + void *userData) { + StringRef data = unwrap(attr).cast().getAttrData(); + callback(data.data(), static_cast(data.size()), userData); +} + +/*============================================================================*/ +/* String attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAString(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirStringAttrGet(MlirContext ctx, intptr_t length, + const char *data) { + return wrap(StringAttr::get(StringRef(data, length), unwrap(ctx))); +} + +MlirAttribute mlirStringAttrTypedGet(MlirType type, intptr_t length, + const char *data) { + return wrap(StringAttr::get(StringRef(data, length), unwrap(type))); +} + +void mlirStringAttrGetValue(MlirAttribute attr, MlirStringCallback callback, + void *userData) { + StringRef data = unwrap(attr).cast().getValue(); + callback(data.data(), static_cast(data.size()), userData); +} + +/*============================================================================*/ +/* SymbolRef attribute. */ +/*============================================================================*/ + +int mlirAttributeIsASymbolRef(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length, + const char *symbol, intptr_t numReferences, + MlirAttribute *references) { + SmallVector refs; + refs.reserve(numReferences); + for (intptr_t i = 0; i < numReferences; ++i) + refs.push_back(unwrap(references[i]).cast()); + return wrap(SymbolRefAttr::get(StringRef(symbol, length), refs, unwrap(ctx))); +} + +void mlirSymbolRefAttrGetRootReference(MlirAttribute attr, + MlirStringCallback callback, + void *userData) { + StringRef ref = unwrap(attr).cast().getRootReference(); + callback(ref.data(), ref.size(), userData); +} + +void mlirSymbolRefAttrGetLeafReference(MlirAttribute attr, + MlirStringCallback callback, + void *userData) { + StringRef ref = unwrap(attr).cast().getLeafReference(); + callback(ref.data(), ref.size(), userData); +} + +intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { + return static_cast( + unwrap(attr).cast().getNestedReferences().size()); +} + +MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, + intptr_t pos) { + return wrap(unwrap(attr).cast().getNestedReferences()[pos]); +} + +/*============================================================================*/ +/* Flat SymbolRef attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, intptr_t length, + const char *symbol) { + return wrap(FlatSymbolRefAttr::get(StringRef(symbol, length), unwrap(ctx))); +} + +void mlirFloatSymbolRefAttrGetValue(MlirAttribute attr, + MlirStringCallback callback, + void *userData) { + StringRef symbol = unwrap(attr).cast().getValue(); + callback(symbol.data(), symbol.size(), userData); +} + +/*============================================================================*/ +/* Type attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAType(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirTypeAttrGet(MlirType type) { + return wrap(TypeAttr::get(unwrap(type))); +} + +MlirType mlirTypeAttrGetValue(MlirAttribute attr) { + return wrap(unwrap(attr).cast().getValue()); +} + +/*============================================================================*/ +/* Unit attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAUnit(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirUnitAttrGet(MlirContext ctx) { + return wrap(UnitAttr::get(unwrap(ctx))); +} + +/*============================================================================*/ +/* Elements attributes. */ +/*============================================================================*/ + +int mlirAttributeIsAElements(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, + uint64_t *idxs) { + return wrap(unwrap(attr).cast().getValue( + llvm::makeArrayRef(idxs, rank))); +} + +int mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, + uint64_t *idxs) { + return unwrap(attr).cast().isValidIndex( + llvm::makeArrayRef(idxs, rank)); +} + +int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { + return unwrap(attr).cast().getNumElements(); +} + +/*============================================================================*/ +/* Dense elements attribute. */ +/*============================================================================*/ + +//===----------------------------------------------------------------------===// +// IsA support. + +int mlirAttributeIsADenseElements(MlirAttribute attr) { + return unwrap(attr).isa(); +} +int mlirAttributeIsADenseIntElements(MlirAttribute attr) { + return unwrap(attr).isa(); +} +int mlirAttributeIsADenseFPElements(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +//===----------------------------------------------------------------------===// +// Constructors. + +MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, + intptr_t numElements, + MlirAttribute *elements) { + SmallVector attributes; + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), + unwrapList(numElements, elements, attributes))); +} + +MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, + MlirAttribute element) { + return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + unwrap(element))); +} +MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, + int element) { + return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + static_cast(element))); +} +MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, + uint32_t element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} +MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, + int32_t element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} +MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, + uint64_t element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} +MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, + int64_t element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} +MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, + float element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} +MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, + double element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} + +MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, + intptr_t numElements, + int *elements) { + SmallVector values(elements, elements + numElements); + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), values)); +} + +/// Creates a dense attribute with elements of the type deduced by templates. +template +static MlirAttribute getDenseAttribute(MlirType shapedType, + intptr_t numElements, T *elements) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), + llvm::makeArrayRef(elements, numElements))); +} + +MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, + intptr_t numElements, + uint32_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrInt32Get(MlirType shapedType, + intptr_t numElements, + int32_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrUInt64Get(MlirType shapedType, + intptr_t numElements, + uint64_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrInt64Get(MlirType shapedType, + intptr_t numElements, + int64_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrFloatGet(MlirType shapedType, + intptr_t numElements, + float *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, + intptr_t numElements, + double *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} + +MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, + intptr_t numElements, + intptr_t *strLengths, + const char **strs) { + SmallVector values; + values.reserve(numElements); + for (intptr_t i = 0; i < numElements; ++i) + values.push_back(StringRef(strs[i], strLengths[i])); + + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), values)); +} + +MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, + MlirType shapedType) { + return wrap(unwrap(attr).cast().reshape( + unwrap(shapedType).cast())); +} + +//===----------------------------------------------------------------------===// +// Splat accessors. + +int mlirDenseElementsAttrIsSplat(MlirAttribute attr) { + return unwrap(attr).cast().isSplat(); +} + +MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { + return wrap(unwrap(attr).cast().getSplatValue()); +} +int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} +int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} +uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} +int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} +uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} +float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} +double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} +void mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr, + MlirStringCallback callback, + void *userData) { + StringRef str = + unwrap(attr).cast().getSplatValue(); + callback(str.data(), str.size(), userData); +} + +//===----------------------------------------------------------------------===// +// Indexed accessors. + +int mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { + return *(unwrap(attr).cast().getValues().begin() + + pos); +} +int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { + return *(unwrap(attr).cast().getValues().begin() + + pos); +} +uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { + return *( + unwrap(attr).cast().getValues().begin() + + pos); +} +int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { + return *(unwrap(attr).cast().getValues().begin() + + pos); +} +uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { + return *( + unwrap(attr).cast().getValues().begin() + + pos); +} +float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { + return *(unwrap(attr).cast().getValues().begin() + + pos); +} +double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { + return *(unwrap(attr).cast().getValues().begin() + + pos); +} +void mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos, + MlirStringCallback callback, + void *userData) { + StringRef str = + *(unwrap(attr).cast().getValues().begin() + + pos); + callback(str.data(), str.size(), userData); +} + +/*============================================================================*/ +/* Opaque elements attribute. */ +/*============================================================================*/ + +int mlirAttributeIsAOpaqueElements(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +/*============================================================================*/ +/* Sparse elements attribute. */ +/*============================================================================*/ + +int mlirAttributeIsASparseElements(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, + MlirAttribute denseIndices, + MlirAttribute denseValues) { + return wrap( + SparseElementsAttr::get(unwrap(shapedType).cast(), + unwrap(denseIndices).cast(), + unwrap(denseValues).cast())); +} + +MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { + return wrap(unwrap(attr).cast().getIndices()); +} + +MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { + return wrap(unwrap(attr).cast().getValues()); +} diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -12,11 +12,14 @@ #include "mlir-c/IR.h" #include "mlir-c/Registration.h" +#include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardTypes.h" #include +#include #include #include +#include void populateLoopBody(MlirContext ctx, MlirBlock loopBody, MlirLocation location, MlirBlock funcBody) { @@ -380,6 +383,210 @@ return 0; } +void callbackSetFixedLengthString(const char *data, intptr_t len, + void *userData) { + strncpy(userData, data, len); +} + +int printStandardAttributes(MlirContext ctx) { + MlirAttribute floating = + mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0); + if (!mlirAttributeIsAFloat(floating) || + fabs(mlirFloatAttrGetValueDouble(floating) - 2.0) > 1E-6) + return 1; + mlirAttributeDump(floating); + + MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42); + if (!mlirAttributeIsAInteger(integer) || + mlirIntegerAttrGetValueInt(integer) != 42) + return 2; + mlirAttributeDump(integer); + + MlirAttribute boolean = mlirBoolAttrGet(ctx, 1); + if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean)) + return 3; + mlirAttributeDump(boolean); + + const char data[] = "abcdefghijklmnopqestuvwxyz"; + char buffer[10]; + MlirAttribute opaque = + mlirOpaqueAttrGet(ctx, "std", 3, data, mlirNoneTypeGet(ctx)); + if (!mlirAttributeIsAOpaque(opaque) || + strcmp("std", mlirOpaqueAttrGetDialectNamespace(opaque))) + return 4; + mlirOpaqueAttrGetData(opaque, callbackSetFixedLengthString, buffer); + if (buffer[0] != 'a' || buffer[1] != 'b' || buffer[2] != 'c') + return 5; + mlirAttributeDump(opaque); + + MlirAttribute string = mlirStringAttrGet(ctx, 2, data + 3); + if (!mlirAttributeIsAString(string)) + return 6; + mlirStringAttrGetValue(string, callbackSetFixedLengthString, buffer); + if (buffer[0] != 'd' || buffer[1] != 'e') + return 7; + mlirAttributeDump(string); + + MlirAttribute flatSymbolRef = mlirFlatSymbolRefAttrGet(ctx, 3, data + 5); + if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef)) + return 8; + mlirFloatSymbolRefAttrGetValue(flatSymbolRef, callbackSetFixedLengthString, + buffer); + if (buffer[0] != 'f' || buffer[1] != 'g' || buffer[2] != 'h') + return 9; + mlirAttributeDump(flatSymbolRef); + + MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef}; + MlirAttribute symbolRef = mlirSymbolRefAttrGet(ctx, 2, data + 8, 2, symbols); + if (!mlirAttributeIsASymbolRef(symbolRef) || + mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 || + !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0), + flatSymbolRef) || + !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 1), + flatSymbolRef)) + return 10; + mlirSymbolRefAttrGetLeafReference(symbolRef, callbackSetFixedLengthString, + buffer); + mlirSymbolRefAttrGetRootReference(symbolRef, callbackSetFixedLengthString, + buffer + 3); + if (buffer[0] != 'f' || buffer[1] != 'g' || buffer[2] != 'h' || + buffer[3] != 'i' || buffer[4] != 'j') + return 11; + mlirAttributeDump(symbolRef); + + MlirAttribute type = mlirTypeAttrGet(mlirF32TypeGet(ctx)); + if (!mlirAttributeIsAType(type) || + !mlirTypeEqual(mlirF32TypeGet(ctx), mlirTypeAttrGetValue(type))) + return 12; + mlirAttributeDump(type); + + MlirAttribute unit = mlirUnitAttrGet(ctx); + if (!mlirAttributeIsAUnit(unit)) + return 13; + mlirAttributeDump(unit); + + int64_t shape[] = {1, 2}; + + int bools[] = {0, 1}; + uint32_t uints32[] = {0u, 1u}; + int32_t ints32[] = {0, 1}; + uint64_t uints64[] = {0u, 1u}; + int64_t ints64[] = {0, 1}; + float floats[] = {0.0f, 1.0f}; + double doubles[] = {0.0, 1.0}; + MlirAttribute boolElements = mlirDenseElementsAttrBoolGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 2, bools); + MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32)), 2, + uints32); + MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 2, + ints32); + MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64)), 2, + uints64); + MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 2, + ints64); + MlirAttribute floatElements = mlirDenseElementsAttrFloatGet( + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 2, floats); + MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet( + mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 2, doubles); + + if (!mlirAttributeIsADenseElements(boolElements) || + !mlirAttributeIsADenseElements(uint32Elements) || + !mlirAttributeIsADenseElements(int32Elements) || + !mlirAttributeIsADenseElements(uint64Elements) || + !mlirAttributeIsADenseElements(int64Elements) || + !mlirAttributeIsADenseElements(floatElements) || + !mlirAttributeIsADenseElements(doubleElements)) + return 14; + + if (mlirDenseElementsAttrGetBoolValue(boolElements, 1) != 1 || + mlirDenseElementsAttrGetUInt32Value(uint32Elements, 1) != 1 || + mlirDenseElementsAttrGetInt32Value(int32Elements, 1) != 1 || + mlirDenseElementsAttrGetUInt64Value(uint64Elements, 1) != 1 || + mlirDenseElementsAttrGetInt64Value(int64Elements, 1) != 1 || + fabsf(mlirDenseElementsAttrGetFloatValue(floatElements, 1) - 1.0f) > + 1E-6f || + fabs(mlirDenseElementsAttrGetDoubleValue(doubleElements, 1) - 1.0) > 1E-6) + return 15; + + mlirAttributeDump(boolElements); + mlirAttributeDump(uint32Elements); + mlirAttributeDump(int32Elements); + mlirAttributeDump(uint64Elements); + mlirAttributeDump(int64Elements); + mlirAttributeDump(floatElements); + mlirAttributeDump(doubleElements); + + MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 1); + MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1); + MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1); + MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1); + MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet( + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1); + MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet( + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 1.0f); + MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet( + mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 1.0); + + if (!mlirAttributeIsADenseElements(splatBool) || + !mlirDenseElementsAttrIsSplat(splatBool) || + !mlirAttributeIsADenseElements(splatUInt32) || + !mlirDenseElementsAttrIsSplat(splatUInt32) || + !mlirAttributeIsADenseElements(splatInt32) || + !mlirDenseElementsAttrIsSplat(splatInt32) || + !mlirAttributeIsADenseElements(splatUInt64) || + !mlirDenseElementsAttrIsSplat(splatUInt64) || + !mlirAttributeIsADenseElements(splatInt64) || + !mlirDenseElementsAttrIsSplat(splatInt64) || + !mlirAttributeIsADenseElements(splatFloat) || + !mlirDenseElementsAttrIsSplat(splatFloat) || + !mlirAttributeIsADenseElements(splatDouble) || + !mlirDenseElementsAttrIsSplat(splatDouble)) + return 16; + + if (mlirDenseElementsAttrGetBoolSplatValue(splatBool) != 1 || + mlirDenseElementsAttrGetUInt32SplatValue(splatUInt32) != 1 || + mlirDenseElementsAttrGetInt32SplatValue(splatInt32) != 1 || + mlirDenseElementsAttrGetUInt64SplatValue(splatUInt64) != 1 || + mlirDenseElementsAttrGetInt64SplatValue(splatInt64) != 1 || + fabsf(mlirDenseElementsAttrGetFloatSplatValue(splatFloat) - 1.0f) > + 1E-6f || + fabs(mlirDenseElementsAttrGetDoubleSplatValue(splatDouble) - 1.0) > 1E-6) + return 17; + + mlirAttributeDump(splatBool); + mlirAttributeDump(splatUInt32); + mlirAttributeDump(splatInt32); + mlirAttributeDump(splatUInt64); + mlirAttributeDump(splatInt64); + mlirAttributeDump(splatFloat); + mlirAttributeDump(splatDouble); + + mlirAttributeDump(mlirElementsAttrGetValue(floatElements, 2, uints64)); + mlirAttributeDump(mlirElementsAttrGetValue(doubleElements, 2, uints64)); + + int64_t indices[] = {4, 7}; + int64_t two = 2; + MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get( + mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64)), 2, + indices); + MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet( + mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx)), 2, floats); + MlirAttribute sparseAttr = mlirSparseElementsAttribute( + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), indicesAttr, + valuesAttr); + mlirAttributeDump(sparseAttr); + + return 0; +} + int main() { mlirRegisterAllDialects(); MlirContext ctx = mlirContextCreate(); @@ -454,10 +661,43 @@ // CHECK: tuple, f32> // CHECK: 0 // clang-format on - fprintf(stderr, "@types"); + fprintf(stderr, "@types\n"); int errcode = printStandardTypes(ctx); fprintf(stderr, "%d\n", errcode); + // clang-format off + // CHECK-LABEL: @attrs + // CHECK: 2.000000e+00 : f64 + // CHECK: 42 : i32 + // CHECK: true + // CHECK: #std.abc + // CHECK: "de" + // CHECK: @fgh + // CHECK: @ij::@fgh::@fgh + // CHECK: f32 + // CHECK: unit + // CHECK: dense<{{\[}}[false, true]]> : tensor<1x2xi1> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui32> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi32> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xui64> + // CHECK: dense<{{\[}}[0, 1]]> : tensor<1x2xi64> + // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf32> + // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64> + // CHECK: dense : tensor<1x2xi1> + // CHECK: dense<1> : tensor<1x2xi32> + // CHECK: dense<1> : tensor<1x2xi32> + // CHECK: dense<1> : tensor<1x2xi64> + // CHECK: dense<1> : tensor<1x2xi64> + // CHECK: dense<1.000000e+00> : tensor<1x2xf32> + // CHECK: dense<1.000000e+00> : tensor<1x2xf64> + // CHECK: 1.000000e+00 : f32 + // CHECK: 1.000000e+00 : f64 + // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32> + // clang-format on + fprintf(stderr, "@attrs\n"); + errcode = printStandardAttributes(ctx); + fprintf(stderr, "%d\n", errcode); + mlirContextDestroy(ctx); return 0;