diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -241,6 +241,23 @@ } }; +/// Sparse codegen rule for trivial tensor casts. +class SparseCastConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only rewrite identically annotated source/dest. + auto encDst = getSparseTensorEncoding(op.getType()); + auto encSrc = getSparseTensorEncoding(op.getSource().getType()); + if (!encDst || encDst != encSrc) + return failure(); + rewriter.replaceOp(op, adaptor.getOperands()); + return success(); + } +}; + /// Sparse conversion rule for pointer accesses. class SparseToPointersConverter : public OpConversionPattern { public: @@ -314,7 +331,7 @@ /// the sparsification of linear algebra operations. void mlir::populateSparseTensorCodegenPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -36,12 +36,20 @@ }> // CHECK-LABEL: func @sparse_nop( -// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) -> tuple, memref, memref, memref> +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) // CHECK: return %[[A]] : tuple, memref, memref, memref> func.func @sparse_nop(%arg0: tensor) -> tensor { return %arg0 : tensor } +// CHECK-LABEL: func @sparse_nop_cast( +// CHECK-SAME: %[[A:.*]]: tuple, memref, memref, memref>) +// CHECK: return %[[A]] : tuple, memref, memref, memref> +func.func @sparse_nop_cast(%arg0: tensor<64xf32, #SparseVector>) -> tensor { + %0 = tensor.cast %arg0 : tensor<64xf32, #SparseVector> to tensor + return %0 : tensor +} + // CHECK-LABEL: func @sparse_dense_2d( // CHECK-SAME: %[[A:.*]]: tuple, memref>) func.func @sparse_dense_2d(%arg0: tensor) { @@ -71,7 +79,7 @@ // fold using the original static dimension sizes. // // CHECK-LABEL: func @sparse_dense_3d( -// CHECK-SAME: %[[A:.*]]: tuple, memref<6000xf64>>) -> index { +// CHECK-SAME: %[[A:.*]]: tuple, memref<6000xf64>>) // CHECK: %[[C:.*]] = arith.constant 20 : index // CHECK: return %[[C]] : index func.func @sparse_dense_3d(%arg0: tensor<10x20x30xf64, #Dense3D>) -> index { @@ -86,7 +94,7 @@ // since the latter honors the dimOrdering. // // CHECK-LABEL: func @sparse_dense_3d_dyn( -// CHECK-SAME: %[[A:.*]]: tuple, memref>) -> index { +// CHECK-SAME: %[[A:.*]]: tuple, memref>) // CHECK: %[[C:.*]] = arith.constant 2 : index // CHECK: %[[F:.*]] = sparse_tensor.storage_get %[[A]][0] : tuple, memref> to memref<3xindex> // CHECK: %[[L:.*]] = memref.load %[[F]][%[[C]]] : memref<3xindex>