diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -342,9 +342,9 @@ bool isOrderedLvl(::mlir::sparse_tensor::Level l) const { return isOrderedDLT(getLvlType(l)); } bool isUniqueLvl(::mlir::sparse_tensor::Level l) const { return isUniqueDLT(getLvlType(l)); } - bool isSlice() const { - return !getDimSlices().empty(); - } + bool isSlice() const; + + ::mlir::sparse_tensor::SparseTensorDimSliceAttr getDimSlice(::mlir::sparse_tensor::Dimension dim) const; std::optional getStaticDimSliceOffset(::mlir::sparse_tensor::Dimension dim) const; std::optional getStaticDimSliceSize(::mlir::sparse_tensor::Dimension dim) const; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -165,19 +165,32 @@ return getDimLevelType()[l]; } +bool SparseTensorEncodingAttr::isSlice() const { + assert(getImpl() && "Uninitialized SparseTensorEncodingAttr"); + return !getDimSlices().empty(); +} + +SparseTensorDimSliceAttr +SparseTensorEncodingAttr::getDimSlice(Dimension dim) const { + assert(isSlice() && "Is not a slice"); + const auto dimSlices = getDimSlices(); + assert(dim < dimSlices.size() && "Dimension is out of bounds"); + return dimSlices[dim]; +} + std::optional SparseTensorEncodingAttr::getStaticDimSliceOffset(Dimension dim) const { - return getDimSlices()[dim].getStaticOffset(); + return getDimSlice(dim).getStaticOffset(); } std::optional SparseTensorEncodingAttr::getStaticDimSliceSize(Dimension dim) const { - return getDimSlices()[dim].getStaticSize(); + return getDimSlice(dim).getStaticSize(); } std::optional SparseTensorEncodingAttr::getStaticDimSliceStride(Dimension dim) const { - return getDimSlices()[dim].getStaticStride(); + return getDimSlice(dim).getStaticStride(); } std::optional