diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -727,7 +727,10 @@ LogicalResult matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (op.getType() != op.getSource().getType()) { + SparseTensorEncodingAttr encDst = getSparseTensorEncoding(op.getType()); + SparseTensorEncodingAttr encSrc = + getSparseTensorEncoding(op.getSource().getType()); + if (encDst != encSrc) { // This should be handled by rewriting before codegen. return failure(); } diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -526,7 +526,7 @@ // CHECK-SAME: %[[A3:.*]]: memref, // CHECK-SAME: %[[A4:.*]]: memref) // CHECK: return %[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]] : memref<1xindex>, memref<3xindex>, memref, memref, memref -func.func @sparse_nop_convert(%arg0: tensor) -> tensor { - %0 = sparse_tensor.convert %arg0 : tensor to tensor +func.func @sparse_nop_convert(%arg0: tensor<32xf32, #SparseVector>) -> tensor { + %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor return %0 : tensor }