diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -895,6 +895,31 @@ namespace { +Value ClampWithNormals(ImplicitLocOpBuilder &builder, + const llvm::ArrayRef shape, Value value, + float lower_bound, float upper_bound) { + assert(!std::isnan(lower_bound)); + assert(!std::isnan(upper_bound)); + + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, shape); + }; + + auto select_cmp = [&builder](auto pred, Value value, Value bound) { + return builder.create( + builder.create(pred, value, bound), value, bound); + }; + + // Note: prefer UGE/ULE vs. UGT/ULT, since they generate vmaxps/vminps vs. + // vcmpleps+vmovaps on x86_64. The latter outcome is also obtained with + // arith::{Max,Min}FOp. + value = select_cmp(arith::CmpFPredicate::UGE, value, + bcast(f32Cst(builder, lower_bound))); + value = select_cmp(arith::CmpFPredicate::ULE, value, + bcast(f32Cst(builder, upper_bound))); + return value; +} + struct ExpApproximation : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -902,122 +927,147 @@ LogicalResult matchAndRewrite(math::ExpOp op, PatternRewriter &rewriter) const final; }; -} // namespace -// Approximate exp(x) using its reduced range exp(y) where y is in the range -// [0, ln(2)], let y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2), exp(x) -// = exp(y) * 2^k. exp(y). LogicalResult ExpApproximation::matchAndRewrite(math::ExpOp op, PatternRewriter &rewriter) const { - if (!getElementTypeOrSelf(op.getOperand()).isF32()) + auto shape = vectorShape(op.getOperand().getType()); + auto elementTy = getElementTypeOrSelf(op.getOperand().getType()); + if (!elementTy.isF32()) { return rewriter.notifyMatchFailure(op, "unsupported operand type"); - - ArrayRef shape = vectorShape(op.getOperand()); + } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - // TODO: Consider a common pattern rewriter with all methods below to - // write the approximations. + auto add = [&](Value a, Value b) -> Value { + return builder.create(a, b); + }; auto bcast = [&](Value value) -> Value { return broadcast(builder, value, shape); }; + auto floor = [&](Value a) { return builder.create(a); }; auto fmla = [&](Value a, Value b, Value c) { return builder.create(a, b, c); }; auto mul = [&](Value a, Value b) -> Value { return builder.create(a, b); }; - auto sub = [&](Value a, Value b) -> Value { - return builder.create(a, b); - }; - auto floor = [&](Value a) { return builder.create(a); }; - - Value cstLn2 = bcast(f32Cst(builder, static_cast(LN2_VALUE))); - Value cstLog2E = bcast(f32Cst(builder, static_cast(LOG2E_VALUE))); - - // Polynomial coefficients. - Value cstCephesExpP0 = bcast(f32Cst(builder, 1.0)); - Value cstCephesExpP1 = bcast(f32Cst(builder, 1.0)); - Value cstCephesExpP2 = bcast(f32Cst(builder, 0.49970514590562437052f)); - Value cstCephesExpP3 = bcast(f32Cst(builder, 0.16873890085469545053f)); - Value cstCephesExpP4 = bcast(f32Cst(builder, 0.03668965196652099192f)); - Value cstCephesExpP5 = bcast(f32Cst(builder, 0.01314350012789660196f)); + // Polynomial approximation from Cephes. + // + // To compute e^x, we re-express it as + // + // e^x = e^(a + b) + // = e^(a + n log(2)) + // = e^a * 2^n. + // + // We choose n = round(x / log(2)), restricting the value of `a` to + // (-log(2)/2, log(2)/2). We then use a polynomial to compute e^a. The + // relative error between our approximation and the true value of e^a is less + // than 2^-22.5 for all values of `a` within this range. + + // Restrict input to a small range, including some values that evaluate to + // +/- inf. Note that for our lower bound, we choose log(2^-126) instead of + // log(F32_EPSILON). We do so because this routine always flushes denormal + // floating points to 0. Therefore, we only need to worry about exponentiating + // up to the smallest representable non-denormal floating point, which is + // 2^-126. + + // Constants. + Value cst_half = bcast(f32Cst(builder, 0.5f)); + Value cst_one = bcast(f32Cst(builder, 1.0f)); + + // 1/log(2) + Value cst_log2ef = bcast(f32Cst(builder, 1.44269504088896341f)); + + Value cst_exp_c1 = bcast(f32Cst(builder, -0.693359375f)); + Value cst_exp_c2 = bcast(f32Cst(builder, 2.12194440e-4f)); + Value cst_exp_p0 = bcast(f32Cst(builder, 1.9875691500E-4f)); + Value cst_exp_p1 = bcast(f32Cst(builder, 1.3981999507E-3f)); + Value cst_exp_p2 = bcast(f32Cst(builder, 8.3334519073E-3f)); + Value cst_exp_p3 = bcast(f32Cst(builder, 4.1665795894E-2f)); + Value cst_exp_p4 = bcast(f32Cst(builder, 1.6666665459E-1f)); + Value cst_exp_p5 = bcast(f32Cst(builder, 5.0000001201E-1f)); + + // Our computations below aren't particularly sensitive to the exact choices + // here, so we choose values a bit larger/smaller than + // + // log(F32_MAX) = 88.723... + // log(2^-126) = -87.337... Value x = op.getOperand(); + x = ClampWithNormals(builder, shape, x, -87.8f, 88.8f); + Value n = floor(fmla(x, cst_log2ef, cst_half)); - Value isNan = builder.create(arith::CmpFPredicate::UNO, x, x); - - // Reduced y = x - floor(x / ln(2)) * ln(2) = x - k * ln(2) - Value xL2Inv = mul(x, cstLog2E); - Value kF32 = floor(xL2Inv); - Value kLn2 = mul(kF32, cstLn2); - Value y = sub(x, kLn2); - - // Use Estrin's evaluation scheme with 3 independent parts: - // P(y)^y : (c0 + c1 y) + (c2 + c3 y) y^2 + (c4 + c5 y) y^4 - Value y2 = mul(y, y); - Value y4 = mul(y2, y2); - - Value q0 = fmla(cstCephesExpP1, y, cstCephesExpP0); - Value q1 = fmla(cstCephesExpP3, y, cstCephesExpP2); - Value q2 = fmla(cstCephesExpP5, y, cstCephesExpP4); - Value expY = fmla(q1, y2, q0); - expY = fmla(q2, y4, expY); - - auto i32Vec = broadcast(builder.getI32Type(), shape); - - // exp2(k) - Value k = builder.create(i32Vec, kF32); - Value exp2KValue = exp2I32(builder, k); - - // exp(x) = exp(y) * exp2(k) - expY = mul(expY, exp2KValue); - - // Handle overflow, inf and underflow of exp(x). exp(x) range is [0, inf], its - // partitioned as the following: - // exp(x) = 0, x <= -inf - // exp(x) = underflow (min_float), x <= -88 - // exp(x) = inf (min_float), x >= 88 - // Note: |k| = 127 is the value where the 8-bits exponent saturates. - Value zerof32Const = bcast(f32Cst(builder, 0)); - auto constPosInfinity = - bcast(f32Cst(builder, std::numeric_limits::infinity())); - auto constNegIfinity = - bcast(f32Cst(builder, -std::numeric_limits::infinity())); - auto underflow = bcast(f32Cst(builder, std::numeric_limits::min())); - - Value kMaxConst = bcast(i32Cst(builder, 127)); - Value kMaxNegConst = bcast(i32Cst(builder, -127)); - Value rightBound = - builder.create(arith::CmpIPredicate::sle, k, kMaxConst); - Value leftBound = - builder.create(arith::CmpIPredicate::sge, k, kMaxNegConst); - - Value isNegInfinityX = builder.create( - arith::CmpFPredicate::OEQ, x, constNegIfinity); - Value isPosInfinityX = builder.create( - arith::CmpFPredicate::OEQ, x, constPosInfinity); - Value isPostiveX = - builder.create(arith::CmpFPredicate::OGT, x, zerof32Const); - Value isComputable = builder.create(rightBound, leftBound); - - expY = builder.create( - isNan, x, - builder.create( - isNegInfinityX, zerof32Const, - builder.create( - isPosInfinityX, constPosInfinity, - builder.create( - isComputable, expY, - builder.create(isPostiveX, constPosInfinity, - underflow))))); - - rewriter.replaceOp(op, expY); - - return success(); + // When we eventually do the multiplication in e^a * 2^n, we need to handle + // the case when n > 127, the max fp32 exponent (so 2^n == inf) but e^a < 1 + // (so e^a * 2^n != inf). There's a similar problem for n < -126, the + // smallest fp32 exponent. + // + // A straightforward solution would be to detect n out of range and split it + // up, doing + // + // e^a * 2^n = e^a * 2^(n1 + n2) + // = (2^n1 * e^a) * 2^n2. + // + // But it turns out this approach is quite slow, probably because it + // manipulates subnormal values. + // + // The approach we use instead is to clamp n to [-127, 127]. Let n' be the + // value of n clamped to [-127, 127]. In the case where n' = 127, `a` can grow + // up to as large as 88.8 - 127 * log(2) which is about 0.7703. Even though + // this value of `a` is outside our previously specified range, e^a will still + // only have a relative error of approximately 2^-16 at worse. In practice + // this seems to work well enough; it passes our exhaustive tests, breaking + // only one result, and by one ulp (we return exp(88.7228394) = max-float but + // we should return inf). + // + // In the case where n' = -127, the original input value of x is so small that + // e^x, our final answer, is less than 2^-126. Since 2^-126 is the smallest + // normal floating point, and since we flush denormals, we simply return 0. We + // do this in a branchless way by observing that our code for constructing 2^n + // produces 0 if n = -127. + // + // The proof that n' = -127 implies e^x < 2^-126 is as follows: + // + // n' = -127 implies n <= -127 + // implies round(x / log(2)) <= -127 + // implies x/log(2) < -126.5 + // implies x < -126.5 * log(2) + // implies e^x < e^(-126.5 * log(2)) + // implies e^x < 2^-126.5 < 2^-126 + // + // This proves that n' = -127 implies e^x < 2^-126. + n = ClampWithNormals(builder, shape, n, -127.0f, 127.0f); + + // Computes x = x - n' * log(2), the value for `a` + x = fmla(cst_exp_c1, n, x); + x = fmla(cst_exp_c2, n, x); + + // Polynomial to compute z = e^a, accurate for a in (-0.5, 0.5). + Value z = fmla(x, cst_exp_p0, cst_exp_p1); + z = fmla(z, x, cst_exp_p2); + z = fmla(z, x, cst_exp_p3); + z = fmla(z, x, cst_exp_p4); + z = fmla(z, x, cst_exp_p5); + z = fmla(z, mul(x, x), x); + z = add(cst_one, z); + + // Convert n' to an i32. This is safe because we clamped it above. + auto i32_vec = broadcast(builder.getI32Type(), shape); + Value n_i32 = builder.create(i32_vec, n); + + // Creates the value 2^n' if -126 <= n' <= 127 and 0 if n' = -127. + Value pow2 = exp2I32(builder, n_i32); + + // Return z * 2^n' if -126 <= n' <= 127 and 0 if n = -127. + Value ret = mul(z, pow2); + + rewriter.replaceOp(op, ret); + return mlir::success(); } +} // namespace + //----------------------------------------------------------------------------// // ExpM1 approximation. //----------------------------------------------------------------------------// diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -96,48 +96,47 @@ // CHECK-LABEL: func @exp_scalar( // CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { -// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0.693147182 : f32 -// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1.44269502 : f32 -// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0.499705136 : f32 -// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 0.168738902 : f32 -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0.0366896503 : f32 -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1.314350e-02 : f32 -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 23 : i32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0x7F800000 : f32 -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0xFF800000 : f32 -// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 1.17549435E-38 : f32 -// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 127 : i32 -// CHECK-DAG: %[[VAL_14:.*]] = arith.constant -127 : i32 -// CHECK: %[[IS_NAN:.*]] = arith.cmpf uno, %[[VAL_0]], %[[VAL_0]] : f32 -// CHECK: %[[VAL_15:.*]] = arith.mulf %[[VAL_0]], %[[VAL_2]] : f32 -// CHECK: %[[VAL_16:.*]] = math.floor %[[VAL_15]] : f32 -// CHECK: %[[VAL_17:.*]] = arith.mulf %[[VAL_16]], %[[VAL_1]] : f32 -// CHECK: %[[VAL_18:.*]] = arith.subf %[[VAL_0]], %[[VAL_17]] : f32 -// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_18]], %[[VAL_18]] : f32 -// CHECK: %[[VAL_20:.*]] = arith.mulf %[[VAL_19]], %[[VAL_19]] : f32 -// CHECK: %[[VAL_21:.*]] = math.fma %[[VAL_3]], %[[VAL_18]], %[[VAL_3]] : f32 -// CHECK: %[[VAL_22:.*]] = math.fma %[[VAL_5]], %[[VAL_18]], %[[VAL_4]] : f32 -// CHECK: %[[VAL_23:.*]] = math.fma %[[VAL_7]], %[[VAL_18]], %[[VAL_6]] : f32 -// CHECK: %[[VAL_24:.*]] = math.fma %[[VAL_22]], %[[VAL_19]], %[[VAL_21]] : f32 -// CHECK: %[[VAL_25:.*]] = math.fma %[[VAL_23]], %[[VAL_20]], %[[VAL_24]] : f32 -// CHECK: %[[VAL_26:.*]] = arith.fptosi %[[VAL_16]] : f32 to i32 -// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_28:.*]] = arith.shli %[[VAL_27]], %[[VAL_8]] : i32 -// CHECK: %[[VAL_29:.*]] = arith.bitcast %[[VAL_28]] : i32 to f32 -// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_25]], %[[VAL_29]] : f32 -// CHECK: %[[VAL_31:.*]] = arith.cmpi sle, %[[VAL_26]], %[[VAL_13]] : i32 -// CHECK: %[[VAL_32:.*]] = arith.cmpi sge, %[[VAL_26]], %[[VAL_14]] : i32 -// CHECK: %[[VAL_33:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_11]] : f32 -// CHECK: %[[VAL_34:.*]] = arith.cmpf oeq, %[[VAL_0]], %[[VAL_10]] : f32 -// CHECK: %[[VAL_35:.*]] = arith.cmpf ogt, %[[VAL_0]], %[[VAL_9]] : f32 -// CHECK: %[[VAL_36:.*]] = arith.andi %[[VAL_31]], %[[VAL_32]] : i1 -// CHECK: %[[VAL_37:.*]] = arith.select %[[VAL_35]], %[[VAL_10]], %[[VAL_12]] : f32 -// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_30]], %[[VAL_37]] : f32 -// CHECK: %[[VAL_39:.*]] = arith.select %[[VAL_34]], %[[VAL_10]], %[[VAL_38]] : f32 -// CHECK: %[[VAL_40:.*]] = arith.select %[[VAL_33]], %[[VAL_9]], %[[VAL_39]] : f32 -// CHECK: %[[VAL_41:.*]] = arith.select %[[IS_NAN]], %[[VAL_0]], %[[VAL_40]] : f32 +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1.44269502 : f32 +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant -0.693359375 : f32 +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 2.12194442E-4 : f32 +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 1.98756912E-4 : f32 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.00139819994 : f32 +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.00833345205 : f32 +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.0416657962 : f32 +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.166666657 : f32 +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant -8.780000e+01 : f32 +// CHECK-DAG: %[[VAL_12:.*]] = arith.constant 8.880000e+01 : f32 +// CHECK-DAG: %[[VAL_13:.*]] = arith.constant -1.270000e+02 : f32 +// CHECK-DAG: %[[VAL_14:.*]] = arith.constant 1.270000e+02 : f32 +// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 23 : i32 +// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 127 : i32 +// CHECK-DAG: %[[VAL_17:.*]] = arith.cmpf uge, %[[VAL_0]], %[[VAL_11]] : f32 +// CHECK-DAG: %[[VAL_18:.*]] = arith.select %[[VAL_17]], %[[VAL_0]], %[[VAL_11]] : f32 +// CHECK-DAG: %[[VAL_19:.*]] = arith.cmpf ule, %[[VAL_18]], %[[VAL_12]] : f32 +// CHECK-DAG: %[[VAL_20:.*]] = arith.select %[[VAL_19]], %[[VAL_18]], %[[VAL_12]] : f32 +// CHECK-DAG: %[[VAL_21:.*]] = math.fma %[[VAL_20]], %[[VAL_3]], %[[VAL_1]] : f32 +// CHECK-DAG: %[[VAL_22:.*]] = math.floor %[[VAL_21]] : f32 +// CHECK-DAG: %[[VAL_23:.*]] = arith.cmpf uge, %[[VAL_22]], %[[VAL_13]] : f32 +// CHECK-DAG: %[[VAL_24:.*]] = arith.select %[[VAL_23]], %[[VAL_22]], %[[VAL_13]] : f32 +// CHECK-DAG: %[[VAL_25:.*]] = arith.cmpf ule, %[[VAL_24]], %[[VAL_14]] : f32 +// CHECK-DAG: %[[VAL_26:.*]] = arith.select %[[VAL_25]], %[[VAL_24]], %[[VAL_14]] : f32 +// CHECK-DAG: %[[VAL_27:.*]] = math.fma %[[VAL_4]], %[[VAL_26]], %[[VAL_20]] : f32 +// CHECK-DAG: %[[VAL_28:.*]] = math.fma %[[VAL_5]], %[[VAL_26]], %[[VAL_27]] : f32 +// CHECK-DAG: %[[VAL_29:.*]] = math.fma %[[VAL_28]], %[[VAL_6]], %[[VAL_7]] : f32 +// CHECK-DAG: %[[VAL_30:.*]] = math.fma %[[VAL_29]], %[[VAL_28]], %[[VAL_8]] : f32 +// CHECK-DAG: %[[VAL_31:.*]] = math.fma %[[VAL_30]], %[[VAL_28]], %[[VAL_9]] : f32 +// CHECK-DAG: %[[VAL_32:.*]] = math.fma %[[VAL_31]], %[[VAL_28]], %[[VAL_10]] : f32 +// CHECK-DAG: %[[VAL_33:.*]] = math.fma %[[VAL_32]], %[[VAL_28]], %[[VAL_1]] : f32 +// CHECK-DAG: %[[VAL_34:.*]] = arith.mulf %[[VAL_28]], %[[VAL_28]] : f32 +// CHECK-DAG: %[[VAL_35:.*]] = math.fma %[[VAL_33]], %[[VAL_34]], %[[VAL_28]] : f32 +// CHECK-DAG: %[[VAL_36:.*]] = arith.addf %[[VAL_35]], %[[VAL_2]] : f32 +// CHECK-DAG: %[[VAL_37:.*]] = arith.fptosi %[[VAL_26]] : f32 to i32 +// CHECK-DAG: %[[VAL_38:.*]] = arith.addi %[[VAL_37]], %[[VAL_16]] : i32 +// CHECK-DAG: %[[VAL_39:.*]] = arith.shli %[[VAL_38]], %[[VAL_15]] : i32 +// CHECK-DAG: %[[VAL_40:.*]] = arith.bitcast %[[VAL_39]] : i32 to f32 +// CHECK-DAG: %[[VAL_41:.*]] = arith.mulf %[[VAL_36]], %[[VAL_40]] : f32 // CHECK: return %[[VAL_41]] : f32 func.func @exp_scalar(%arg0: f32) -> f32 { %0 = math.exp %arg0 : f32 @@ -146,11 +145,7 @@ // CHECK-LABEL: func @exp_vector( // CHECK-SAME: %[[VAL_0:.*]]: vector<8xf32>) -> vector<8xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant dense<0.693147182> : vector<8xf32> -// CHECK-NOT: exp -// CHECK-COUNT-4: select -// CHECK: %[[VAL_40:.*]] = arith.select -// CHECK: return %[[VAL_40]] : vector<8xf32> +// CHECK-NOT: math.exp func.func @exp_vector(%arg0: vector<8xf32>) -> vector<8xf32> { %0 = math.exp %arg0 : vector<8xf32> return %0 : vector<8xf32> @@ -158,26 +153,114 @@ // CHECK-LABEL: func @expm1_scalar( // CHECK-SAME: %[[X:.*]]: f32) -> f32 { -// CHECK-DAG: %[[CST_MINUSONE:.*]] = arith.constant -1.000000e+00 : f32 -// CHECK-DAG: %[[CST_LOG2E:.*]] = arith.constant 1.44269502 : f32 -// CHECK-DAG: %[[CST_ONE:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[BEGIN_EXP_X:.*]] = arith.mulf %[[X]], %[[CST_LOG2E]] : f32 -// CHECK-NOT: exp -// CHECK-COUNT-4: select -// CHECK: %[[EXP_X:.*]] = arith.select -// CHECK: %[[IS_ONE_OR_NAN:.*]] = arith.cmpf ueq, %[[EXP_X]], %[[CST_ONE]] : f32 -// CHECK: %[[VAL_59:.*]] = arith.subf %[[EXP_X]], %[[CST_ONE]] : f32 -// CHECK: %[[VAL_60:.*]] = arith.cmpf oeq, %[[VAL_59]], %[[CST_MINUSONE]] : f32 -// CHECK-NOT: log -// CHECK-COUNT-5: select -// CHECK: %[[LOG_U:.*]] = arith.select -// CHECK: %[[VAL_104:.*]] = arith.cmpf oeq, %[[LOG_U]], %[[EXP_X]] : f32 -// CHECK: %[[VAL_105:.*]] = arith.divf %[[X]], %[[LOG_U]] : f32 -// CHECK: %[[VAL_106:.*]] = arith.mulf %[[VAL_59]], %[[VAL_105]] : f32 -// CHECK: %[[VAL_107:.*]] = arith.select %[[VAL_104]], %[[EXP_X]], %[[VAL_106]] : f32 -// CHECK: %[[VAL_108:.*]] = arith.select %[[VAL_60]], %[[CST_MINUSONE]], %[[VAL_107]] : f32 -// CHECK: %[[VAL_109:.*]] = arith.select %[[IS_ONE_OR_NAN]], %[[X]], %[[VAL_108]] : f32 -// CHECK: return %[[VAL_109]] : f32 +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant -1.000000e+00 : f32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 1.44269502 : f32 +// CHECK-DAG: %[[VAL_5:.*]] = arith.constant -0.693359375 : f32 +// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 2.12194442E-4 : f32 +// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 1.98756912E-4 : f32 +// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.00139819994 : f32 +// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.00833345205 : f32 +// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.0416657962 : f32 +// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0.166666657 : f32 +// CHECK-DAG: %[[VAL_12:.*]] = arith.constant -8.780000e+01 : f32 +// CHECK-DAG: %[[VAL_13:.*]] = arith.constant 8.880000e+01 : f32 +// CHECK-DAG: %[[VAL_14:.*]] = arith.constant -1.270000e+02 : f32 +// CHECK-DAG: %[[VAL_15:.*]] = arith.constant 1.270000e+02 : f32 +// CHECK-DAG: %[[VAL_16:.*]] = arith.constant 23 : i32 +// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 127 : i32 +// CHECK-DAG: %[[VAL_18:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[VAL_19:.*]] = arith.constant -5.000000e-01 : f32 +// CHECK-DAG: %[[VAL_20:.*]] = arith.constant 1.17549435E-38 : f32 +// CHECK-DAG: %[[VAL_21:.*]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: %[[VAL_22:.*]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: %[[VAL_23:.*]] = arith.constant 0x7FC00000 : f32 +// CHECK-DAG: %[[VAL_24:.*]] = arith.constant 0.707106769 : f32 +// CHECK-DAG: %[[VAL_25:.*]] = arith.constant 0.0703768358 : f32 +// CHECK-DAG: %[[VAL_26:.*]] = arith.constant -0.115146101 : f32 +// CHECK-DAG: %[[VAL_27:.*]] = arith.constant 0.116769984 : f32 +// CHECK-DAG: %[[VAL_28:.*]] = arith.constant -0.12420141 : f32 +// CHECK-DAG: %[[VAL_29:.*]] = arith.constant 0.142493233 : f32 +// CHECK-DAG: %[[VAL_30:.*]] = arith.constant -0.166680574 : f32 +// CHECK-DAG: %[[VAL_31:.*]] = arith.constant 0.200007141 : f32 +// CHECK-DAG: %[[VAL_32:.*]] = arith.constant -0.24999994 : f32 +// CHECK-DAG: %[[VAL_33:.*]] = arith.constant 0.333333313 : f32 +// CHECK-DAG: %[[VAL_34:.*]] = arith.constant 1.260000e+02 : f32 +// CHECK-DAG: %[[VAL_35:.*]] = arith.constant -2139095041 : i32 +// CHECK-DAG: %[[VAL_36:.*]] = arith.constant 1056964608 : i32 +// CHECK-DAG: %[[VAL_37:.*]] = arith.constant 0.693147182 : f32 +// CHECK-DAG: %[[VAL_38:.*]] = arith.cmpf uge, %[[X]], %[[VAL_12]] : f32 +// CHECK-DAG: %[[VAL_39:.*]] = arith.select %[[VAL_38]], %[[X]], %[[VAL_12]] : f32 +// CHECK-DAG: %[[VAL_40:.*]] = arith.cmpf ule, %[[VAL_39]], %[[VAL_13]] : f32 +// CHECK-DAG: %[[VAL_41:.*]] = arith.select %[[VAL_40]], %[[VAL_39]], %[[VAL_13]] : f32 +// CHECK-DAG: %[[VAL_42:.*]] = math.fma %[[VAL_41]], %[[VAL_4]], %[[VAL_3]] : f32 +// CHECK-DAG: %[[VAL_43:.*]] = math.floor %[[VAL_42]] : f32 +// CHECK-DAG: %[[VAL_44:.*]] = arith.cmpf uge, %[[VAL_43]], %[[VAL_14]] : f32 +// CHECK-DAG: %[[VAL_45:.*]] = arith.select %[[VAL_44]], %[[VAL_43]], %[[VAL_14]] : f32 +// CHECK-DAG: %[[VAL_46:.*]] = arith.cmpf ule, %[[VAL_45]], %[[VAL_15]] : f32 +// CHECK-DAG: %[[VAL_47:.*]] = arith.select %[[VAL_46]], %[[VAL_45]], %[[VAL_15]] : f32 +// CHECK-DAG: %[[VAL_48:.*]] = math.fma %[[VAL_5]], %[[VAL_47]], %[[VAL_41]] : f32 +// CHECK-DAG: %[[VAL_49:.*]] = math.fma %[[VAL_6]], %[[VAL_47]], %[[VAL_48]] : f32 +// CHECK-DAG: %[[VAL_50:.*]] = math.fma %[[VAL_49]], %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK-DAG: %[[VAL_51:.*]] = math.fma %[[VAL_50]], %[[VAL_49]], %[[VAL_9]] : f32 +// CHECK-DAG: %[[VAL_52:.*]] = math.fma %[[VAL_51]], %[[VAL_49]], %[[VAL_10]] : f32 +// CHECK-DAG: %[[VAL_53:.*]] = math.fma %[[VAL_52]], %[[VAL_49]], %[[VAL_11]] : f32 +// CHECK-DAG: %[[VAL_54:.*]] = math.fma %[[VAL_53]], %[[VAL_49]], %[[VAL_3]] : f32 +// CHECK-DAG: %[[VAL_55:.*]] = arith.mulf %[[VAL_49]], %[[VAL_49]] : f32 +// CHECK-DAG: %[[VAL_56:.*]] = math.fma %[[VAL_54]], %[[VAL_55]], %[[VAL_49]] : f32 +// CHECK-DAG: %[[VAL_57:.*]] = arith.addf %[[VAL_56]], %[[VAL_1]] : f32 +// CHECK-DAG: %[[VAL_58:.*]] = arith.fptosi %[[VAL_47]] : f32 to i32 +// CHECK-DAG: %[[VAL_59:.*]] = arith.addi %[[VAL_58]], %[[VAL_17]] : i32 +// CHECK-DAG: %[[VAL_60:.*]] = arith.shli %[[VAL_59]], %[[VAL_16]] : i32 +// CHECK-DAG: %[[VAL_61:.*]] = arith.bitcast %[[VAL_60]] : i32 to f32 +// CHECK-DAG: %[[VAL_62:.*]] = arith.mulf %[[VAL_57]], %[[VAL_61]] : f32 +// CHECK-DAG: %[[VAL_63:.*]] = arith.cmpf ueq, %[[VAL_62]], %[[VAL_1]] : f32 +// CHECK-DAG: %[[VAL_64:.*]] = arith.subf %[[VAL_62]], %[[VAL_1]] : f32 +// CHECK-DAG: %[[VAL_65:.*]] = arith.cmpf oeq, %[[VAL_64]], %[[VAL_2]] : f32 +// CHECK-DAG: %[[VAL_66:.*]] = arith.cmpf ugt, %[[VAL_62]], %[[VAL_20]] : f32 +// CHECK-DAG: %[[VAL_67:.*]] = arith.select %[[VAL_66]], %[[VAL_62]], %[[VAL_20]] : f32 +// CHECK-DAG: %[[VAL_68:.*]] = arith.bitcast %[[VAL_67]] : f32 to i32 +// CHECK-DAG: %[[VAL_69:.*]] = arith.andi %[[VAL_68]], %[[VAL_35]] : i32 +// CHECK-DAG: %[[VAL_70:.*]] = arith.ori %[[VAL_69]], %[[VAL_36]] : i32 +// CHECK-DAG: %[[VAL_71:.*]] = arith.bitcast %[[VAL_70]] : i32 to f32 +// CHECK-DAG: %[[VAL_72:.*]] = arith.bitcast %[[VAL_67]] : f32 to i32 +// CHECK-DAG: %[[VAL_73:.*]] = arith.shrui %[[VAL_72]], %[[VAL_16]] : i32 +// CHECK-DAG: %[[VAL_74:.*]] = arith.sitofp %[[VAL_73]] : i32 to f32 +// CHECK-DAG: %[[VAL_75:.*]] = arith.subf %[[VAL_74]], %[[VAL_34]] : f32 +// CHECK-DAG: %[[VAL_76:.*]] = arith.cmpf olt, %[[VAL_71]], %[[VAL_24]] : f32 +// CHECK-DAG: %[[VAL_77:.*]] = arith.select %[[VAL_76]], %[[VAL_71]], %[[VAL_18]] : f32 +// CHECK-DAG: %[[VAL_78:.*]] = arith.subf %[[VAL_71]], %[[VAL_1]] : f32 +// CHECK-DAG: %[[VAL_79:.*]] = arith.select %[[VAL_76]], %[[VAL_1]], %[[VAL_18]] : f32 +// CHECK-DAG: %[[VAL_80:.*]] = arith.subf %[[VAL_75]], %[[VAL_79]] : f32 +// CHECK-DAG: %[[VAL_81:.*]] = arith.addf %[[VAL_78]], %[[VAL_77]] : f32 +// CHECK-DAG: %[[VAL_82:.*]] = arith.mulf %[[VAL_81]], %[[VAL_81]] : f32 +// CHECK-DAG: %[[VAL_83:.*]] = arith.mulf %[[VAL_82]], %[[VAL_81]] : f32 +// CHECK-DAG: %[[VAL_84:.*]] = math.fma %[[VAL_25]], %[[VAL_81]], %[[VAL_26]] : f32 +// CHECK-DAG: %[[VAL_85:.*]] = math.fma %[[VAL_28]], %[[VAL_81]], %[[VAL_29]] : f32 +// CHECK-DAG: %[[VAL_86:.*]] = math.fma %[[VAL_31]], %[[VAL_81]], %[[VAL_32]] : f32 +// CHECK-DAG: %[[VAL_87:.*]] = math.fma %[[VAL_84]], %[[VAL_81]], %[[VAL_27]] : f32 +// CHECK-DAG: %[[VAL_88:.*]] = math.fma %[[VAL_85]], %[[VAL_81]], %[[VAL_30]] : f32 +// CHECK-DAG: %[[VAL_89:.*]] = math.fma %[[VAL_86]], %[[VAL_81]], %[[VAL_33]] : f32 +// CHECK-DAG: %[[VAL_90:.*]] = math.fma %[[VAL_87]], %[[VAL_83]], %[[VAL_88]] : f32 +// CHECK-DAG: %[[VAL_91:.*]] = math.fma %[[VAL_90]], %[[VAL_83]], %[[VAL_89]] : f32 +// CHECK-DAG: %[[VAL_92:.*]] = arith.mulf %[[VAL_91]], %[[VAL_83]] : f32 +// CHECK-DAG: %[[VAL_93:.*]] = math.fma %[[VAL_19]], %[[VAL_82]], %[[VAL_92]] : f32 +// CHECK-DAG: %[[VAL_94:.*]] = arith.addf %[[VAL_81]], %[[VAL_93]] : f32 +// CHECK-DAG: %[[VAL_95:.*]] = math.fma %[[VAL_80]], %[[VAL_37]], %[[VAL_94]] : f32 +// CHECK-DAG: %[[VAL_96:.*]] = arith.cmpf ult, %[[VAL_62]], %[[VAL_18]] : f32 +// CHECK-DAG: %[[VAL_97:.*]] = arith.cmpf oeq, %[[VAL_62]], %[[VAL_18]] : f32 +// CHECK-DAG: %[[VAL_98:.*]] = arith.cmpf oeq, %[[VAL_62]], %[[VAL_22]] : f32 +// CHECK-DAG: %[[VAL_99:.*]] = arith.select %[[VAL_98]], %[[VAL_22]], %[[VAL_95]] : f32 +// CHECK-DAG: %[[VAL_100:.*]] = arith.select %[[VAL_96]], %[[VAL_23]], %[[VAL_99]] : f32 +// CHECK-DAG: %[[VAL_101:.*]] = arith.select %[[VAL_97]], %[[VAL_21]], %[[VAL_100]] : f32 +// CHECK-DAG: %[[VAL_102:.*]] = arith.cmpf oeq, %[[VAL_101]], %[[VAL_62]] : f32 +// CHECK-DAG: %[[VAL_103:.*]] = arith.divf %[[X]], %[[VAL_101]] : f32 +// CHECK-DAG: %[[VAL_104:.*]] = arith.mulf %[[VAL_64]], %[[VAL_103]] : f32 +// CHECK-DAG: %[[VAL_105:.*]] = arith.select %[[VAL_102]], %[[VAL_62]], %[[VAL_104]] : f32 +// CHECK-DAG: %[[VAL_106:.*]] = arith.select %[[VAL_65]], %[[VAL_2]], %[[VAL_105]] : f32 +// CHECK-DAG: %[[VAL_107:.*]] = arith.select %[[VAL_63]], %[[X]], %[[VAL_106]] : f32 +// CHECK-DAG: return %[[VAL_107]] : f32 // CHECK: } func.func @expm1_scalar(%arg0: f32) -> f32 { %0 = math.expm1 %arg0 : f32 @@ -186,16 +269,9 @@ // CHECK-LABEL: func @expm1_vector( // CHECK-SAME: %[[VAL_0:.*]]: vector<8x8xf32>) -> vector<8x8xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant dense<-1.000000e+00> : vector<8x8xf32> // CHECK-NOT: exp -// CHECK-COUNT-5: select // CHECK-NOT: log -// CHECK-COUNT-5: select // CHECK-NOT: expm1 -// CHECK-COUNT-3: select -// CHECK: %[[VAL_115:.*]] = arith.select -// CHECK: return %[[VAL_115]] : vector<8x8xf32> -// CHECK: } func.func @expm1_vector(%arg0: vector<8x8xf32>) -> vector<8x8xf32> { %0 = math.expm1 %arg0 : vector<8x8xf32> return %0 : vector<8x8xf32> diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -293,7 +293,7 @@ %f0 = arith.constant 1.0 : f32 call @exp_f32(%f0) : (f32) -> () - // CHECK: 0.778802, 2.117, 2.71828, 3.85742 + // CHECK: 0.778801, 2.117, 2.71828, 3.85743 %v1 = arith.constant dense<[-0.25, 0.75, 1.0, 1.35]> : vector<4xf32> call @exp_4xf32(%v1) : (vector<4xf32>) -> () @@ -301,7 +301,7 @@ %zero = arith.constant 0.0 : f32 call @exp_f32(%zero) : (f32) -> () - // CHECK: 1.17549e-38, 1.38879e-11, 7.20049e+10, inf + // CHECK: 0, 1.38879e-11, 7.20049e+10, inf %special_vec = arith.constant dense<[-89.0, -25.0, 25.0, 89.0]> : vector<4xf32> call @exp_4xf32(%special_vec) : (vector<4xf32>) -> () @@ -349,7 +349,7 @@ %f0 = arith.constant 1.0e-10 : f32 call @expm1_f32(%f0) : (f32) -> () - // CHECK: -0.00995016, 0.0100502, 0.648721, 6.38905 + // CHECK: -0.00995017, 0.0100502, 0.648721, 6.38906 %v1 = arith.constant dense<[-0.01, 0.01, 0.5, 2.0]> : vector<4xf32> call @expm1_4xf32(%v1) : (vector<4xf32>) -> () @@ -701,5 +701,3 @@ call @ceilf() : () -> () return } - -