diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -208,28 +208,9 @@ if (srcTp.isa() || dstTp.isa()) return builder.create(loc, dstTp, value); - const bool ext = - srcTp.getIntOrFloatBitWidth() < dstTp.getIntOrFloatBitWidth(); - - // float => float. - if (srcTp.isa() && dstTp.isa()) { - if (ext) - return builder.create(loc, dstTp, value); - return builder.create(loc, dstTp, value); - } - - // int => int - const auto srcIntTp = srcTp.dyn_cast(); - if (srcIntTp && dstTp.isa()) { - if (!ext) - return builder.create(loc, dstTp, value); - if (srcIntTp.isUnsigned()) - return builder.create(loc, dstTp, value); - if (srcIntTp.isSigned()) - return builder.create(loc, dstTp, value); - } - - llvm_unreachable("unhandled type casting"); + const auto srcIntTp = srcTp.dyn_cast_or_null(); + const bool isUnsignedCast = srcIntTp ? srcIntTp.isUnsigned() : false; + return mlir::convertScalarToDtype(builder, loc, value, dstTp, isUnsignedCast); } mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) { diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir --- a/mlir/test/Dialect/SparseTensor/codegen.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen.mlir @@ -663,6 +663,22 @@ return %0 : tensor } +// CHECK-LABEL: func.func @sparse_convert_element_type( +// CHECK-SAME: %[[A1:.*]]: memref, +// CHECK-SAME: %[[A2:.*]]: memref, +// CHECK-SAME: %[[A3:.*]]: memref, +// CHECK-SAME: %[[A4:.*]]: !sparse_tensor.storage_specifier +// CHECK: scf.for +// CHECK: %[[FValue:.*]] = memref.load +// CHECK: %[[IValue:.*]] = arith.fptosi %[[FValue]] +// CHECK: memref.store %[[IValue]] +// CHECK: return %{{.*}}, %{{.*}}, %{{.*}}, %[[A4]] : +// CHECK-SAME: memref, memref, memref, !sparse_tensor.storage_specifier +func.func @sparse_convert_element_type(%arg0: tensor<32xf32, #SparseVector>) -> tensor { + %0 = sparse_tensor.convert %arg0 : tensor<32xf32, #SparseVector> to tensor + return %0 : tensor +} + // CHECK-LABEL: func.func @sparse_new_coo( // CHECK-SAME: %[[A0:.*]]: !llvm.ptr) -> (memref, memref, memref, !sparse_tensor.storage_specifier<#sparse_tensor.encoding<{ dimLevelType = [ "compressed", "singleton" ] }>>) { // CHECK-DAG: %[[A1:.*]] = arith.constant false