diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -543,10 +543,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr); // Creates a strided layout attribute from given strides and offset. -MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, - int64_t offset, - intptr_t numStrides, - int64_t *strides); +MLIR_CAPI_EXPORTED MlirAttribute +mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides, + const int64_t *strides); // Returns the offset in the given strided layout layout attribute. MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1031,6 +1031,45 @@ } }; +/// Strided layout attribute subclass. +class PyStridedLayoutAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; + static constexpr const char *pyClassName = "StridedLayoutAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int64_t offset, const std::vector strides, + DefaultingPyMlirContext ctx) { + MlirAttribute attr = mlirStridedLayoutAttrGet( + ctx->get(), offset, strides.size(), strides.data()); + return PyStridedLayoutAttribute(ctx->getRef(), attr); + }, + py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), + "Gets a strided layout attribute."); + c.def_property_readonly( + "offset", + [](PyStridedLayoutAttribute &self) { + return mlirStridedLayoutAttrGetOffset(self); + }, + "Returns the value of the float point attribute"); + c.def_property_readonly( + "strides", + [](PyStridedLayoutAttribute &self) { + intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); + std::vector strides(size); + for (intptr_t i = 0; i < size; i++) { + strides[i] = mlirStridedLayoutAttrGetStride(self, i); + } + return strides; + }, + "Returns the value of the float point attribute"); + } +}; + } // namespace void mlir::python::populateIRAttributes(py::module &m) { @@ -1065,4 +1104,6 @@ PyStringAttribute::bind(m); PyTypeAttribute::bind(m); PyUnitAttribute::bind(m); + + PyStridedLayoutAttribute::bind(m); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -302,11 +302,11 @@ }, "Returns the shape of the ranked shaped type as a list of integers."); c.def_static( - "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, + "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, "Returns the value used to indicate dynamic dimensions in shaped " "types."); c.def_static( - "_get_dynamic_stride_or_offset", + "get_dynamic_stride_or_offset", []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, "Returns the value used to indicate dynamic strides or offsets in " "shaped types."); diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -732,7 +732,8 @@ } MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, - intptr_t numStrides, int64_t *strides) { + intptr_t numStrides, + const int64_t *strides) { return wrap(StridedLayoutAttr::get(unwrap(ctx), offset, ArrayRef(strides, numStrides))); } diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -211,7 +211,7 @@ static_split_point = split_point dynamic_split_point = None else: - static_split_point = _get_int64_attr(ShapedType._get_dynamic_size()) + static_split_point = _get_int64_attr(ShapedType.get_dynamic_size()) dynamic_split_point = _get_op_result_or_value(split_point) pdl_operation_type = pdl.OperationType.get() @@ -255,7 +255,7 @@ static_sizes.append(size) else: static_sizes.append( - IntegerAttr.get(i64_type, ShapedType._get_dynamic_size())) + IntegerAttr.get(i64_type, ShapedType.get_dynamic_size())) dynamic_sizes.append(_get_op_result_or_value(size)) sizes_attr = ArrayAttr.get(static_sizes) diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -523,3 +523,22 @@ array = array + [StringAttr.get("c")] # CHECK: concat: ["a", "b", "c"] print("concat: ", array) + + +# CHECK-LABEL: TEST: testStridedLayoutAttr +@run +def testStridedLayoutAttr(): + with Context(): + attr = StridedLayoutAttr.get(42, [5, 7, 13]) + # CHECK: strided<[5, 7, 13], offset: 42> + print(attr) + # CHECK: 42 + print(attr.offset) + # CHECK: 3 + print(len(attr.strides)) + # CHECK: 5 + print(attr.strides[0]) + # CHECK: 7 + print(attr.strides[1]) + # CHECK: 13 + print(attr.strides[2]) diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -487,3 +487,13 @@ print("dialect namespace:", opaque.dialect_namespace) # CHECK: data: type print("data:", opaque.data) + + +# CHECK-LABEL: TEST: testShapedTypeConstants +# Tests that ShapedType exposes magic value constants. +@run +def testShapedTypeConstants(): + # CHECK: + print(type(ShapedType.get_dynamic_size())) + # CHECK: + print(type(ShapedType.get_dynamic_stride_or_offset()))