diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -14,6 +14,7 @@ class RewritePatternSet; void populateExpandCtlzPattern(RewritePatternSet &patterns); +void populateExpandTanPattern(RewritePatternSet &patterns); void populateExpandTanhPattern(RewritePatternSet &patterns); void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Math/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -54,6 +55,17 @@ return success(); } +static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) { + ImplicitLocOpBuilder b(op->getLoc(), rewriter); + Value operand = op.getOperand(); + Type type = operand.getType(); + Value sin = b.create(type, operand); + Value cos = b.create(type, operand); + Value div = b.create(type, sin, cos); + rewriter.replaceOp(op, div); + return success(); +} + static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter) { auto operand = op.getOperand(); @@ -107,6 +119,10 @@ patterns.add(convertCtlzOp); } +void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { + patterns.add(convertTanOp); +} + void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { patterns.add(convertTanhOp); } diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -24,6 +24,19 @@ // ----- +// CHECK-LABEL: func @tan +func.func @tan(%arg: f32) -> f32 { + %res = math.tan %arg : f32 + return %res : f32 +} + +// CHECK-SAME: %[[ARG0:.+]]: f32 +// CHECK: %[[SIN:.+]] = math.sin %[[ARG0]] +// CHECK: %[[COS:.+]] = math.cos %[[ARG0]] +// CEHCK: %[[DIV:.+]] = arith.div %[[SIN]] %[[COS]] + +// ----- + // CHECK-LABEL: func @ctlz func.func @ctlz(%arg: i32) -> i32 { // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i32 diff --git a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp --- a/mlir/test/lib/Dialect/Math/TestExpandMath.cpp +++ b/mlir/test/lib/Dialect/Math/TestExpandMath.cpp @@ -35,6 +35,7 @@ void TestExpandMathPass::runOnOperation() { RewritePatternSet patterns(&getContext()); populateExpandCtlzPattern(patterns); + populateExpandTanPattern(patterns); populateExpandTanhPattern(patterns); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); }