diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -49,6 +49,27 @@ (ConstantLikeMatcher APIntAttr:$c1)), (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>; +def IsScalarOrSplatNegativeOne : + Constraint, + CPred<"getIntOrSplatIntValue($0)->isAllOnes()">]>>; + +// addi(x, muli(y, -1)) -> subi(x, y) +def AddIMulNegativeOneRhs : + Pat<(Arith_AddIOp + $x, + (Arith_MulIOp $y, (ConstantLikeMatcher AnyAttr:$c0))), + (Arith_SubIOp $x, $y), + [(IsScalarOrSplatNegativeOne $c0)]>; + +// addi(muli(x, -1), y) -> subi(y, x) +def AddIMulNegativeOneLhs : + Pat<(Arith_AddIOp + (Arith_MulIOp $x, (ConstantLikeMatcher AnyAttr:$c0)), + $y), + (Arith_SubIOp $y, $x), + [(IsScalarOrSplatNegativeOne $c0)]>; + //===----------------------------------------------------------------------===// // AddUIExtendedOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -258,8 +258,8 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add( - context); + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -735,6 +735,72 @@ return %add : index } +// CHECK-LABEL: @addiMuliToSubiRhsI32 +// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) +// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32 +// CHECK: return %[[SUB]] +func.func @addiMuliToSubiRhsI32(%arg0: i32, %arg1: i32) -> i32 { + %c-1 = arith.constant -1 : i32 + %neg = arith.muli %arg1, %c-1 : i32 + %add = arith.addi %arg0, %neg : i32 + return %add : i32 +} + +// CHECK-LABEL: @addiMuliToSubiRhsIndex +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index) +// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index +// CHECK: return %[[SUB]] +func.func @addiMuliToSubiRhsIndex(%arg0: index, %arg1: index) -> index { + %c-1 = arith.constant -1 : index + %neg = arith.muli %arg1, %c-1 : index + %add = arith.addi %arg0, %neg : index + return %add : index +} + +// CHECK-LABEL: @addiMuliToSubiRhsVector +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>) +// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64> +// CHECK: return %[[SUB]] +func.func @addiMuliToSubiRhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> { + %c-1 = arith.constant dense<-1> : vector<3xi64> + %neg = arith.muli %arg1, %c-1 : vector<3xi64> + %add = arith.addi %arg0, %neg : vector<3xi64> + return %add : vector<3xi64> +} + +// CHECK-LABEL: @addiMuliToSubiLhsI32 +// CHECK-SAME: (%[[ARG0:.+]]: i32, %[[ARG1:.+]]: i32) +// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : i32 +// CHECK: return %[[SUB]] +func.func @addiMuliToSubiLhsI32(%arg0: i32, %arg1: i32) -> i32 { + %c-1 = arith.constant -1 : i32 + %neg = arith.muli %arg1, %c-1 : i32 + %add = arith.addi %neg, %arg0 : i32 + return %add : i32 +} + +// CHECK-LABEL: @addiMuliToSubiLhsIndex +// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index) +// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : index +// CHECK: return %[[SUB]] +func.func @addiMuliToSubiLhsIndex(%arg0: index, %arg1: index) -> index { + %c-1 = arith.constant -1 : index + %neg = arith.muli %arg1, %c-1 : index + %add = arith.addi %neg, %arg0 : index + return %add : index +} + +// CHECK-LABEL: @addiMuliToSubiLhsVector +// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi64>, %[[ARG1:.+]]: vector<3xi64>) +// CHECK: %[[SUB:.+]] = arith.subi %[[ARG0]], %[[ARG1]] : vector<3xi64> +// CHECK: return %[[SUB]] +func.func @addiMuliToSubiLhsVector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi64> { + %c-1 = arith.constant dense<-1> : vector<3xi64> + %neg = arith.muli %arg1, %c-1 : vector<3xi64> + %add = arith.addi %neg, %arg0 : vector<3xi64> + return %add : vector<3xi64> +} + // CHECK-LABEL: @adduiExtendedZeroRhs // CHECK-NEXT: %[[false:.+]] = arith.constant false // CHECK-NEXT: return %arg0, %[[false]]