diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -477,7 +477,6 @@ Tosa_Tensor:$output ); - let hasCanonicalizer = 1; let hasFolder = 1; } @@ -796,7 +795,6 @@ Tosa_Tensor:$output ); - let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -246,92 +246,6 @@ results.add(context); } -struct AddZeroOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::AddOp op, - PatternRewriter &rewriter) const override { - auto input1 = op.getInput1(); - auto input2 = op.getInput2(); - - DenseElementsAttr input1Attr; - if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && - input2.getType() == op.getType()) { - if (input1Attr.getType().getElementType().isa() && - input1Attr.getSplatValue().isZero()) { - rewriter.replaceOp(op, op.getInput2()); - return success(); - } - } - - DenseElementsAttr input2Attr; - if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && - input1.getType() == op.getType()) { - if (input2Attr.getType().getElementType().isa() && - input2Attr.getSplatValue().isZero()) { - rewriter.replaceOp(op, op.getInput1()); - return success(); - } - } - - return failure(); - } -}; - -void AddOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -struct MulOneOptimization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(tosa::MulOp op, - PatternRewriter &rewriter) const override { - auto input1 = op.getInput1(); - auto input2 = op.getInput2(); - - DenseElementsAttr input1Attr; - if (matchPattern(input1, m_Constant(&input1Attr)) && input1Attr.isSplat() && - input2.getType() == op.getType()) { - if (input1Attr.getType().getElementType().isa() && - input1Attr.getSplatValue().isExactlyValue(1)) { - rewriter.replaceOp(op, op.getInput2()); - return success(); - } - - if (input1Attr.getType().getElementType().isa() && - matchPattern(input1, m_One())) { - rewriter.replaceOp(op, op.getInput2()); - return success(); - } - } - - DenseElementsAttr input2Attr; - if (matchPattern(input2, m_Constant(&input2Attr)) && input2Attr.isSplat() && - input1.getType() == op.getType()) { - if (input2Attr.getType().getElementType().isa() && - input2Attr.getSplatValue().isExactlyValue(1)) { - rewriter.replaceOp(op, op.getInput1()); - return success(); - } - - if (input2Attr.getType().getElementType().isa() && - matchPattern(input2, m_One())) { - rewriter.replaceOp(op, op.getInput1()); - return success(); - } - } - - return failure(); - } -}; - -void MulOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - struct MaterializePadValue : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -609,44 +523,47 @@ return {}; } +static bool isSplatZero(Type elemType, DenseElementsAttr val) { + if (elemType.isa()) + return val && val.isSplat() && val.getSplatValue().isZero(); + if (elemType.isa()) + return val && val.isSplat() && val.getSplatValue().isZero(); + return false; +} + +static bool isSplatOne(Type elemType, DenseElementsAttr val, int64_t shift) { + if (elemType.isa()) + return val && val.isSplat() && + val.getSplatValue().isExactlyValue(1.0); + if (elemType.isa()) { + const int64_t shifted = 1LL << shift; + return val && val.isSplat() && + val.getSplatValue().getSExtValue() == shifted; + } + return false; +} + OpFoldResult AddOp::fold(FoldAdaptor adaptor) { auto lhsTy = getInput1().getType().dyn_cast(); auto rhsTy = getInput2().getType().dyn_cast(); auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); - if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { - if (lhsAttr.getSplatValue().isZero()) - return getInput2(); - } - - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - if (rhsAttr.getSplatValue().isZero()) - return getInput1(); - } - - if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { - if (lhsAttr.getSplatValue().isZero()) - return getInput2(); - } - - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - if (rhsAttr.getSplatValue().isZero()) - return getInput1(); - } + if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) + return getInput1(); + if (rhsTy == resultTy && isSplatZero(resultETy, lhsAttr)) + return getInput2(); if (!lhsAttr || !rhsAttr) return {}; return binaryFolder, std::plus>(lhsAttr, rhsAttr, - lhsTy); + resultTy); } OpFoldResult DivOp::fold(FoldAdaptor adaptor) { @@ -724,50 +641,26 @@ auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); - if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { - auto val = lhsAttr.getSplatValue(); - if (val.isZero()) + const int64_t shift = resultETy.isa() ? getShift() : 0; + if (rhsTy == resultTy) { + if (isSplatZero(resultETy, lhsAttr)) return lhsAttr; - if (val.isExactlyValue(1.0)) + if (isSplatOne(resultETy, lhsAttr, shift)) return rhs; } - - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - auto val = rhsAttr.getSplatValue(); - if (val.isZero()) - return rhsAttr; - if (val.isExactlyValue(1.0)) - return lhs; - } - - if (lhsAttr && lhsAttr.isSplat() && resultETy.isa()) { - auto val = lhsAttr.getSplatValue(); - if (val.isZero()) - return lhsAttr; - const int64_t shift = getShift(); - const int64_t shifted = 1LL << shift; - if (val.getSExtValue() == shifted) - return rhs; - } - - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - auto val = rhsAttr.getSplatValue(); - const int64_t shift = getShift(); - const int64_t shifted = 1LL << shift; - if (val.isZero()) + if (lhsTy == resultTy) { + if (isSplatZero(resultETy, rhsAttr)) return rhsAttr; - if (val.getSExtValue() == shifted) + if (isSplatOne(resultETy, rhsAttr, shift)) return lhs; } - return mulBinaryFolder(lhsAttr, rhsAttr, lhsTy, getShift()); + return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift()); } OpFoldResult SubOp::fold(FoldAdaptor adaptor) { @@ -776,28 +669,19 @@ auto resultTy = getType().dyn_cast(); if (!lhsTy || !rhsTy || !resultTy) return {}; - if (lhsTy != rhsTy) - return {}; auto resultETy = resultTy.getElementType(); auto lhsAttr = adaptor.getInput1().dyn_cast_or_null(); auto rhsAttr = adaptor.getInput2().dyn_cast_or_null(); - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - if (rhsAttr.getSplatValue().isZero()) - return getInput1(); - } - - if (rhsAttr && rhsAttr.isSplat() && resultETy.isa()) { - if (rhsAttr.getSplatValue().isZero()) - return getInput1(); - } + if (lhsTy == resultTy && isSplatZero(resultETy, rhsAttr)) + return getInput1(); if (!lhsAttr || !rhsAttr) return {}; return binaryFolder, std::minus>(lhsAttr, rhsAttr, - lhsTy); + resultTy); } namespace { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -7,15 +7,15 @@ return %0 : tensor } -// CHECK-LABEL: @add_zero_different_shape -func.func @add_zero_different_shape(%arg0: tensor<2x3xi32>) -> tensor<4x2x3xi32> { - // CHECK: tosa.add - %zeros = "tosa.const"() {value = dense<0> : tensor<4x2x3xi32>} : () -> tensor<4x2x3xi32> - %1 = "tosa.add"(%arg0, %zeros) : (tensor<2x3xi32>, tensor<4x2x3xi32>) -> tensor<4x2x3xi32> +// CHECK-LABEL: @add_bcast_zero_int +func.func @add_bcast_zero_int(%arg0: tensor<4x2x3xi32>) -> tensor<4x2x3xi32> { + // CHECK-NOT: tosa.add + // CHECK: return %arg0 + %zeros = "tosa.const"() {value = dense<0> : tensor<1x1x1xi32>} : () -> tensor<1x1x1xi32> + %1 = "tosa.add"(%arg0, %zeros) : (tensor<4x2x3xi32>, tensor<1x1x1xi32>) -> tensor<4x2x3xi32> return %1 : tensor<4x2x3xi32> } - // CHECK-LABEL: @add_zero_int func.func @add_zero_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK: return %arg0 @@ -176,14 +176,6 @@ return %1 : tensor } -// CHECK-LABEL: @mul_one_different_shape -func.func @mul_one_different_shape(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> { - // CHECK: tosa.mul - %ones = "tosa.const"() {value = dense<1.0> : tensor<4x2x3xf32>} : () -> tensor<4x2x3xf32> - %1 = "tosa.mul"(%arg0, %ones) {shift = 0 : i32} : (tensor<2x3xf32>, tensor<4x2x3xf32>) -> tensor<4x2x3xf32> - return %1 : tensor<4x2x3xf32> -} - // CHECK-LABEL: @mul_one_float func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %arg0 @@ -193,6 +185,15 @@ return %1 : tensor<2x3xf32> } +// CHECK-LABEL: @mul_bcast_one_float +func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: return %arg0 + // CHECK-NOT: tosa.mul + %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> + %1 = "tosa.mul"(%ones, %arg0) {shift = 0 : i32} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + return %1 : tensor<2x3xf32> +} + // CHECK-LABEL: @mul_one_int func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK: return %arg0