diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -510,6 +510,7 @@ %3 = arith.shli %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000 ``` }]; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1850,6 +1850,20 @@ } return success(); } +//===----------------------------------------------------------------------===// +// ShLIOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::ShLIOp::fold(ArrayRef operands) { + // Don't fold if shifting more than the bit width. + bool bounded = false; + auto result = + constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { + bounded = b.ule(b.getBitWidth()); + return std::move(a).shl(b); + }); + return bounded ? result : Attribute(); +} //===----------------------------------------------------------------------===// // Atomic Enum diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -948,3 +948,35 @@ // CHECK: %[[c3:.+]] = arith.constant 3 : i32 // CHECK: arith.cmpi ugt, %[[arg0]], %[[c3]] : i32 } + +// ----- + +// CHECK-LABEL: @foldShl( +// CHECK: %[[res:.+]] = arith.constant 4294967296 : i64 +// CHECK: return %[[res]] +func @foldShl() -> i64 { + %c1 = arith.constant 1 : i64 + %c32 = arith.constant 32 : i64 + %r = arith.shli %c1, %c32 : i64 + return %r : i64 +} + +// CHECK-LABEL: @nofoldShl( +// CHECK: %[[res:.+]] = arith.shli +// CHECK: return %[[res]] +func @nofoldShl() -> i64 { + %c1 = arith.constant 1 : i64 + %c132 = arith.constant 132 : i64 + %r = arith.shli %c1, %c132 : i64 + return %r : i64 +} + +// CHECK-LABEL: @nofoldShl2( +// CHECK: %[[res:.+]] = arith.shli +// CHECK: return %[[res]] +func @nofoldShl2() -> i64 { + %c1 = arith.constant 1 : i64 + %cm32 = arith.constant -32 : i64 + %r = arith.shli %c1, %cm32 : i64 + return %r : i64 +}