diff --git a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp --- a/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/IR/ArithmeticOps.cpp @@ -1299,6 +1299,12 @@ auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); + // If one operand is NaN, making them both NaN does not change the result. + if (lhs && lhs.getValue().isNaN()) + rhs = lhs; + if (rhs && rhs.getValue().isNaN()) + lhs = rhs; + if (!lhs || !rhs) return {}; diff --git a/mlir/test/Dialect/Arithmetic/canonicalize.mlir b/mlir/test/Dialect/Arithmetic/canonicalize.mlir --- a/mlir/test/Dialect/Arithmetic/canonicalize.mlir +++ b/mlir/test/Dialect/Arithmetic/canonicalize.mlir @@ -691,3 +691,17 @@ %res = arith.maxf %const, %min : f32 return %res : f32 } + +// ----- +// CHECK-LABEL: @cmpf_nan( +func @cmpf_nan(%arg0 : f32) -> (i1, i1, i1, i1) { +// CHECK-DAG: %[[T:.*]] = arith.constant true +// CHECK-DAG: %[[F:.*]] = arith.constant false +// CHECK: return %[[F]], %[[F]], %[[T]], %[[T]] + %nan = arith.constant 0x7fffffff : f32 + %0 = arith.cmpf olt, %nan, %arg0 : f32 + %1 = arith.cmpf olt, %arg0, %nan : f32 + %2 = arith.cmpf ugt, %nan, %arg0 : f32 + %3 = arith.cmpf ugt, %arg0, %nan : f32 + return %0, %1, %2, %3 : i1, i1, i1, i1 +}