diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -1033,8 +1033,8 @@ Value cstNegOne = bcast(f32Cst(builder, -1.0f)); Value x = op.getOperand(); Value u = builder.create(x); - Value uEqOne = - builder.create(arith::CmpFPredicate::OEQ, u, cstOne); + Value uEqOneOrNaN = + builder.create(arith::CmpFPredicate::UEQ, u, cstOne); Value uMinusOne = builder.create(u, cstOne); Value uMinusOneEqNegOne = builder.create( arith::CmpFPredicate::OEQ, uMinusOne, cstNegOne); @@ -1050,7 +1050,7 @@ uMinusOne, builder.create(x, logU)); expm1 = builder.create(isInf, u, expm1); Value approximation = builder.create( - uEqOne, x, + uEqOneOrNaN, x, builder.create(uMinusOneEqNegOne, cstNegOne, expm1)); rewriter.replaceOp(op, approximation); return success(); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -163,7 +163,7 @@ // CHECK-NOT: exp // CHECK-COUNT-3: select // CHECK: %[[EXP_X:.*]] = arith.select -// CHECK: %[[VAL_58:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32 +// CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32 // CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32 // CHECK: %[[VAL_60:.*]] = arith.cmpf oeq, %[[VAL_59]], %[[CST_MINUSONE]] : f32 // CHECK-NOT: log @@ -174,7 +174,7 @@ // CHECK: %[[VAL_106:.*]] = arith.mulf %[[VAL_59]], %[[VAL_105]] : f32 // CHECK: %[[VAL_107:.*]] = arith.select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32 // CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32 -// CHECK: %[[VAL_109:.*]] = arith.select %[[VAL_58]], %[[X]], %[[VAL_108]] : f32 +// CHECK: %[[VAL_109:.*]] = arith.select %[[IS_ONE_OR_NAN]], %[[X]], %[[VAL_108]] : f32 // CHECK: return %[[VAL_109]] : f32 // CHECK: } func @expm1_scalar(%arg0: f32) -> f32 {