diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -199,36 +199,42 @@ //===----------------------------------------------------------------------===// Value sparse_tensor::genCast(OpBuilder &builder, Location loc, Value value, - Type dstTy) { - Type srcTy = value.getType(); - if (srcTy != dstTy) { - // int <=> index - if (dstTy.isa() || srcTy.isa()) - return builder.create(loc, dstTy, value); - - bool ext = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); - - // float => float. - if (srcTy.isa() && dstTy.isa() && ext) - return builder.create(loc, dstTy, value); - - if (srcTy.isa() && dstTy.isa() && !ext) - return builder.create(loc, dstTy, value); - - // int => int - if (srcTy.isUnsignedInteger() && dstTy.isa() && ext) - return builder.create(loc, dstTy, value); - - if (srcTy.isSignedInteger() && dstTy.isa() && ext) - return builder.create(loc, dstTy, value); - - if (srcTy.isa() && dstTy.isa() && !ext) - return builder.create(loc, dstTy, value); + Type dstTp) { + const Type srcTp = value.getType(); + if (srcTp == dstTp) + return value; + + // int <=> index + if (srcTp.isa() || dstTp.isa()) + return builder.create(loc, dstTp, value); + + const bool ext = + srcTp.getIntOrFloatBitWidth() < dstTp.getIntOrFloatBitWidth(); + + // float => float. + if (srcTp.isa() && dstTp.isa()) { + // Can't use `?:` here because of the implicit conversion from + // `Operation*` to `Value`. + if (ext) + return builder.create(loc, dstTp, value); + if (!ext) + return builder.create(loc, dstTp, value); + } - llvm_unreachable("unhandled type casting"); + // int => int + const auto srcIntTp = srcTp.dyn_cast(); + if (srcIntTp && dstTp.isa()) { + // Can't use `?:` here because of the implicit conversion from + // `Operation*` to `Value`. + if (ext && srcIntTp.isUnsigned()) + return builder.create(loc, dstTp, value); + if (ext && srcIntTp.isSigned()) + return builder.create(loc, dstTp, value); + if (!ext) + return builder.create(loc, dstTp, value); } - return value; + llvm_unreachable("unhandled type casting"); } mlir::Attribute mlir::sparse_tensor::getOneAttr(Builder &builder, Type tp) {