diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -203,6 +203,10 @@ MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding); +/// Gets the 'encoding' attribute from the ranked tensor type, returning a null +/// attribute if none. +MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type); + /// Creates an unranked tensor type with the given element type in the same /// context as the element type. The type is owned by the context. MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -338,10 +338,11 @@ c.def_static( "get", [](std::vector shape, PyType &elementType, + llvm::Optional &encodingAttr, DefaultingPyLocation loc) { - MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType, encodingAttr); + loc, shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -355,8 +356,17 @@ } return PyRankedTensorType(elementType.getContext(), t); }, - py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(), + py::arg("shape"), py::arg("element_type"), + py::arg("encoding") = py::none(), py::arg("loc") = py::none(), "Create a ranked tensor type"); + c.def_property_readonly( + "encoding", + [](PyRankedTensorType &self) -> llvm::Optional { + MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return llvm::None; + return PyAttribute(self.getContext(), encoding); + }); } }; diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -206,6 +206,10 @@ unwrap(elementType), unwrap(encoding))); } +MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) { + return wrap(unwrap(type).cast().getEncoding()); +} + MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { return wrap(UnrankedTensorType::get(unwrap(elementType))); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -690,7 +690,8 @@ MlirType rankedTensor = mlirRankedTensorTypeGet( sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull()); if (!mlirTypeIsATensor(rankedTensor) || - !mlirTypeIsARankedTensor(rankedTensor)) + !mlirTypeIsARankedTensor(rankedTensor) || + !mlirAttributeIsNull(mlirRankedTensorTypeGetEncoding(rankedTensor))) return 16; mlirTypeDump(rankedTensor); fprintf(stderr, "\n"); diff --git a/mlir/test/python/dialects/sparse_tensor/dialect.py b/mlir/test/python/dialects/sparse_tensor/dialect.py --- a/mlir/test/python/dialects/sparse_tensor/dialect.py +++ b/mlir/test/python/dialects/sparse_tensor/dialect.py @@ -73,3 +73,18 @@ print(created) # CHECK: created_equal: True print(f"created_equal: {created == casted}") + + +# CHECK-LABEL: TEST: testEncodingAttrOnTensot +@run +def testEncodingAttrOnTensot(): + with Context() as ctx, Location.unknown(): + encoding = st.EncodingAttr(Attribute.parse( + '#sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], ' + 'pointerBitWidth = 16, indexBitWidth = 32 }>')) + tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding) + # CHECK: tensor<1024xf32, #sparse_tensor + print(tt) + # CHECK: #sparse_tensor.encoding + print(tt.encoding) + assert tt.encoding == encoding diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -294,6 +294,9 @@ else: print("Exception not produced") + # Encoding should be None. + assert RankedTensorType.get(shape, f32).encoding is None + # CHECK-LABEL: TEST: testUnrankedTensorType @run