diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td @@ -781,6 +781,7 @@ def Arith_RemFOp : Arith_FloatBinaryOp<"remf"> { let summary = "floating point division remainder operation"; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -756,6 +756,19 @@ operands, [](const APFloat &a, const APFloat &b) { return a / b; }); } +//===----------------------------------------------------------------------===// +// RemFOp +//===----------------------------------------------------------------------===// + +OpFoldResult arith::RemFOp::fold(ArrayRef operands) { + return constFoldBinaryOp(operands, + [](const APFloat &a, const APFloat &b) { + APFloat Result(a); + (void)Result.remainder(b); + return Result; + }); +} + //===----------------------------------------------------------------------===// // Utility functions for verifying cast ops //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -1374,3 +1374,25 @@ %0 = arith.remsi %arg, %v : vector<4xi32> return %0 : vector<4xi32> } + +// ----- + +// CHECK-LABEL: @test_remf( +// CHECK: %[[res:.+]] = arith.constant -1.000000e+00 : f32 +// CHECK: return %[[res]] +func.func @test_remf() -> (f32) { + %v1 = arith.constant 3.0 : f32 + %v2 = arith.constant 2.0 : f32 + %0 = arith.remf %v1, %v2 : f32 + return %0 : f32 +} + +// CHECK-LABEL: @test_remf_vec( +// CHECK: %[[res:.+]] = arith.constant dense<[1.000000e+00, 0.000000e+00, -1.000000e+00, 0.000000e+00]> : vector<4xf32> +// CHECK: return %[[res]] +func.func @test_remf_vec() -> (vector<4xf32>) { + %v1 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : vector<4xf32> + %v2 = arith.constant dense<[2.0, 2.0, 2.0, 2.0]> : vector<4xf32> + %0 = arith.remf %v1, %v2 : vector<4xf32> + return %0 : vector<4xf32> +}