diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -630,11 +630,142 @@ return success(); } +//----------------------------------------------------------------------------// +// Sin and Cos approximation. +//----------------------------------------------------------------------------// + +namespace { + +template +struct SinAndCosApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; +}; +} // namespace + +#define TWO_OVER_PI \ + 0.6366197723675813430755350534900574481378385829618257949906693762L +#define PI_OVER_2 \ + 1.5707963267948966192313216916397514420985846996875529104874722961L + +// Approximates sin(x) or cos(x) by finding the best approximation polynomial in +// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the +// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). +template +LogicalResult SinAndCosApproximation::matchAndRewrite( + OpTy op, PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + if (!width.hasValue()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + auto mul = [&](Value a, Value b) -> Value { + return builder.create(a, b); + }; + auto sub = [&](Value a, Value b) -> Value { + return builder.create(a, b); + }; + auto floor = [&](Value a) { return builder.create(a); }; + + auto i32Vec = broadcast(builder.getI32Type(), *width); + auto cast_f32_to_i32 = [&](Value a) -> Value { + return builder.create(a, i32Vec); + }; + + auto modulo_4 = [&](Value a) -> Value { + return builder.create(a, bcast(i32Cst(builder, 3))); + }; + + auto isEqualTo = [&](Value a, Value b) -> Value { + return builder.create(CmpIPredicate::eq, a, b); + }; + + auto isGreaterThan = [&](Value a, Value b) -> Value { + return builder.create(CmpIPredicate::sgt, a, b); + }; + + auto select = [&](Value cond, Value t, Value f) -> Value { + return builder.create(cond, t, f); + }; + + auto fmla = [&](Value a, Value b, Value c) { + return builder.create(a, b, c); + }; + + auto Or = [&](Value a, Value b) { return builder.create(a, b); }; + + Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI)); + Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2)); + + Value x = op.operand(); + + Value k = floor(mul(x, twoOverPi)); + + Value y = sub(x, mul(k, piOverTwo)); + + Value One = bcast(f32Cst(builder, 1.0)); + Value nOne = bcast(f32Cst(builder, -1.0)); + + Value SC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); + Value SC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); + Value SC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); + Value SC8 = bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); + Value SC10 = + bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); + + Value CC2 = bcast(f32Cst(builder, -0.5f)); + Value CC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); + Value CC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); + Value CC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); + Value CC10 = + bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); + + Value kmod4 = modulo_4(cast_f32_to_i32(k)); + + Value KR0 = isEqualTo(kmod4, bcast(i32Cst(builder, 0))); + Value KR1 = isEqualTo(kmod4, bcast(i32Cst(builder, 1))); + Value KR2 = isEqualTo(kmod4, bcast(i32Cst(builder, 2))); + Value KR3 = isEqualTo(kmod4, bcast(i32Cst(builder, 3))); + + Value sinuseCos = isSine ? Or(KR1, KR3) : Or(KR0, KR2); + Value negativeRange = + isSine ? isGreaterThan(kmod4, bcast(i32Cst(builder, 1))) : Or(KR1, KR2); + + Value y2 = mul(y, y); + + Value base = select(sinuseCos, One, y); + Value C2 = select(sinuseCos, CC2, SC2); + Value C4 = select(sinuseCos, CC4, SC4); + Value C6 = select(sinuseCos, CC6, SC6); + Value C8 = select(sinuseCos, CC8, SC8); + Value C10 = select(sinuseCos, CC10, SC10); + + Value v1 = fmla(y2, C10, C8); + Value v2 = fmla(y2, v1, C6); + Value v3 = fmla(y2, v2, C4); + Value v4 = fmla(y2, v3, C2); + Value v5 = fmla(y2, v4, One); + Value v6 = mul(base, v5); + + Value approximation = select(negativeRange, mul(nOne, v6), v6); + + rewriter.replaceOp(op, approximation); + + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns) { patterns.add( + Log1pApproximation, ExpApproximation, ExpM1Approximation, + SinAndCosApproximation, + SinAndCosApproximation>( patterns.getContext()); } diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -219,6 +219,83 @@ return } +// -------------------------------------------------------------------------- // +// Sin. +// -------------------------------------------------------------------------- // +func @sin() { + // CHECK: 0 + %0 = constant 0.0 : f32 + %sin_0 = math.sin %0 : f32 + vector.print %sin_0 : f32 + + // CHECK: 0.707107 + %pi_over_4 = constant 0.78539816339 : f32 + %sin_pi_over_4 = math.sin %pi_over_4 : f32 + vector.print %sin_pi_over_4 : f32 + + // CHECK: 1 + %pi_over_2 = constant 1.57079632679 : f32 + %sin_pi_over_2 = math.sin %pi_over_2 : f32 + vector.print %sin_pi_over_2 : f32 + + + // CHECK: 0 + %pi = constant 3.14159265359 : f32 + %sin_pi = math.sin %pi : f32 + vector.print %sin_pi : f32 + + // CHECK: -1 + %pi_3_over_2 = constant 4.71238898038 : f32 + %sin_pi_3_over_2 = math.sin %pi_3_over_2 : f32 + vector.print %sin_pi_3_over_2 : f32 + + // CHECK: 0, 0.866025, -1 + %vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32> + %sin_vec_x = math.sin %vec_x : vector<3xf32> + vector.print %sin_vec_x : vector<3xf32> + + return +} + +// -------------------------------------------------------------------------- // +// cos. +// -------------------------------------------------------------------------- // + +func @cos() { + // CHECK: 1 + %0 = constant 0.0 : f32 + %cos_0 = math.cos %0 : f32 + vector.print %cos_0 : f32 + + // CHECK: 0.707107 + %pi_over_4 = constant 0.78539816339 : f32 + %cos_pi_over_4 = math.cos %pi_over_4 : f32 + vector.print %cos_pi_over_4 : f32 + + //// CHECK: 0 + %pi_over_2 = constant 1.57079632679 : f32 + %cos_pi_over_2 = math.cos %pi_over_2 : f32 + vector.print %cos_pi_over_2 : f32 + + /// CHECK: -1 + %pi = constant 3.14159265359 : f32 + %cos_pi = math.cos %pi : f32 + vector.print %cos_pi : f32 + + // CHECK: 0 + %pi_3_over_2 = constant 4.71238898038 : f32 + %cos_pi_3_over_2 = math.cos %pi_3_over_2 : f32 + vector.print %cos_pi_3_over_2 : f32 + + // CHECK: -1, -0.5, 0 + %vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32> + %cos_vec_x = math.cos %vec_x : vector<3xf32> + vector.print %cos_vec_x : vector<3xf32> + + + return +} + func @main() { call @tanh(): () -> () @@ -227,5 +304,7 @@ call @log1p(): () -> () call @exp(): () -> () call @expm1(): () -> () + call @sin(): () -> () + call @cos(): () -> () return }