diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -481,12 +481,14 @@ if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) { auto intMin = rewriter.create( - loc, rewriter.getF32FloatAttr( + loc, rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); auto intMax = rewriter.create( - loc, rewriter.getF32FloatAttr( + loc, rewriter.getFloatAttr( + getElementTypeOrSelf(srcTy), APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) .getSExtValue())); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -551,6 +551,14 @@ // CHECK: arith.extf %0 = tosa.cast %arg0 : (tensor<1xf16>) -> tensor<1xf32> + // CHECK: linalg.generic + // CHECK: arith.constant -1.280000e+02 + // CHECK: arith.constant 1.270000e+02 + // CHECK: math.roundeven + // CHECK: arith.minf + // CHECK: arith.maxf + // CHECK: arith.fptosi + %1 = "tosa.cast"(%arg0) : (tensor<1xf16>) -> tensor<1xi8> return }