diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -444,23 +444,22 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::RemUIOp::fold(ArrayRef operands) { - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - auto rhsValue = rhs.getValue(); - - // x % 1 = 0 - if (rhsValue.isOneValue()) - return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + // remui (x, 1) -> 0. + if (matchPattern(getRhs(), m_One())) + return Builder(getContext()).getZeroAttr(getType()); - // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; + // Don't fold if it would require a division by zero. + bool div0 = false; + auto result = + constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { + if (div0 || b.isNullValue()) { + div0 = true; + return a; + } + return a.urem(b); + }); - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); + return div0 ? Attribute() : result; } //===----------------------------------------------------------------------===// @@ -468,23 +467,22 @@ //===----------------------------------------------------------------------===// OpFoldResult arith::RemSIOp::fold(ArrayRef operands) { - auto rhs = operands.back().dyn_cast_or_null(); - if (!rhs) - return {}; - auto rhsValue = rhs.getValue(); - - // x % 1 = 0 - if (rhsValue.isOneValue()) - return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); + // remsi (x, 1) -> 0. + if (matchPattern(getRhs(), m_One())) + return Builder(getContext()).getZeroAttr(getType()); - // Don't fold if it requires division by zero. - if (rhsValue.isNullValue()) - return {}; + // Don't fold if it would require a division by zero. + bool div0 = false; + auto result = + constFoldBinaryOp(operands, [&](APInt a, const APInt &b) { + if (div0 || b.isNullValue()) { + div0 = true; + return a; + } + return a.srem(b); + }); - auto lhs = operands.front().dyn_cast_or_null(); - if (!lhs) - return {}; - return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); + return div0 ? Attribute() : result; } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1319,3 +1319,49 @@ %0 = arith.negf %c : f32 return %0: f32 } + +// ----- + +// CHECK-LABEL: @test_remui( +// CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32> +// CHECK: return %[[res]] +func @test_remui() -> (vector<4xi32>) { + %v1 = arith.constant dense<[9, 9, 9, 9]> : vector<4xi32> + %v2 = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> + %0 = arith.remui %v1, %v2 : vector<4xi32> + return %0 : vector<4xi32> +} + +// // ----- + +// CHECK-LABEL: @test_remui_1( +// CHECK: %[[res:.+]] = arith.constant dense<0> : vector<4xi32> +// CHECK: return %[[res]] +func @test_remui_1(%arg : vector<4xi32>) -> (vector<4xi32>) { + %v = arith.constant dense<[1, 1, 1, 1]> : vector<4xi32> + %0 = arith.remui %arg, %v : vector<4xi32> + return %0 : vector<4xi32> +} + +// ----- + +// CHECK-LABEL: @test_remsi( +// CHECK: %[[res:.+]] = arith.constant dense<[0, 0, 4, 2]> : vector<4xi32> +// CHECK: return %[[res]] +func @test_remsi() -> (vector<4xi32>) { + %v1 = arith.constant dense<[9, 9, 9, 9]> : vector<4xi32> + %v2 = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> + %0 = arith.remsi %v1, %v2 : vector<4xi32> + return %0 : vector<4xi32> +} + +// // ----- + +// CHECK-LABEL: @test_remsi_1( +// CHECK: %[[res:.+]] = arith.constant dense<0> : vector<4xi32> +// CHECK: return %[[res]] +func @test_remsi_1(%arg : vector<4xi32>) -> (vector<4xi32>) { + %v = arith.constant dense<[1, 1, 1, 1]> : vector<4xi32> + %0 = arith.remsi %arg, %v : vector<4xi32> + return %0 : vector<4xi32> +}