diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td @@ -335,13 +335,21 @@ def IsSparseTensorPred : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">; +def IsSparseTensorSlicePred + : CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && " + " ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">; + // The following four follow the same idiom as `TensorOf`, `AnyTensor`, // `RankedTensorOf`, `AnyRankedTensor`. class SparseTensorOf allowedTypes> : TensorOf; +class SparseTensorSliceOf allowedTypes> + : TensorOf; + def AnySparseTensor : SparseTensorOf<[AnyType]>; +def AnySparseTensorSlice : SparseTensorSliceOf<[AnyType]>; class RankedSparseTensorOf allowedTypes> : RankedTensorOf; diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -294,6 +294,61 @@ let hasVerifier = 1; } +def SparseTensor_ToSliceOffsetOp : SparseTensor_Op<"slice.offset", [Pure]>, + Arguments<(ins AnySparseTensorSlice:$slice, IndexAttr:$dim)>, + Results<(outs Index:$offset)> { + let summary = "Extracts the offset of the sparse tensor slice at the given dimension"; + let description = [{ + Extracts the offset of the sparse tensor slice at the given dimension. + + Currently, sparse tensor slices are still a work in progress, and only + works when runtime library is disabled (i.e., running sparse compiler + with `enable-runtime-library=false`). + + Example: + + ```mlir + %0 = tensor.extract_slice %s[%v1, %v2][64, 64][1, 1] : tensor<128x128xf64, #DCSR> + to tensor<64x64xf64, #Slice> + + %1 = sparse_tensor.slice.offset %0 at 0 : tensor<64x64xf64, #Slice> + %2 = sparse_tensor.slice.offset %0 at 1 : tensor<64x64xf64, #Slice> + // %1 = %v1 + // %2 = %v2 + ``` + }]; + let assemblyFormat = "$slice `at` $dim attr-dict `:` type($slice)"; + let hasVerifier = 1; +} + +def SparseTensor_ToSliceStrideOp : SparseTensor_Op<"slice.stride", [Pure]>, + Arguments<(ins AnySparseTensorSlice:$slice, IndexAttr:$dim)>, + Results<(outs Index:$stride)> { + let summary = "Extracts the stride of the sparse tensor slice at the given dimension"; + let description = [{ + Extracts the stride of the sparse tensor slice at the given dimension. + + Currently, sparse tensor slices are still a work in progress, and only + works when runtime library is disabled (i.e., running sparse compiler + with `enable-runtime-library=false`). + + Example: + + ```mlir + %0 = tensor.extract_slice %s[%v1, %v2][64, 64][%s1, %s2] : tensor<128x128xf64, #DCSR> + to tensor<64x64xf64, #Slice> + + %1 = sparse_tensor.slice.stride %0 at 0 : tensor<64x64xf64, #Slice> + %2 = sparse_tensor.slice.stride %0 at 1 : tensor<64x64xf64, #Slice> + // %1 = %s1 + // %2 = %s2 + + ``` + }]; + let assemblyFormat = "$slice `at` $dim attr-dict `:` type($slice)"; + let hasVerifier = 1; +} + def SparseTensor_StorageSpecifierInitOp : SparseTensor_Op<"storage_specifier.init", [Pure]>, Results<(outs SparseTensorStorageSpecifier:$result)> { let summary = ""; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -765,6 +765,20 @@ return success(); } +LogicalResult ToSliceOffsetOp::verify() { + auto rank = getRankedTensorType(getSlice()).getRank(); + if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) + return emitError("requested dimension out of bound"); + return success(); +} + +LogicalResult ToSliceStrideOp::verify() { + auto rank = getRankedTensorType(getSlice()).getRank(); + if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0) + return emitError("requested dimension out of bound"); + return success(); +} + LogicalResult GetStorageSpecifierOp::verify() { RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter( getSpecifierKind(), getDim(), getSpecifier(), getOperation())) diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -224,6 +224,32 @@ // ----- +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +func.func @sparse_slice_offset(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + // expected-error@+1 {{requested dimension out of bound}} + %0 = sparse_tensor.slice.offset %arg0 at 2 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + // expected-error@+1 {{requested dimension out of bound}} + %0 = sparse_tensor.slice.stride %arg0 at 2 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> func.func @sparse_get_md(%arg0: !sparse_tensor.storage_specifier<#SparseVector>) -> i64 { diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -148,6 +148,38 @@ // ----- +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func @sparse_slice_offset( +// CHECK-SAME: %[[A:.*]]: tensor<2x8xf64, #{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.slice.offset %[[A]] at 1 : tensor<2x8xf64, #{{.*}}> +// CHECK: return %[[T]] : index +func.func @sparse_slice_offset(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + %0 = sparse_tensor.slice.offset %arg0 at 1 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + +#CSR_SLICE = #sparse_tensor.encoding<{ + dimLevelType = [ "dense", "compressed" ], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +// CHECK-LABEL: func @sparse_slice_stride( +// CHECK-SAME: %[[A:.*]]: tensor<2x8xf64, #{{.*}}>) +// CHECK: %[[T:.*]] = sparse_tensor.slice.stride %[[A]] at 1 : tensor<2x8xf64, #{{.*}}> +// CHECK: return %[[T]] : index +func.func @sparse_slice_stride(%arg0: tensor<2x8xf64, #CSR_SLICE>) -> index { + %0 = sparse_tensor.slice.stride %arg0 at 1 : tensor<2x8xf64, #CSR_SLICE> + return %0 : index +} + +// ----- + #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> // CHECK-LABEL: func @sparse_metadata_init(