diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -930,6 +930,8 @@ Value x = op.getOperand(); + Value isNan = builder.create(arith::CmpFPredicate::UNO, x, x); + // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) Value xL2Inv = mul(x, cstLog2E); Value kF32 = floor(xL2Inv); @@ -985,13 +987,15 @@ Value isComputable = builder.create(rightBound, leftBound); expY = builder.create( - isNegInfinityX, zerof32Const, + isNan, x, builder.create( - isPosInfinityX, constPosInfinity, + isNegInfinityX, zerof32Const, builder.create( - isComputable, expY, - builder.create(isPostiveX, constPosInfinity, - underflow)))); + isPosInfinityX, constPosInfinity, + builder.create( + isComputable, expY, + builder.create(isPostiveX, constPosInfinity, + underflow))))); rewriter.replaceOp(op, expY); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -110,6 +110,7 @@ // CHECK-DAG: %[[VAL_12:.*]] = arith.constant 1.17549435E-38 : f32 // CHECK-DAG: %[[VAL_13:.*]] = arith.constant 127 : i32 // CHECK-DAG: %[[VAL_14:.*]] = arith.constant -127 : i32 +// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[VAL_0]], %[[VAL_0]] : f32 // CHECK: %[[VAL_15:.*]] = arith.mulf %[[VAL_0]], %[[VAL_2]] : f32 // CHECK: %[[VAL_16:.*]] = math.floor %[[VAL_15]] : f32 // CHECK: %[[VAL_17:.*]] = arith.mulf %[[VAL_16]], %[[VAL_1]] : f32 @@ -136,7 +137,8 @@ // CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32 // CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32 // CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32 -// CHECK: return %[[VAL_40]] : f32 +// CHECK: %[[VAL_41:.*]] = arith.select %[[IS_NAN]], %[[VAL_0]], %[[VAL_40]] : f32 +// CHECK: return %[[VAL_41]] : f32 func @exp_scalar(%arg0: f32) -> f32 { %0 = math.exp %arg0 : f32 return %0 : f32 @@ -146,7 +148,7 @@ // CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { // CHECK: %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32> // CHECK-NOT: exp -// CHECK-COUNT-3: select +// CHECK-COUNT-4: select // CHECK: %[[VAL_40:.*]] = arith.select // CHECK: return %[[VAL_40]] : vector<8xf32> func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> { @@ -161,7 +163,7 @@ // CHECK-DAG: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32 // CHECK-NOT: exp -// CHECK-COUNT-3: select +// CHECK-COUNT-4: select // CHECK: %[[EXP_X:.*]] = arith.select // CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32 // CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32 @@ -186,7 +188,7 @@ // CHECK-SAME: %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> { // CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32> // CHECK-NOT: exp -// CHECK-COUNT-4: select +// CHECK-COUNT-5: select // CHECK-NOT: log // CHECK-COUNT-5: select // CHECK-NOT: expm1 diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -258,6 +258,11 @@ %exp_negative_inf = math.exp %negative_inf : f32 vector.print %exp_negative_inf : f32 + // CHECK: nan + %nan = arith.constant 0x7fc00000 : f32 + %exp_nan = math.exp %nan : f32 + vector.print %exp_nan : f32 + return } @@ -292,6 +297,11 @@ %log_special_vec = math.expm1 %special_vec : vector<3xf32> vector.print %log_special_vec : vector<3xf32> + // CHECK: nan + %nan = arith.constant 0x7fc00000 : f32 + %exp_nan = math.expm1 %nan : f32 + vector.print %exp_nan : f32 + return } // -------------------------------------------------------------------------- //