diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -630,11 +630,143 @@ return success(); } +//----------------------------------------------------------------------------// +// Sin and Cos approximation. +//----------------------------------------------------------------------------// + +namespace { + +template +struct SinAndCosApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const final; +}; +} // namespace + +#define TWO_OVER_PI \ + 0.6366197723675813430755350534900574481378385829618257949906693762L +#define PI_OVER_2 \ + 1.5707963267948966192313216916397514420985846996875529104874722961L + +// Approximates sin(x) or cos(x) by finding the best approximation polynomial in +// the reduced range [0, pi/2] for both sin(x) and cos(x). Then given y in the +// reduced range sin(x) will be computed as sin(y), -sin(y), cos(y) or -cos(y). +template +LogicalResult SinAndCosApproximation::matchAndRewrite( + OpTy op, PatternRewriter &rewriter) const { + auto width = vectorWidth(op.operand().getType(), isF32); + if (!width.hasValue()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, *width); + }; + auto mul = [&](Value a, Value b) -> Value { + return builder.create(a, b); + }; + auto sub = [&](Value a, Value b) -> Value { + return builder.create(a, b); + }; + auto floor = [&](Value a) { return builder.create(a); }; + + auto i32Vec = broadcast(builder.getI32Type(), *width); + auto fPToSingedInteger = [&](Value a) -> Value { + return builder.create(a, i32Vec); + }; + + auto modulo4 = [&](Value a) -> Value { + return builder.create(a, bcast(i32Cst(builder, 3))); + }; + + auto isEqualTo = [&](Value a, Value b) -> Value { + return builder.create(CmpIPredicate::eq, a, b); + }; + + auto isGreaterThan = [&](Value a, Value b) -> Value { + return builder.create(CmpIPredicate::sgt, a, b); + }; + + auto select = [&](Value cond, Value t, Value f) -> Value { + return builder.create(cond, t, f); + }; + + auto fmla = [&](Value a, Value b, Value c) { + return builder.create(a, b, c); + }; + + auto Or = [&](Value a, Value b) { return builder.create(a, b); }; + + Value twoOverPi = bcast(f32Cst(builder, TWO_OVER_PI)); + Value piOverTwo = bcast(f32Cst(builder, PI_OVER_2)); + + Value x = op.operand(); + + Value k = floor(mul(x, twoOverPi)); + + Value y = sub(x, mul(k, piOverTwo)); + + Value cstOne = bcast(f32Cst(builder, 1.0)); + Value cstNegativeOne = bcast(f32Cst(builder, -1.0)); + + Value cstSC2 = bcast(f32Cst(builder, -0.16666667163372039794921875f)); + Value cstSC4 = bcast(f32Cst(builder, 8.333347737789154052734375e-3f)); + Value cstSC6 = bcast(f32Cst(builder, -1.9842604524455964565277099609375e-4f)); + Value cstSC8 = + bcast(f32Cst(builder, 2.760012648650445044040679931640625e-6f)); + Value cstSC10 = + bcast(f32Cst(builder, -2.50293279435709337121807038784027099609375e-8f)); + + Value cstCC2 = bcast(f32Cst(builder, -0.5f)); + Value cstCC4 = bcast(f32Cst(builder, 4.166664183139801025390625e-2f)); + Value cstCC6 = bcast(f32Cst(builder, -1.388833043165504932403564453125e-3f)); + Value cstCC8 = bcast(f32Cst(builder, 2.47562347794882953166961669921875e-5f)); + Value cstCC10 = + bcast(f32Cst(builder, -2.59630184018533327616751194000244140625e-7f)); + + Value kMod4 = modulo4(fPToSingedInteger(k)); + + Value kR0 = isEqualTo(kMod4, bcast(i32Cst(builder, 0))); + Value kR1 = isEqualTo(kMod4, bcast(i32Cst(builder, 1))); + Value kR2 = isEqualTo(kMod4, bcast(i32Cst(builder, 2))); + Value kR3 = isEqualTo(kMod4, bcast(i32Cst(builder, 3))); + + Value sinuseCos = isSine ? Or(kR1, kR3) : Or(kR0, kR2); + Value negativeRange = + isSine ? isGreaterThan(kMod4, bcast(i32Cst(builder, 1))) : Or(kR1, kR2); + + Value y2 = mul(y, y); + + Value base = select(sinuseCos, cstOne, y); + Value cstC2 = select(sinuseCos, cstCC2, cstSC2); + Value cstC4 = select(sinuseCos, cstCC4, cstSC4); + Value cstC6 = select(sinuseCos, cstCC6, cstSC6); + Value cstC8 = select(sinuseCos, cstCC8, cstSC8); + Value cstC10 = select(sinuseCos, cstCC10, cstSC10); + + Value v1 = fmla(y2, cstC10, cstC8); + Value v2 = fmla(y2, v1, cstC6); + Value v3 = fmla(y2, v2, cstC4); + Value v4 = fmla(y2, v3, cstC2); + Value v5 = fmla(y2, v4, cstOne); + Value v6 = mul(base, v5); + + Value approximation = select(negativeRange, mul(cstNegativeOne, v6), v6); + + rewriter.replaceOp(op, approximation); + + return success(); +} + //----------------------------------------------------------------------------// void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns) { patterns.add( + Log1pApproximation, ExpApproximation, ExpM1Approximation, + SinAndCosApproximation, + SinAndCosApproximation>( patterns.getContext()); } diff --git a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir --- a/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir +++ b/mlir/test/mlir-cpu-runner/math_polynomial_approx.mlir @@ -219,6 +219,83 @@ return } +// -------------------------------------------------------------------------- // +// Sin. +// -------------------------------------------------------------------------- // +func @sin() { + // CHECK: 0 + %0 = constant 0.0 : f32 + %sin_0 = math.sin %0 : f32 + vector.print %sin_0 : f32 + + // CHECK: 0.707107 + %pi_over_4 = constant 0.78539816339 : f32 + %sin_pi_over_4 = math.sin %pi_over_4 : f32 + vector.print %sin_pi_over_4 : f32 + + // CHECK: 1 + %pi_over_2 = constant 1.57079632679 : f32 + %sin_pi_over_2 = math.sin %pi_over_2 : f32 + vector.print %sin_pi_over_2 : f32 + + + // CHECK: 0 + %pi = constant 3.14159265359 : f32 + %sin_pi = math.sin %pi : f32 + vector.print %sin_pi : f32 + + // CHECK: -1 + %pi_3_over_2 = constant 4.71238898038 : f32 + %sin_pi_3_over_2 = math.sin %pi_3_over_2 : f32 + vector.print %sin_pi_3_over_2 : f32 + + // CHECK: 0, 0.866025, -1 + %vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32> + %sin_vec_x = math.sin %vec_x : vector<3xf32> + vector.print %sin_vec_x : vector<3xf32> + + return +} + +// -------------------------------------------------------------------------- // +// cos. +// -------------------------------------------------------------------------- // + +func @cos() { + // CHECK: 1 + %0 = constant 0.0 : f32 + %cos_0 = math.cos %0 : f32 + vector.print %cos_0 : f32 + + // CHECK: 0.707107 + %pi_over_4 = constant 0.78539816339 : f32 + %cos_pi_over_4 = math.cos %pi_over_4 : f32 + vector.print %cos_pi_over_4 : f32 + + //// CHECK: 0 + %pi_over_2 = constant 1.57079632679 : f32 + %cos_pi_over_2 = math.cos %pi_over_2 : f32 + vector.print %cos_pi_over_2 : f32 + + /// CHECK: -1 + %pi = constant 3.14159265359 : f32 + %cos_pi = math.cos %pi : f32 + vector.print %cos_pi : f32 + + // CHECK: 0 + %pi_3_over_2 = constant 4.71238898038 : f32 + %cos_pi_3_over_2 = math.cos %pi_3_over_2 : f32 + vector.print %cos_pi_3_over_2 : f32 + + // CHECK: -1, -0.5, 0 + %vec_x = constant dense<[9.42477796077, 2.09439510239, -1.57079632679]> : vector<3xf32> + %cos_vec_x = math.cos %vec_x : vector<3xf32> + vector.print %cos_vec_x : vector<3xf32> + + + return +} + func @main() { call @tanh(): () -> () @@ -227,5 +304,7 @@ call @log1p(): () -> () call @exp(): () -> () call @expm1(): () -> () + call @sin(): () -> () + call @cos(): () -> () return }