diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2253,6 +2253,9 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::ShLIOp::fold(ArrayRef operands) { + // shli(x, 0) -> x + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( @@ -2268,6 +2271,9 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::ShRUIOp::fold(ArrayRef operands) { + // shrui(x, 0) -> x + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( @@ -2283,6 +2289,9 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::ShRSIOp::fold(ArrayRef operands) { + // shrsi(x, 0) -> x + if (matchPattern(getRhs(), m_Zero())) + return getLhs(); // Don't fold if shifting more than the bit width. bool bounded = false; auto result = constFoldBinaryOp( diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2107,3 +2107,30 @@ %hi = arith.trunci %sh: i64 to i32 return %hi : i32 } + +// CHECK-LABEL: @foldShli0 +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: return %[[ARG]] : i64 +func.func @foldShli0(%x : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %r = arith.shli %x, %c0 : i64 + return %r : i64 +} + +// CHECK-LABEL: @foldShrui0 +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: return %[[ARG]] : i64 +func.func @foldShrui0(%x : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %r = arith.shrui %x, %c0 : i64 + return %r : i64 +} + +// CHECK-LABEL: @foldShrsi0 +// CHECK-SAME: (%[[ARG:.*]]: i64) +// CHECK: return %[[ARG]] : i64 +func.func @foldShrsi0(%x : i64) -> i64 { + %c0 = arith.constant 0 : i64 + %r = arith.shrsi %x, %c0 : i64 + return %r : i64 +}