diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -491,9 +491,34 @@ args.front(), zero); } - if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) - return rewriter.create(loc, resultTypes, args, - mlir::None); + if (mlir::FPToSIOp::areCastCompatible(srcTy, dstTy)) { + auto zero = + rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + auto half = + rewriter.create(loc, rewriter.getF32FloatAttr(0.5f)); + + auto intMin = rewriter.create( + loc, rewriter.getF32FloatAttr( + APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue())); + + auto intMax = rewriter.create( + loc, rewriter.getF32FloatAttr( + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue())); + + auto added = rewriter.create(loc, args[0], half); + auto subbed = rewriter.create(loc, args[0], half); + auto negative = + rewriter.create(loc, CmpFPredicate::OLT, args[0], zero); + auto rounded = + rewriter.create(loc, negative, subbed, added); + + auto clamped = clampHelper(loc, rounded, intMin, intMax, + CmpFPredicate::OLT, rewriter); + + return rewriter.create(loc, dstTy, clamped); + } // Casting to boolean, integers need to only be checked as not-equal to // zero. @@ -508,9 +533,23 @@ return rewriter.create(loc, resultTypes, args, mlir::None); - if (srcTy.isa() && dstTy.isa() && !bitExtend) - return rewriter.create(loc, resultTypes, args, - mlir::None); + if (srcTy.isa() && dstTy.isa() && !bitExtend) { + auto intMin = rewriter.create( + loc, + APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue(), + srcTy.getIntOrFloatBitWidth()); + + auto intMax = rewriter.create( + loc, + APInt::getSignedMaxValue(dstTy.getIntOrFloatBitWidth()) + .getSExtValue(), + srcTy.getIntOrFloatBitWidth()); + + auto clamped = clampHelper(loc, args[0], intMin, intMax, + CmpIPredicate::slt, rewriter); + return rewriter.create(loc, dstTy, clamped); + } } (void)rewriter.notifyMatchFailure( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -213,6 +213,18 @@ %20 = "tosa.sigmoid"(%0) : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic + // CHECK: constant 0.000000e+00 + // CHECK: constant 5.000000e-01 + // CHECK: constant -2.14748365E+9 + // CHECK: constant 2.14748365E+9 + // CHECK: addf + // CHECK: subf + // CHECK: cmpf olt + // CHECK: select + // CHECK: cmpf olt + // CHECK: select + // CHECK: cmpf olt + // CHECK: select // CHECK: fptosi %21 = "tosa.cast"(%0) : (tensor<1xf32>) -> tensor<1xi32> @@ -358,6 +370,12 @@ %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic + // CHECK: constant -32768 + // CHECK: constant 32767 + // CHECK: cmpi slt + // CHECK: select + // CHECK: cmpi slt + // CHECK: select // CHECK: trunci %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16>