diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -141,9 +141,11 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add, VecOpToScalarOp, - VecOpToScalarOp>(patterns.getContext(), benefit); + VecOpToScalarOp, VecOpToScalarOp, + VecOpToScalarOp>(patterns.getContext(), benefit); patterns.add, PromoteOpToF32, - PromoteOpToF32>(patterns.getContext(), benefit); + PromoteOpToF32, PromoteOpToF32, + PromoteOpToF32>(patterns.getContext(), benefit); patterns.add>(patterns.getContext(), "atan2f", "atan2", benefit); patterns.add>(patterns.getContext(), "erff", @@ -154,6 +156,10 @@ "tanh", benefit); patterns.add>(patterns.getContext(), "roundf", "round", benefit); + patterns.add>(patterns.getContext(), "cosf", + "cos", benefit); + patterns.add>(patterns.getContext(), "sinf", + "sin", benefit); } namespace { diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -10,6 +10,10 @@ // CHECK-DAG: @tanhf(f32) -> f32 // CHECK-DAG: @round(f64) -> f64 // CHECK-DAG: @roundf(f32) -> f32 +// CHECK-DAG: @cos(f64) -> f64 +// CHECK-DAG: @cosf(f32) -> f32 +// CHECK-DAG: @sin(f64) -> f64 +// CHECK-DAG: @sinf(f32) -> f32 // CHECK-LABEL: func @tanh_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 @@ -129,3 +133,27 @@ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] return %float_result, %double_result : f32, f64 } + +// CHECK-LABEL: func @cos_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @cos_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cosf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.cos %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cos(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.cos %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +} + +// CHECK-LABEL: func @sin_caller +// CHECK-SAME: %[[FLOAT:.*]]: f32 +// CHECK-SAME: %[[DOUBLE:.*]]: f64 +func.func @sin_caller(%float: f32, %double: f64) -> (f32, f64) { + // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @sinf(%[[FLOAT]]) : (f32) -> f32 + %float_result = math.sin %float : f32 + // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @sin(%[[DOUBLE]]) : (f64) -> f64 + %double_result = math.sin %double : f64 + // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] + return %float_result, %double_result : f32, f64 +}