diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -526,9 +526,44 @@ } }; +struct ClampClampOptimization : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ClampOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + auto inputType = op.input().getType().template dyn_cast(); + if (!inputType.hasStaticShape()) { + return failure(); + } + + Operation *definingOp = input.getDefiningOp(); + if (!definingOp) + return failure(); + + if (tosa::ClampOp clampOp = dyn_cast(definingOp)) { + auto min_fp = std::max(op.min_fp(), clampOp.min_fp()).convertToFloat(); + auto max_fp = std::min(op.max_fp(), clampOp.max_fp()).convertToFloat(); + + auto min_int = std::max(op.min_int(), clampOp.min_int()); + auto max_int = std::min(op.max_int(), clampOp.max_int()); + + rewriter.replaceOpWithNewOp( + op, op.getType(), clampOp.input(), + rewriter.getI64IntegerAttr(min_int), + rewriter.getI64IntegerAttr(max_int), rewriter.getF32FloatAttr(min_fp), + rewriter.getF32FloatAttr(max_fp)); + return success(); + } + + return failure(); + } +}; + void ClampOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -98,6 +98,16 @@ // ----- +// CHECK-LABEL: @clamp_twice_is_single_clamp +func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { + // CHECK: "tosa.clamp"(%arg0) {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64} + %0 = "tosa.clamp"(%arg0) {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8> + %1 = "tosa.clamp"(%0) {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64} : (tensor<4xi8>) -> tensor<4xi8> + return %1 : tensor<4xi8> +} + +// ----- + // CHECK-LABEL: @concat_fold func @concat_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0