diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -339,13 +339,12 @@ return failure(); } - if (inputElementType.isF32()) { + if (inputElementType.isa()) { + // Unlike integer types, floating point types can represent infinity. auto minClamp = op.getMinFp(); auto maxClamp = op.getMaxFp(); - bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) && - minClamp.isNegative(); - bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) && - !maxClamp.isNegative(); + bool isMin = minClamp.isInfinity() && minClamp.isNegative(); + bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative(); if (isMin && isMax) { rewriter.replaceOp(op, input); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -39,18 +39,42 @@ return %0 : tensor } -// CHECK-LABEL: @clamp_not_noop -func.func @clamp_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { +// CHECK-LABEL: @clamp_i32_not_noop +func.func @clamp_i32_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK: "tosa.clamp" %0 = "tosa.clamp"(%arg0) {min_int = 1 : i64, max_int = 4 : i64, min_fp = 1.0 : f32, max_fp = 4.0 : f32} : (tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } -// CHECK-LABEL: @clamp_float_is_noop -func.func @clamp_float_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { +// CHECK-LABEL: @clamp_f16_not_noop +func.func @clamp_f16_not_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> { + // CHECK: "tosa.clamp" + %0 = "tosa.clamp"(%arg0) {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xf16>) -> tensor<4xf16> + return %0 : tensor<4xf16> +} + +// CHECK-LABEL: @clamp_f32_not_noop +func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: "tosa.clamp" + %0 = "tosa.clamp"(%arg0) {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: @clamp_f16_is_noop +func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> { + // CHECK: return %arg0 + // CHECK-NOT: "tosa.clamp" + // 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity. + %0 = "tosa.clamp"(%arg0) {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf16>) -> tensor<4xf16> + return %0 : tensor<4xf16> +} + +// CHECK-LABEL: @clamp_f32_is_noop +func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: return %arg0 // CHECK-NOT: "tosa.clamp" - %0 = "tosa.clamp"(%arg0) {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xf32>) -> tensor<4xf32> + // 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity. + %0 = "tosa.clamp"(%arg0) {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> }