diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -1075,6 +1075,28 @@ return success(); } +//----------------------------------------------------------------------------// +// Tan decomposition to sin / cos. +//----------------------------------------------------------------------------// + +namespace { +struct TanDecomposition : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::TanOp op, + PatternRewriter &rewriter) const final { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operand = op.getOperand(); + Type type = operand.getType(); + Value sin = b.create(type, operand); + Value cos = b.create(type, operand); + Value div = b.create(type, sin, cos); + rewriter.replaceOp(op, div); + return success(); + } +}; +} // namespace + //----------------------------------------------------------------------------// // Sin and Cos approximation. //----------------------------------------------------------------------------// @@ -1291,7 +1313,7 @@ patterns.add, + ReuseF32Expansion, TanDecomposition, SinAndCosApproximation, SinAndCosApproximation>( patterns.getContext()); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -593,3 +593,10 @@ %0 = math.atan2 %arg0, %arg1 : f16 return %0 : f16 } + +// CHECK-LABEL: @tan_scalar +func.func @tan_scalar(%arg0 : f32) -> f32 { + // CHECK-NOT: math.tan + %0 = math.tan %arg0 : f32 + return %0 : f32 +}