diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -701,6 +701,11 @@ if (auto tp2 = getDest().getType().dyn_cast()) { if (tp1.getRank() != tp2.getRank()) return emitError("unexpected conversion mismatch in rank"); + auto dstEnc = + tp2.getEncoding().dyn_cast_or_null(); + if (dstEnc && dstEnc.isSlice()) + return emitError("cannot convert to a sparse tensor slice"); + auto shape1 = tp1.getShape(); auto shape2 = tp2.getShape(); // Accept size matches between the source and the destination type diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -1058,9 +1058,14 @@ SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); SparseTensorEncodingAttr encSrc = getSparseTensorEncoding(op.getSource().getType()); + // The output tensor can not be a slice and those cases should have been + // rejected by ConvertOp::verify() already. + assert(!encDst.isSlice() && "Cannot convert to a sparse tensor slices."); // Different encoding (except for different bitwidth) should be handled by // rewriting. - if (encDst.withoutBitWidths() != encSrc.withoutBitWidths()) { + // We need further rewrites if the input tensor is a slice too. + if (encDst.withoutBitWidths() != encSrc.withoutBitWidths() || + encSrc.isSlice()) { return failure(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -618,7 +618,7 @@ PatternRewriter &rewriter) const override { auto encDst = getSparseTensorEncoding(op.getType()); auto encSrc = getSparseTensorEncoding(op.getSource().getType()); - if (encDst && encSrc && + if (encDst && encSrc && !encSrc.isSlice() && encSrc.withoutBitWidths() == encDst.withoutBitWidths()) { // Trivial tensor conversion and simple element type conversion is handled // in codegen. diff --git a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir --- a/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir @@ -25,6 +25,10 @@ dimLevelType = ["compressed"] }> +#SortedCOO2D = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ], +}> + #SortedCOO3D = #sparse_tensor.encoding<{ dimLevelType = [ "compressed-nu", "singleton-nu", "singleton" ] @@ -35,6 +39,11 @@ dimOrdering = affine_map<(i,j,k) -> (k,i,j)> }> +#COOSlice = #sparse_tensor.encoding<{ + dimLevelType = [ "compressed-nu", "singleton" ], + slice = [ (2, 2, 1), (12, 13, 1) ] +}> + // CHECK-LABEL: func @sparse_nop_convert( // CHECK-SAME: %[[A:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK: return %[[A]] : !llvm.ptr @@ -185,3 +194,20 @@ %0 = sparse_tensor.convert %arg0 : tensor to tensor return %0 : tensor } + +// CHECK-RWT-LABEL: func.func @sparse_convert_slice( +// CHECK-RWT-SAME: %[[VAL_0:.*]]: tensor<2x13xi32, #{{.*}}>) -> tensor<2x13xi32, #{{.*}}> { +// CHECK-RWT: %[[VAL_1:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor<2x13xi32, #{{.*}}> +// CHECK-RWT: %[[VAL_2:.*]] = bufferization.alloc_tensor() size_hint=%[[VAL_1]] : tensor<2x13xi32, #{{.*}}> +// CHECK-RWT: %[[VAL_3:.*]] = sparse_tensor.foreach in %[[VAL_0]] init(%[[VAL_2]]) : tensor<2x13xi32, #{{.*}}>, tensor<2x13xi32, #{{.*}}> -> tensor<2x13xi32, #{{.*}}> do { +// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: i32, %[[VAL_7:.*]]: tensor<2x13xi32, #{{.*}}>): +// CHECK-RWT: %[[VAL_8:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]] : tensor<2x13xi32, #{{.*}}> +// CHECK-RWT: sparse_tensor.yield %[[VAL_8]] : tensor<2x13xi32, #{{.*}}> +// CHECK-RWT: } +// CHECK-RWT: %[[VAL_9:.*]] = sparse_tensor.load %[[VAL_10:.*]] hasInserts : tensor<2x13xi32, #{{.*}}> +// CHECK-RWT: %[[VAL_11:.*]] = sparse_tensor.convert %[[VAL_9]] : tensor<2x13xi32, #{{.*}}> to tensor<2x13xi32, #{{.*}}> +// CHECK-RWT: return %[[VAL_11]] : tensor<2x13xi32, #{{.*}}> +func.func @sparse_convert_slice(%arg0: tensor<2x13xi32, #COOSlice>) -> (tensor<2x13xi32, #SortedCOO2D>) { + %0 = sparse_tensor.convert %arg0 : tensor<2x13xi32, #COOSlice> to tensor<2x13xi32, #SortedCOO2D> + return %0 : tensor<2x13xi32, #SortedCOO2D> +} diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -423,6 +423,19 @@ // ----- +#CSR = #sparse_tensor.encoding<{ + dimLevelType = ["dense", "compressed"], + slice = [ (1, 4, 1), (1, 4, 2) ] +}> + +func.func @sparse_convert_to_slice(%arg0: tensor<10x?xf32>) -> tensor<10x10xf32, #CSR> { + // expected-error@+1 {{cannot convert to a sparse tensor slice}} + %0 = sparse_tensor.convert %arg0 : tensor<10x?xf32> to tensor<10x10xf32, #CSR> + return %0 : tensor<10x10xf32, #CSR> +} + +// ----- + func.func @invalid_binary_num_args_mismatch_overlap(%arg0: f64, %arg1: f64) -> f64 { // expected-error@+1 {{overlap region must have exactly 2 arguments}} %r = sparse_tensor.binary %arg0, %arg1 : f64, f64 to f64