diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -535,6 +535,30 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr); +//===----------------------------------------------------------------------===// +// Strided layout attribute. +//===----------------------------------------------------------------------===// + +// Checks wheather the given attribute is a strided layout attribute. +MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr); + +// Creates a strided layout attribute from given strides and offset. +MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, + int64_t offset, + intptr_t numStrides, + int64_t *strides); + +// Returns the offset in the given strided layout layout attribute. +MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr); + +// Returns the number of strides in the given strided layout attribute. +MLIR_CAPI_EXPORTED intptr_t +mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr); + +// Returns the pos-th stride stored in the given strided layout attribute. +MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, + intptr_t pos); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -722,3 +722,30 @@ MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValues()); } + +//===----------------------------------------------------------------------===// +// Strided layout attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, + intptr_t numStrides, int64_t *strides) { + return wrap(StridedLayoutAttr::get(unwrap(ctx), offset, + ArrayRef(strides, numStrides))); +} + +int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { + return unwrap(attr).cast().getOffset(); +} + +intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { + return static_cast( + unwrap(attr).cast().getStrides().size()); +} + +int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast().getStrides()[pos]; +} diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1220,6 +1220,20 @@ fabs(mlirDenseF64ArrayGetElement(doubleArray, 1) - 1.0) > 1E-6) return 21; + int64_t layoutStrides[3] = {5, 7, 13}; + MlirAttribute stridedLayoutAttr = + mlirStridedLayoutAttrGet(ctx, 42, 3, &layoutStrides[0]); + + // CHECK: strided<[5, 7, 13], offset: 42> + mlirAttributeDump(stridedLayoutAttr); + + if (mlirStridedLayoutAttrGetOffset(stridedLayoutAttr) != 42 || + mlirStridedLayoutAttrGetNumStrides(stridedLayoutAttr) != 3 || + mlirStridedLayoutAttrGetStride(stridedLayoutAttr, 0) != 5 || + mlirStridedLayoutAttrGetStride(stridedLayoutAttr, 1) != 7 || + mlirStridedLayoutAttrGetStride(stridedLayoutAttr, 2) != 13) + return 22; + return 0; }