diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include +#include #include #include "mlir/Dialect/Arith/IR/Arith.h" @@ -171,7 +172,7 @@ builder.getFloatAttr(elementType, value)); } -static Value f32Cst(ImplicitLocOpBuilder &builder, float value) { +static Value f32Cst(ImplicitLocOpBuilder &builder, double value) { return builder.create(builder.getF32FloatAttr(value)); } @@ -380,35 +381,76 @@ ArrayRef shape = vectorShape(op.getOperand()); ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - auto one = broadcast(builder, f32Cst(builder, 1.0f), shape); - - // Remap the problem over [0.0, 1.0] by looking at the absolute value and the - // handling symmetry. Value abs = builder.create(operand); - Value reciprocal = builder.create(one, abs); - Value compare = - builder.create(arith::CmpFPredicate::OLT, abs, reciprocal); - Value x = builder.create(compare, abs, reciprocal); - // Perform the Taylor series approximation for atan over the range - // [-1.0, 1.0]. - auto n1 = broadcast(builder, f32Cst(builder, 0.14418283f), shape); - auto n2 = broadcast(builder, f32Cst(builder, -0.34999234f), shape); - auto n3 = broadcast(builder, f32Cst(builder, -0.01067831f), shape); - auto n4 = broadcast(builder, f32Cst(builder, 1.00209986f), shape); - - Value p = builder.create(x, n1, n2); - p = builder.create(x, p, n3); - p = builder.create(x, p, n4); - p = builder.create(x, p); + auto one = broadcast(builder, f32Cst(builder, 1.0), shape); - // Remap the solution for over [0.0, 1.0] to [0.0, inf] - auto halfPi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); - Value sub = builder.create(halfPi, p); - Value select = builder.create(compare, p, sub); + // When 0.66 < x <= 2.41 we do (x-1) / (x+1): + auto twoThirds = broadcast(builder, f32Cst(builder, 0.66), shape); + Value cmp2 = + builder.create(arith::CmpFPredicate::OGT, abs, twoThirds); + Value addone = builder.create(abs, one); + Value subone = builder.create(abs, one); + Value xnum = builder.create(cmp2, subone, abs); + Value xden = builder.create(cmp2, addone, one); + + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, shape); + }; + + // Break into the <= 0.66 or > 2.41 we do x or 1/x: + auto tan3pio8 = bcast(f32Cst(builder, 2.41421356237309504880)); + Value cmp1 = + builder.create(arith::CmpFPredicate::OGT, abs, tan3pio8); + xnum = builder.create(cmp1, one, xnum); + xden = builder.create(cmp1, abs, xden); + + Value x = builder.create(xnum, xden); + Value xx = builder.create(x, x); + + // Perform the Taylor series approximation for atan over the range + // [0.0, 0.66]. + auto p0 = bcast(f32Cst(builder, -8.750608600031904122785e-01)); + auto p1 = bcast(f32Cst(builder, -1.615753718733365076637e+01)); + auto p2 = bcast(f32Cst(builder, -7.500855792314704667340e+01)); + auto p3 = bcast(f32Cst(builder, -1.228866684490136173410e+02)); + auto p4 = bcast(f32Cst(builder, -6.485021904942025371773e+01)); + auto q0 = bcast(f32Cst(builder, +2.485846490142306297962e+01)); + auto q1 = bcast(f32Cst(builder, +1.650270098316988542046e+02)); + auto q2 = bcast(f32Cst(builder, +4.328810604912902668951e+02)); + auto q3 = bcast(f32Cst(builder, +4.853903996359136964868e+02)); + auto q4 = bcast(f32Cst(builder, +1.945506571482613964425e+02)); + + // Apply the polynomial approximation for the numerator: + Value n = p0; + n = builder.create(xx, n, p1); + n = builder.create(xx, n, p2); + n = builder.create(xx, n, p3); + n = builder.create(xx, n, p4); + n = builder.create(n, xx); + + // Apply the polynomial approximation for the denominator: + Value d = q0; + d = builder.create(xx, d, q1); + d = builder.create(xx, d, q2); + d = builder.create(xx, d, q3); + d = builder.create(xx, d, q4); + + // Compute approximation of theta: + Value ans0 = builder.create(n, d); + ans0 = builder.create(ans0, x, x); + + // Correct for the input mapping's angles: + Value mpi4 = bcast(f32Cst(builder, M_PI_4)); + Value ans2 = builder.create(mpi4, ans0); + Value ans = builder.create(cmp2, ans2, ans0); + + Value mpi2 = bcast(f32Cst(builder, M_PI_2)); + Value ans1 = builder.create(mpi2, ans0); + ans = builder.create(cmp1, ans1, ans); // Correct for signing of the input. - rewriter.replaceOpWithNewOp(op, select, operand); + rewriter.replaceOpWithNewOp(op, ans, operand); return success(); } diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -511,24 +511,50 @@ } // CHECK-LABEL: @atan_scalar -// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 -// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831 -// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335 -// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099 -// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987 -// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637 -// CHECK-DAG: %[[ABS:.+]] = math.absf %arg0 -// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]] -// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]] -// CHECK-DAG: %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]] -// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]] -// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]] -// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]] -// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]] -// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]] -// CHECK-DAG: %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]] -// CHECK-DAG: %[[RES:.+]] = math.copysign %[[EST]], %arg0 -// CHECK: return %[[RES]] +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 6.600000e-01 : f32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2.41421366 : f32 +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant -0.875060856 : f32 +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant -16.1575375 : f32 +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant -75.0085601 : f32 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant -122.886665 : f32 +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant -64.8502197 : f32 +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 24.8584652 : f32 +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 165.027008 : f32 +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 432.881073 : f32 +// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 485.390411 : f32 +// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 194.550659 : f32 +// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 0.785398185 : f32 +// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 1.57079637 : f32 +// CHECK-DAG: %[[VAL_16:.*]] = math.absf %[[VAL_0]] : f32 +// CHECK-DAG: %[[VAL_17:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_2]] : f32 +// CHECK-DAG: %[[VAL_18:.*]] = arith.addf %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK-DAG: %[[VAL_19:.*]] = arith.subf %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK-DAG: %[[VAL_20:.*]] = arith.select %[[VAL_17]], %[[VAL_19]], %[[VAL_16]] : f32 +// CHECK-DAG: %[[VAL_21:.*]] = arith.select %[[VAL_17]], %[[VAL_18]], %[[VAL_1]] : f32 +// CHECK-DAG: %[[VAL_22:.*]] = arith.cmpf ogt, %[[VAL_16]], %[[VAL_3]] : f32 +// CHECK-DAG: %[[VAL_23:.*]] = arith.select %[[VAL_22]], %[[VAL_1]], %[[VAL_20]] : f32 +// CHECK-DAG: %[[VAL_24:.*]] = arith.select %[[VAL_22]], %[[VAL_16]], %[[VAL_21]] : f32 +// CHECK-DAG: %[[VAL_25:.*]] = arith.divf %[[VAL_23]], %[[VAL_24]] : f32 +// CHECK-DAG: %[[VAL_26:.*]] = arith.mulf %[[VAL_25]], %[[VAL_25]] : f32 +// CHECK-DAG: %[[VAL_27:.*]] = math.fma %[[VAL_26]], %[[VAL_4]], %[[VAL_5]] : f32 +// CHECK-DAG: %[[VAL_28:.*]] = math.fma %[[VAL_26]], %[[VAL_27]], %[[VAL_6]] : f32 +// CHECK-DAG: %[[VAL_29:.*]] = math.fma %[[VAL_26]], %[[VAL_28]], %[[VAL_7]] : f32 +// CHECK-DAG: %[[VAL_30:.*]] = math.fma %[[VAL_26]], %[[VAL_29]], %[[VAL_8]] : f32 +// CHECK-DAG: %[[VAL_31:.*]] = arith.mulf %[[VAL_30]], %[[VAL_26]] : f32 +// CHECK-DAG: %[[VAL_32:.*]] = math.fma %[[VAL_26]], %[[VAL_9]], %[[VAL_10]] : f32 +// CHECK-DAG: %[[VAL_33:.*]] = math.fma %[[VAL_26]], %[[VAL_32]], %[[VAL_11]] : f32 +// CHECK-DAG: %[[VAL_34:.*]] = math.fma %[[VAL_26]], %[[VAL_33]], %[[VAL_12]] : f32 +// CHECK-DAG: %[[VAL_35:.*]] = math.fma %[[VAL_26]], %[[VAL_34]], %[[VAL_13]] : f32 +// CHECK-DAG: %[[VAL_36:.*]] = arith.divf %[[VAL_31]], %[[VAL_35]] : f32 +// CHECK-DAG: %[[VAL_37:.*]] = math.fma %[[VAL_36]], %[[VAL_25]], %[[VAL_25]] : f32 +// CHECK-DAG: %[[VAL_38:.*]] = arith.addf %[[VAL_37]], %[[VAL_14]] : f32 +// CHECK-DAG: %[[VAL_39:.*]] = arith.select %[[VAL_17]], %[[VAL_38]], %[[VAL_37]] : f32 +// CHECK-DAG: %[[VAL_40:.*]] = arith.subf %[[VAL_15]], %[[VAL_37]] : f32 +// CHECK-DAG: %[[VAL_41:.*]] = arith.select %[[VAL_22]], %[[VAL_40]], %[[VAL_39]] : f32 +// CHECK-DAG: %[[VAL_42:.*]] = math.copysign %[[VAL_41]], %[[VAL_0]] : f32 +// CHECK: return %[[VAL_42]] : f3 func.func @atan_scalar(%arg0: f32) -> f32 { %0 = math.atan %arg0 : f32 return %0 : f32 @@ -536,59 +562,75 @@ // CHECK-LABEL: @atan2_scalar - -// ATan approximation: -// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 -// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831 -// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335 -// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099 -// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987 -// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637 -// CHECK-DAG: %[[ARG0:.+]] = arith.extf %arg0 : f16 to f32 -// CHECK-DAG: %[[ARG1:.+]] = arith.extf %arg1 : f16 to f32 -// CHECK-DAG: %[[RATIO:.+]] = arith.divf %[[ARG0]], %[[ARG1]] -// CHECK-DAG: %[[ABS:.+]] = math.absf %[[RATIO]] -// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]] -// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]] -// CHECK-DAG: %[[SEL:.+]] = arith.select %[[CMP]], %[[ABS]], %[[DIV]] -// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]] -// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]] -// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]] -// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]] -// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]] -// CHECK-DAG: %[[EST:.+]] = arith.select %[[CMP]], %[[P3]], %[[SUB]] -// CHECK-DAG: %[[ATAN:.+]] = math.copysign %[[EST]], %[[RATIO]] - -// Handle the case of x < 0: -// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 -// CHECK-DAG: %[[PI:.+]] = arith.constant 3.14159274 -// CHECK-DAG: %[[ADD_PI:.+]] = arith.addf %[[ATAN]], %[[PI]] -// CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]] -// CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]] -// CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]] -// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ZERO]] -// CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]] - -// Handle PI / 2 edge case: -// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %[[ARG1]], %[[ZERO]] -// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %[[ARG0]], %[[ZERO]] -// CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]] -// CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]] - -// Handle -PI / 2 edge case: -// CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637 -// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %[[ARG0]], %[[ZERO]] -// CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]] -// CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]] - -// Handle Nan edgecase: -// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %[[ARG0]], %[[ZERO]] -// CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]] -// CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000 -// CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]] -// CHECK: %[[RET:.+]] = arith.truncf %[[EDGE3]] -// CHECK: return %[[RET]] - +// CHECK-SAME: %[[VAL_0:.*]]: f16, +// CHECK-SAME: %[[VAL_1:.*]]: f16) +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6.600000e-01 : f32 +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 2.41421366 : f32 +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant -0.875060856 : f32 +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant -16.1575375 : f32 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant -75.0085601 : f32 +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant -122.886665 : f32 +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant -64.8502197 : f32 +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 24.8584652 : f32 +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 165.027008 : f32 +// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 432.881073 : f32 +// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 485.390411 : f32 +// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 194.550659 : f32 +// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 0.785398185 : f32 +// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 1.57079637 : f32 +// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[VAL_18:.*]] = arith.constant 3.14159274 : f32 +// CHECK-DAG: %[[VAL_19:.*]] = arith.constant -1.57079637 : f32 +// CHECK-DAG: %[[VAL_20:.*]] = arith.constant 0x7FC00000 : f32 +// CHECK-DAG: %[[VAL_21:.*]] = arith.extf %[[VAL_0]] : f16 to f32 +// CHECK-DAG: %[[VAL_22:.*]] = arith.extf %[[VAL_1]] : f16 to f32 +// CHECK-DAG: %[[VAL_23:.*]] = arith.divf %[[VAL_21]], %[[VAL_22]] : f32 +// CHECK-DAG: %[[VAL_24:.*]] = math.absf %[[VAL_23]] : f32 +// CHECK-DAG: %[[VAL_25:.*]] = arith.cmpf ogt, %[[VAL_24]], %[[VAL_3]] : f32 +// CHECK-DAG: %[[VAL_26:.*]] = arith.addf %[[VAL_24]], %[[VAL_2]] : f32 +// CHECK-DAG: %[[VAL_27:.*]] = arith.subf %[[VAL_24]], %[[VAL_2]] : f32 +// CHECK-DAG: %[[VAL_28:.*]] = arith.select %[[VAL_25]], %[[VAL_27]], %[[VAL_24]] : f32 +// CHECK-DAG: %[[VAL_29:.*]] = arith.select %[[VAL_25]], %[[VAL_26]], %[[VAL_2]] : f32 +// CHECK-DAG: %[[VAL_30:.*]] = arith.cmpf ogt, %[[VAL_24]], %[[VAL_4]] : f32 +// CHECK-DAG: %[[VAL_31:.*]] = arith.select %[[VAL_30]], %[[VAL_2]], %[[VAL_28]] : f32 +// CHECK-DAG: %[[VAL_32:.*]] = arith.select %[[VAL_30]], %[[VAL_24]], %[[VAL_29]] : f32 +// CHECK-DAG: %[[VAL_33:.*]] = arith.divf %[[VAL_31]], %[[VAL_32]] : f32 +// CHECK-DAG: %[[VAL_34:.*]] = arith.mulf %[[VAL_33]], %[[VAL_33]] : f32 +// CHECK-DAG: %[[VAL_35:.*]] = math.fma %[[VAL_34]], %[[VAL_5]], %[[VAL_6]] : f32 +// CHECK-DAG: %[[VAL_36:.*]] = math.fma %[[VAL_34]], %[[VAL_35]], %[[VAL_7]] : f32 +// CHECK-DAG: %[[VAL_37:.*]] = math.fma %[[VAL_34]], %[[VAL_36]], %[[VAL_8]] : f32 +// CHECK-DAG: %[[VAL_38:.*]] = math.fma %[[VAL_34]], %[[VAL_37]], %[[VAL_9]] : f32 +// CHECK-DAG: %[[VAL_39:.*]] = arith.mulf %[[VAL_38]], %[[VAL_34]] : f32 +// CHECK-DAG: %[[VAL_40:.*]] = math.fma %[[VAL_34]], %[[VAL_10]], %[[VAL_11]] : f32 +// CHECK-DAG: %[[VAL_41:.*]] = math.fma %[[VAL_34]], %[[VAL_40]], %[[VAL_12]] : f32 +// CHECK-DAG: %[[VAL_42:.*]] = math.fma %[[VAL_34]], %[[VAL_41]], %[[VAL_13]] : f32 +// CHECK-DAG: %[[VAL_43:.*]] = math.fma %[[VAL_34]], %[[VAL_42]], %[[VAL_14]] : f32 +// CHECK-DAG: %[[VAL_44:.*]] = arith.divf %[[VAL_39]], %[[VAL_43]] : f32 +// CHECK-DAG: %[[VAL_45:.*]] = math.fma %[[VAL_44]], %[[VAL_33]], %[[VAL_33]] : f32 +// CHECK-DAG: %[[VAL_46:.*]] = arith.addf %[[VAL_45]], %[[VAL_15]] : f32 +// CHECK-DAG: %[[VAL_47:.*]] = arith.select %[[VAL_25]], %[[VAL_46]], %[[VAL_45]] : f32 +// CHECK-DAG: %[[VAL_48:.*]] = arith.subf %[[VAL_16]], %[[VAL_45]] : f32 +// CHECK-DAG: %[[VAL_49:.*]] = arith.select %[[VAL_30]], %[[VAL_48]], %[[VAL_47]] : f32 +// CHECK-DAG: %[[VAL_50:.*]] = math.copysign %[[VAL_49]], %[[VAL_23]] : f32 +// CHECK-DAG: %[[VAL_51:.*]] = arith.addf %[[VAL_50]], %[[VAL_18]] : f32 +// CHECK-DAG: %[[VAL_52:.*]] = arith.subf %[[VAL_50]], %[[VAL_18]] : f32 +// CHECK-DAG: %[[VAL_53:.*]] = arith.cmpf ogt, %[[VAL_50]], %[[VAL_17]] : f32 +// CHECK-DAG: %[[VAL_54:.*]] = arith.select %[[VAL_53]], %[[VAL_52]], %[[VAL_51]] : f32 +// CHECK-DAG: %[[VAL_55:.*]] = arith.cmpf ogt, %[[VAL_22]], %[[VAL_17]] : f32 +// CHECK-DAG: %[[VAL_56:.*]] = arith.select %[[VAL_55]], %[[VAL_50]], %[[VAL_54]] : f32 +// CHECK-DAG: %[[VAL_57:.*]] = arith.cmpf oeq, %[[VAL_22]], %[[VAL_17]] : f32 +// CHECK-DAG: %[[VAL_58:.*]] = arith.cmpf ogt, %[[VAL_21]], %[[VAL_17]] : f32 +// CHECK-DAG: %[[VAL_59:.*]] = arith.andi %[[VAL_57]], %[[VAL_58]] : i1 +// CHECK-DAG: %[[VAL_60:.*]] = arith.select %[[VAL_59]], %[[VAL_16]], %[[VAL_56]] : f32 +// CHECK-DAG: %[[VAL_61:.*]] = arith.cmpf olt, %[[VAL_21]], %[[VAL_17]] : f32 +// CHECK-DAG: %[[VAL_62:.*]] = arith.andi %[[VAL_57]], %[[VAL_61]] : i1 +// CHECK-DAG: %[[VAL_63:.*]] = arith.select %[[VAL_62]], %[[VAL_19]], %[[VAL_60]] : f32 +// CHECK-DAG: %[[VAL_64:.*]] = arith.cmpf oeq, %[[VAL_21]], %[[VAL_17]] : f32 +// CHECK-DAG: %[[VAL_65:.*]] = arith.andi %[[VAL_57]], %[[VAL_64]] : i1 +// CHECK-DAG: %[[VAL_66:.*]] = arith.select %[[VAL_65]], %[[VAL_20]], %[[VAL_63]] : f32 +// CHECK-DAG: %[[VAL_67:.*]] = arith.truncf %[[VAL_66]] : f32 to f16 +// CHECK: return %[[VAL_67]] : f1 func.func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 { %0 = math.atan2 %arg0, %arg1 : f16 return %0 : f16 diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -471,19 +471,19 @@ } func.func @atan() { - // CHECK: -0.785184 + // CHECK: -0.785398 %0 = arith.constant -1.0 : f32 call @atan_f32(%0) : (f32) -> () - // CHECK: 0.785184 + // CHECK: 0.785398 %1 = arith.constant 1.0 : f32 call @atan_f32(%1) : (f32) -> () - // CHECK: -0.463643 + // CHECK: -0.463648 %2 = arith.constant -0.5 : f32 call @atan_f32(%2) : (f32) -> () - // CHECK: 0.463643 + // CHECK: 0.463648 %3 = arith.constant 0.5 : f32 call @atan_f32(%3) : (f32) -> () @@ -548,7 +548,7 @@ // CHECK: -1.10715 call @atan2_f32(%neg_two, %one) : (f32, f32) -> () - // CHECK: 0.463643 + // CHECK: 0.463648 call @atan2_f32(%one, %two) : (f32, f32) -> () // CHECK: 2.67795 @@ -561,7 +561,7 @@ %y11 = arith.constant -1.0 : f32 call @atan2_f32(%neg_one, %neg_two) : (f32, f32) -> () - // CHECK: -0.463643 + // CHECK: -0.463648 call @atan2_f32(%neg_one, %two) : (f32, f32) -> () return