diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -1032,6 +1032,7 @@ Value cstOne = bcast(f32Cst(builder, 1.0f)); Value cstNegOne = bcast(f32Cst(builder, -1.0f)); Value x = op.getOperand(); + Value xIsNaN = builder.create(arith::CmpFPredicate::UNO, x, x); Value u = builder.create(x); Value uEqOne = builder.create(arith::CmpFPredicate::OEQ, u, cstOne); @@ -1049,8 +1050,9 @@ Value expm1 = builder.create( uMinusOne, builder.create(x, logU)); expm1 = builder.create(isInf, u, expm1); + Value uEqOneOrNaN = builder.create(uEqOne, xIsNaN); Value approximation = builder.create( - uEqOne, x, + uEqOneOrNaN, x, builder.create(uMinusOneEqNegOne, cstNegOne, expm1)); rewriter.replaceOp(op, approximation); return success(); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -159,11 +159,12 @@ // CHECK-DAG: %[[CST_MINUSONE:.*]] = arith.constant -1.000000e+00 : f32 // CHECK-DAG: %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32 // CHECK-DAG: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[IS_NAN:.*]] = arith.cmpf uno, %[[X]], %[[X]] : f32 // CHECK: %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32 // CHECK-NOT: exp // CHECK-COUNT-3: select // CHECK: %[[EXP_X:.*]] = arith.select -// CHECK: %[[VAL_58:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32 +// CHECK: %[[IS_ONE:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32 // CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32 // CHECK: %[[VAL_60:.*]] = arith.cmpf oeq, %[[VAL_59]], %[[CST_MINUSONE]] : f32 // CHECK-NOT: log @@ -173,8 +174,9 @@ // CHECK: %[[VAL_105:.*]] = arith.divf %[[X]], %[[LOG_U]] : f32 // CHECK: %[[VAL_106:.*]] = arith.mulf %[[VAL_59]], %[[VAL_105]] : f32 // CHECK: %[[VAL_107:.*]] = arith.select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32 +// CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.ori %[[IS_ONE]], %[[IS_NAN]] : i1 // CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32 -// CHECK: %[[VAL_109:.*]] = arith.select %[[VAL_58]], %[[X]], %[[VAL_108]] : f32 +// CHECK: %[[VAL_109:.*]] = arith.select %[[IS_ONE_OR_NAN]], %[[X]], %[[VAL_108]] : f32 // CHECK: return %[[VAL_109]] : f32 // CHECK: } func @expm1_scalar(%arg0: f32) -> f32 {