diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -300,13 +300,21 @@ def IsSparseTensorPred : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">; +def IsSparseTensorSlicePred + : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && " + " ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">; + // The following four follow the same idiom as `TensorOf`, `AnyTensor`, // `RankedTensorOf`, `AnyRankedTensor`. class SparseTensorOf allowedTypes> : TensorOf; +class SparseTensorSliceOf allowedTypes> + : TensorOf; + def AnySparseTensor : SparseTensorOf<[AnyType]>; +def AnySparseTensorSlice : SparseTensorSliceOf<[AnyType]>; class RankedSparseTensorOf allowedTypes> : RankedTensorOf; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -205,6 +205,36 @@ let hasVerifier = 1; } +def SparseTensor_ToSliceOffsetOp : SparseTensor_Op<"slice.offset", [Pure]>, + Arguments<(ins AnySparseTensorSlice:$slice, IndexAttr:$dim)>, + Results<(outs Index:$offset)> { + let summary = "Extracts the offset of the sparse tensor slice at the given dimension"; + let description = [{ + Example: + + ```mlir + %1 = sparse_tensor.slice.offset %0 at 1 : tensor<64x64xf64, #Slice> + ``` + }]; + let assemblyFormat = "$slice `at` $dim attr-dict `:` type($slice)"; + let hasVerifier = 1; +} + +def SparseTensor_ToSliceStrideOp : SparseTensor_Op<"slice.stride", [Pure]>, + Arguments<(ins AnySparseTensorSlice:$slice, IndexAttr:$dim)>, + Results<(outs Index:$stride)> { + let summary = "Extracts the stride of the sparse tensor slice at the given dimension"; + let description = [{ + Example: + + ```mlir + %1 = sparse_tensor.slice.stride %0 at 1 : tensor<64x64xf64, #Slice> + ``` + }]; + let assemblyFormat = "$slice `at` $dim attr-dict `:` type($slice)"; + let hasVerifier = 1; +} + def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>, Results<(outs SparseTensorStorageSpecifier:$result)> { let summary = ""; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -662,6 +662,20 @@ return success(); } +LogicalResult ToSliceOffsetOp::verify() { + auto rank = getSlice().getType().cast().getRank(); + if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) + return emitError("requested dimension out of bound"); + return success(); +} + +LogicalResult ToSliceStrideOp::verify() { + auto rank = getSlice().getType().cast().getRank(); + if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) + return emitError("requested dimension out of bound"); + return success(); +} + LogicalResult GetStorageSpecifierOp::verify() { if (failed(verifySparsifierGetterSetter(getSpecifierKind(), getDim(), getSpecifier(), getOperation()))) { diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -116,6 +116,32 @@ // ----- +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +func.func @sparse_slice_offset(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + // expected-error@+1 {{requested dimension out of bound}} + %0 = sparse_tensor.slice.offset %arg0 at 2 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + // expected-error@+1 {{requested dimension out of bound}} + %0 = sparse_tensor.slice.stride %arg0 at 2 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -117,6 +117,38 @@ // ----- +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func @sparse_slice_offset( +// CHECK-SAME: %[[A:.*]]: tensor<2x8xf64, #{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.slice.offset %[[A]] at 1 : tensor<2x8xf64, #{{.*}}> +// CHECK: return %[[T]] : index +func.func @sparse_slice_offset(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + %0 = sparse_tensor.slice.offset %arg0 at 1 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func @sparse_slice_stride( +// CHECK-SAME: %[[A:.*]]: tensor<2x8xf64, #{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.slice.stride %[[A]] at 1 : tensor<2x8xf64, #{{.*}}> +// CHECK: return %[[T]] : index +func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + %0 = sparse_tensor.slice.stride %arg0 at 1 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_metadata_init(