diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -265,6 +265,7 @@ ]; let hasFolder = 1; + let hasCanonicalizer = 1; let extraClassDeclaration = [{ ::llvm::Optional<::llvm::SmallVector> getShapeForUnroll(); diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -49,6 +49,17 @@ (Arith_ConstantOp APIntAttr:$c1)), (Arith_SubIOp (Arith_ConstantOp (AddIntAttrs $res, $c0, $c1)), $x)>; +//===----------------------------------------------------------------------===// +// AddUIExtendedOp +//===----------------------------------------------------------------------===// + +// addui_extended(x, y) -> [addi(x, y), x], when the `overflow` result has no +// uses. Since the 'overflow' result is unused, any replacement value will do. +def AddUIExtendedToAddI: + Pattern<(Arith_AddUIExtendedOp:$res $x, $y), + [(Arith_AddIOp $x, $y), (replaceWithValue $x)], + [(Constraint> $res__1)]>; + //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -295,6 +295,11 @@ return failure(); } +void arith::AddUIExtendedOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -662,6 +662,24 @@ return %sum, %overflow : i32, i1 } +// CHECK-LABEL: @adduiExtendedUnusedOverflowScalar +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) -> i32 +// CHECK-NEXT: %[[RES:.+]] = arith.addi %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RES]] : i32 +func.func @adduiExtendedUnusedOverflowScalar(%arg0: i32, %arg1: i32) -> i32 { + %sum, %overflow = arith.addui_extended %arg0, %arg1: i32, i1 + return %sum : i32 +} + +// CHECK-LABEL: @adduiExtendedUnusedOverflowVector +// CHECK-SAME: (%[[LHS:.+]]: vector<3xi32>, %[[RHS:.+]]: vector<3xi32>) -> vector<3xi32> +// CHECK-NEXT: %[[RES:.+]] = arith.addi %[[LHS]], %[[RHS]] : vector<3xi32> +// CHECK-NEXT: return %[[RES]] : vector<3xi32> +func.func @adduiExtendedUnusedOverflowVector(%arg0: vector<3xi32>, %arg1: vector<3xi32>) -> vector<3xi32> { + %sum, %overflow = arith.addui_extended %arg0, %arg1: vector<3xi32>, vector<3xi1> + return %sum : vector<3xi32> +} + // CHECK-LABEL: @adduiExtendedConstants // CHECK-DAG: %[[false:.+]] = arith.constant false // CHECK-DAG: %[[c50:.+]] = arith.constant 50 : i32