diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -567,15 +567,20 @@ Value isNegInfinityX = builder.create( arith::CmpFPredicate::OEQ, x, constNegIfinity); + Value isPosInfinityX = builder.create( + arith::CmpFPredicate::OEQ, x, constPosInfinity); Value isPostiveX = builder.create(arith::CmpFPredicate::OGT, x, zerof32Const); Value isComputable = builder.create(rightBound, leftBound); expY = builder.create( - isComputable, expY, + isNegInfinityX, zerof32Const, builder.create( - isPostiveX, constPosInfinity, - builder.create(isNegInfinityX, zerof32Const, underflow))); + isPosInfinityX, constPosInfinity, + builder.create(isComputable, expY, + builder.create(isPostiveX, + constPosInfinity, + underflow)))); rewriter.replaceOp(op, expY); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -38,13 +38,14 @@ // CHECK: %[[VAL_31:.*]] = arith.cmpi sle, %[[VAL_26]], %[[VAL_13]] : i32 // CHECK: %[[VAL_32:.*]] = arith.cmpi sge, %[[VAL_26]], %[[VAL_14]] : i32 // CHECK: %[[VAL_33:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_11]] : f32 -// CHECK: %[[VAL_34:.*]] = arith.cmpf ogt, %[[VAL_0]], %[[VAL_9]] : f32 -// CHECK: %[[VAL_35:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1 -// CHECK: %[[VAL_36:.*]] = select %[[VAL_33]], %[[VAL_9]], %[[VAL_12]] : f32 -// CHECK: %[[VAL_37:.*]] = select %[[VAL_34]], %[[VAL_10]], %[[VAL_36]] : f32 -// CHECK: %[[VAL_38:.*]] = select %[[VAL_35]], %[[VAL_30]], %[[VAL_37]] : f32 -// CHECK: return %[[VAL_38]] : f32 -// CHECK: } +// CHECK: %[[VAL_34:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_10]] : f32 +// CHECK: %[[VAL_35:.*]] = arith.cmpf ogt, %[[VAL_0]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_36:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1 +// CHECK: %[[VAL_37:.*]] = select %[[VAL_35]], %[[VAL_10]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_38:.*]] = select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32 +// CHECK: %[[VAL_39:.*]] = select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32 +// CHECK: %[[VAL_40:.*]] = select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32 +// CHECK: return %[[VAL_40]] : f32 func @exp_scalar(%arg0: f32) -> f32 { %0 = math.exp %arg0 : f32 return %0 : f32 @@ -54,10 +55,9 @@ // CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { // CHECK: %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32> // CHECK-NOT: exp -// CHECK-COUNT-2: select -// CHECK: %[[VAL_38:.*]] = select -// CHECK: return %[[VAL_38]] : vector<8xf32> -// CHECK: } +// CHECK-COUNT-3: select +// CHECK: %[[VAL_40:.*]] = select +// CHECK: return %[[VAL_40]] func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> { %0 = math.exp %arg0 : vector<8xf32> return %0 : vector<8xf32> @@ -70,7 +70,7 @@ // CHECK-DAG: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32 // CHECK: %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32 // CHECK-NOT: exp -// CHECK-COUNT-2: select +// CHECK-COUNT-3: select // CHECK: %[[EXP_X:.*]] = select // CHECK: %[[VAL_58:.*]] = arith.cmpf oeq, %[[EXP_X]], %[[CST_ONE]] : f32 // CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32 @@ -95,7 +95,7 @@ // CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { // CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8xf32> // CHECK-NOT: exp -// CHECK-COUNT-3: select +// CHECK-COUNT-4: select // CHECK-NOT: log // CHECK-COUNT-5: select // CHECK-NOT: expm1 diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -10,9 +10,6 @@ // RUN: -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// XFAIL: s390x -// (see https://bugs.llvm.org/show_bug.cgi?id=51204) - // -------------------------------------------------------------------------- // // Tanh. // -------------------------------------------------------------------------- //