diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -86,6 +86,7 @@ }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; + let hasFolder = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -333,6 +333,17 @@ return emitError("unexpected type in convert"); } +OpFoldResult ConvertOp::fold(ArrayRef operands) { + Type dstType = getType(); + // Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse + // convert for codegen to remove. This is because we use trivial + // sparse-to-sparse to tell bufferization that the sparse codegen will expand + // the tensor buffer into sparse tensor storage. + if (!getSparseTensorEncoding(dstType) && dstType == getSource().getType()) + return getSource(); + return {}; +} + LogicalResult ToPointersOp::verify() { auto e = getSparseTensorEncoding(getTensor().getType()); if (failed(isInBounds(getDimension().getZExtValue(), getTensor()))) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -727,7 +727,10 @@ LogicalResult matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getType() != op.getSource().getType()) { + SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); + SparseTensorEncodingAttr encSrc = + getSparseTensorEncoding(op.getSource().getType()); + if (encDst != encSrc) { // This should be handled by rewriting before codegen. return failure(); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -526,7 +526,7 @@ // CHECK-SAME: %[[A3:.*]]: memref, // CHECK-SAME: %[[A4:.*]]: memref) // CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -func.func @sparse_nop_convert(%arg0: tensor) -> tensor { - %0 = sparse_tensor.convert %arg0 : tensor to tensor +func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor { + %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor return %0 : tensor } diff --git a/mlir/test/Dialect/SparseTensor/fold.mlir b/mlir/test/Dialect/SparseTensor/fold.mlir --- a/mlir/test/Dialect/SparseTensor/fold.mlir +++ b/mlir/test/Dialect/SparseTensor/fold.mlir @@ -2,6 +2,15 @@ #SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}> +// CHECK-LABEL: func @sparse_nop_dense2dense_convert( +// CHECK-SAME: %[[A:.*]]: tensor<64xf32>) +// CHECK-NOT: sparse_tensor.convert +// CHECK: return %[[A]] : tensor<64xf32> +func.func @sparse_nop_dense2dense_convert(%arg0: tensor<64xf32>) -> tensor<64xf32> { + %0 = sparse_tensor.convert %arg0 : tensor<64xf32> to tensor<64xf32> + return %0 : tensor<64xf32> +} + // CHECK-LABEL: func @sparse_dce_convert( // CHECK-SAME: %[[A:.*]]: tensor<64xf32>) // CHECK-NOT: sparse_tensor.convert