diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -182,16 +182,21 @@ // Helper functions to build math functions approximations. //----------------------------------------------------------------------------// -static Value min(ImplicitLocOpBuilder &builder, Value a, Value b) { +// Return the minimum of the two values or NaN if value is NaN +static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound) { return builder.create( - builder.create(arith::CmpFPredicate::OLT, a, b), a, b); + builder.create(arith::CmpFPredicate::ULT, value, bound), + value, bound); } -static Value max(ImplicitLocOpBuilder &builder, Value a, Value b) { +// Return the maximum of the two values or NaN if value is NaN +static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound) { return builder.create( - builder.create(arith::CmpFPredicate::OGT, a, b), a, b); + builder.create(arith::CmpFPredicate::UGT, value, bound), + value, bound); } +// Return the clamped value or NaN if value is NaN static Value clamp(ImplicitLocOpBuilder &builder, Value value, Value lowerBound, Value upperBound) { return max(builder, min(builder, value, upperBound), lowerBound); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -225,7 +225,7 @@ // CHECK: %[[VAL_20:.*]] = arith.constant 1056964608 : i32 // CHECK: %[[VAL_21:.*]] = arith.constant 23 : i32 // CHECK: %[[VAL_22:.*]] = arith.constant 0.693147182 : f32 -// CHECK: %[[VAL_23:.*]] = arith.cmpf ogt, %[[X]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_23:.*]] = arith.cmpf ugt, %[[X]], %[[VAL_4]] : f32 // CHECK: %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[X]], %[[VAL_4]] : f32 // CHECK-NOT: frexp // CHECK: %[[VAL_25:.*]] = arith.bitcast %[[VAL_24]] : f32 to i32 @@ -355,9 +355,9 @@ // CHECK: %[[VAL_12:.*]] = arith.constant 0.00226843474 : f32 // CHECK: %[[VAL_13:.*]] = arith.constant 1.18534706E-4 : f32 // CHECK: %[[VAL_14:.*]] = arith.constant 1.19825836E-6 : f32 -// CHECK: %[[VAL_15:.*]] = arith.cmpf olt, %[[VAL_0]], %[[VAL_2]] : f32 +// CHECK: %[[VAL_15:.*]] = arith.cmpf ult, %[[VAL_0]], %[[VAL_2]] : f32 // CHECK: %[[VAL_16:.*]] = arith.select %[[VAL_15]], %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_17:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_17:.*]] = arith.cmpf ugt, %[[VAL_16]], %[[VAL_1]] : f32 // CHECK: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_16]], %[[VAL_1]] : f32 // CHECK: %[[VAL_19:.*]] = math.abs %[[VAL_0]] : f32 // CHECK: %[[VAL_20:.*]] = arith.cmpf olt, %[[VAL_19]], %[[VAL_3]] : f32 diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -29,6 +29,11 @@ %5 = math.tanh %4 : vector<8xf32> vector.print %5 : vector<8xf32> + // CHECK-NEXT: nan + %nan = arith.constant 0x7fc00000 : f32 + %6 = math.tanh %nan : f32 + vector.print %6 : f32 + return }