diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -532,6 +532,7 @@ %3 = arith.shrui %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -556,6 +557,7 @@ %5 = arith.shrsi %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1871,6 +1871,36 @@ return bounded ? result : Attribute(); } +//===----------------------------------------------------------------------===// +// ShRUIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::ShRUIOp::fold(ArrayRef operands) { + // Don't fold if shifting more than the bit width. + bool bounded = false; + auto result = constFoldBinaryOp( + operands, [&](const APInt &a, const APInt &b) { + bounded = b.ule(b.getBitWidth()); + return std::move(a).lshr(b); + }); + return bounded ? result : Attribute(); +} + +//===----------------------------------------------------------------------===// +// ShRSIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::ShRSIOp::fold(ArrayRef operands) { + // Don't fold if shifting more than the bit width. + bool bounded = false; + auto result = constFoldBinaryOp( + operands, [&](const APInt &a, const APInt &b) { + bounded = b.ule(b.getBitWidth()); + return std::move(a).ashr(b); + }); + return bounded ? result : Attribute(); +} + //===----------------------------------------------------------------------===// // Atomic Enum //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1022,3 +1022,83 @@ %r = arith.shli %c1, %cm32 : i64 return %r : i64 } + +// CHECK-LABEL: @foldShru( +// CHECK: %[[res:.+]] = arith.constant 2 : i64 +// CHECK: return %[[res]] +func @foldShru() -> i64 { + %c1 = arith.constant 8 : i64 + %c32 = arith.constant 2 : i64 + %r = arith.shrui %c1, %c32 : i64 + return %r : i64 +} + +// CHECK-LABEL: @foldShru2( +// CHECK: %[[res:.+]] = arith.constant 9223372036854775807 : i64 +// CHECK: return %[[res]] +func @foldShru2() -> i64 { + %c1 = arith.constant -2 : i64 + %c32 = arith.constant 1 : i64 + %r = arith.shrui %c1, %c32 : i64 + return %r : i64 +} + +// CHECK-LABEL: @nofoldShru( +// CHECK: %[[res:.+]] = arith.shrui +// CHECK: return %[[res]] +func @nofoldShru() -> i64 { + %c1 = arith.constant 8 : i64 + %c132 = arith.constant 132 : i64 + %r = arith.shrui %c1, %c132 : i64 + return %r : i64 +} + +// CHECK-LABEL: @nofoldShru2( +// CHECK: %[[res:.+]] = arith.shrui +// CHECK: return %[[res]] +func @nofoldShru2() -> i64 { + %c1 = arith.constant 8 : i64 + %cm32 = arith.constant -32 : i64 + %r = arith.shrui %c1, %cm32 : i64 + return %r : i64 +} + +// CHECK-LABEL: @foldShrs( +// CHECK: %[[res:.+]] = arith.constant 2 : i64 +// CHECK: return %[[res]] +func @foldShrs() -> i64 { + %c1 = arith.constant 8 : i64 + %c32 = arith.constant 2 : i64 + %r = arith.shrsi %c1, %c32 : i64 + return %r : i64 +} + +// CHECK-LABEL: @foldShrs2( +// CHECK: %[[res:.+]] = arith.constant -1 : i64 +// CHECK: return %[[res]] +func @foldShrs2() -> i64 { + %c1 = arith.constant -2 : i64 + %c32 = arith.constant 1 : i64 + %r = arith.shrsi %c1, %c32 : i64 + return %r : i64 +} + +// CHECK-LABEL: @nofoldShrs( +// CHECK: %[[res:.+]] = arith.shrsi +// CHECK: return %[[res]] +func @nofoldShrs() -> i64 { + %c1 = arith.constant 8 : i64 + %c132 = arith.constant 132 : i64 + %r = arith.shrsi %c1, %c132 : i64 + return %r : i64 +} + +// CHECK-LABEL: @nofoldShrs2( +// CHECK: %[[res:.+]] = arith.shrsi +// CHECK: return %[[res]] +func @nofoldShrs2() -> i64 { + %c1 = arith.constant 8 : i64 + %cm32 = arith.constant -32 : i64 + %r = arith.shrsi %c1, %cm32 : i64 + return %r : i64 +}