diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -341,6 +341,8 @@ let results = (outs Tosa_Tensor:$output ); + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -458,6 +458,79 @@ results.insert(context); } +struct ClampIsNoOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ClampOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + auto inputType = op.input().getType().template dyn_cast(); + auto inputElementType = inputType.getElementType(); + + if (!inputType.hasStaticShape()) { + return failure(); + } + + if (inputElementType.isF32()) { + auto minClamp = op.min_fp(); + auto maxClamp = op.max_fp(); + bool isMin = (minClamp.isLargest() || minClamp.isInfinity()) && + minClamp.isNegative(); + bool isMax = (maxClamp.isLargest() || maxClamp.isInfinity()) && + !maxClamp.isNegative(); + + if (isMin && isMax) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } + + if (inputElementType.isUnsignedInteger()) { + auto minClamp = op.min_int(); + auto maxClamp = op.max_int(); + + int64_t intMin = + APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) + .getZExtValue(); + int64_t intMax = + APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth()) + .getZExtValue(); + + if (minClamp <= intMin && maxClamp >= intMax) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } + + if (inputElementType.isa()) { + auto minClamp = op.min_int(); + auto maxClamp = op.max_int(); + + int64_t intMin = + APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) + .getSExtValue(); + int64_t intMax = + APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth()) + .getSExtValue(); + + if (minClamp <= intMin && maxClamp >= intMax) { + rewriter.replaceOp(op, input); + return success(); + } + return failure(); + } + + return failure(); + } +}; + +void ClampOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// @@ -556,8 +629,7 @@ // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// -template -static LogicalResult verifyConvOp(T op) { +template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = op.input().getType().template dyn_cast(); auto weightType = op.weight().getType().template dyn_cast(); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -49,6 +49,55 @@ // ----- +// CHECK-LABEL: @clamp_not_noop +func @clamp_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { + // CHECK: "tosa.clamp" + %0 = "tosa.clamp"(%arg0) {min_int = 1 : i64, max_int = 4 : i64, min_fp = 1.0 : f32, max_fp = 4.0 : f32} : (tensor<4xi32>) -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + +// CHECK-LABEL: @clamp_float_is_noop +func @clamp_float_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: return %arg0 + // CHECK-NOT: "tosa.clamp" + %0 = "tosa.clamp"(%arg0) {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: @clamp_int8_is_noop +func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> { + // CHECK: return %arg0 + // CHECK-NOT: "tosa.clamp" + %0 = "tosa.clamp"(%arg0) {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xi8>) -> tensor<4xi8> + return %0 : tensor<4xi8> +} + +// ----- + +// CHECK-LABEL: @clamp_int16_is_noop +func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> { + // CHECK: return %arg0 + // CHECK-NOT: "tosa.clamp" + %0 = "tosa.clamp"(%arg0) {min_int = -32768 : i64, max_int = 32767 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xi16>) -> tensor<4xi16> + return %0 : tensor<4xi16> +} + +// ----- + +// CHECK-LABEL: @clamp_uint8_is_noop +func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> { + // CHECK: return %arg0 + // CHECK-NOT: "tosa.clamp" + %0 = "tosa.clamp"(%arg0) {min_int = 0 : i64, max_int = 255 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xui8>) -> tensor<4xui8> + return %0 : tensor<4xui8> +} + +// ----- + // CHECK-LABEL: @concat_fold func @concat_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0