diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -142,10 +142,12 @@ PatternBenefit benefit) { patterns.add, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp>(patterns.getContext(), benefit); + VecOpToScalarOp, VecOpToScalarOp, + VecOpToScalarOp>(patterns.getContext(), benefit); patterns.add, PromoteOpToF32, PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32>(patterns.getContext(), benefit); + PromoteOpToF32, PromoteOpToF32, + PromoteOpToF32>(patterns.getContext(), benefit); patterns.add>(patterns.getContext(), "atan2f", "atan2", benefit); patterns.add>(patterns.getContext(), "erff", diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -63,6 +63,30 @@ return %float_result, %double_result : f32, f64 } +// CHECK-LABEL: func @erf_vec_caller( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { +func.func @erf_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { + // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> + // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64> + // CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32> + // CHECK: %[[OUT0_F32:.*]] = call @erff(%[[IN0_F32]]) : (f32) -> f32 + // CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32> + // CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32> + // CHECK: %[[OUT1_F32:.*]] = call @erff(%[[IN1_F32]]) : (f32) -> f32 + // CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32> + %float_result = math.erf %float : vector<2xf32> + // CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64> + // CHECK: %[[OUT0_F64:.*]] = call @erf(%[[IN0_F64]]) : (f64) -> f64 + // CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64> + // CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64> + // CHECK: %[[OUT1_F64:.*]] = call @erf(%[[IN1_F64]]) : (f64) -> f64 + // CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64> + %double_result = math.erf %double : vector<2xf64> + // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> + return %float_result, %double_result : vector<2xf32>, vector<2xf64> +} + // CHECK-LABEL: func @expm1_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64 @@ -157,3 +181,27 @@ // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] return %float_result, %double_result : f32, f64 } + +// CHECK-LABEL: func @round_vec_caller( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { +func.func @round_vec_caller(%float: vector<2xf32>, %double: vector<2xf64>) -> (vector<2xf32>, vector<2xf64>) { + // CHECK-DAG: %[[CVF:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32> + // CHECK-DAG: %[[CVD:.*]] = arith.constant dense<0.000000e+00> : vector<2xf64> + // CHECK: %[[IN0_F32:.*]] = vector.extract %[[VAL_0]][0] : vector<2xf32> + // CHECK: %[[OUT0_F32:.*]] = call @roundf(%[[IN0_F32]]) : (f32) -> f32 + // CHECK: %[[VAL_8:.*]] = vector.insert %[[OUT0_F32]], %[[CVF]] [0] : f32 into vector<2xf32> + // CHECK: %[[IN1_F32:.*]] = vector.extract %[[VAL_0]][1] : vector<2xf32> + // CHECK: %[[OUT1_F32:.*]] = call @roundf(%[[IN1_F32]]) : (f32) -> f32 + // CHECK: %[[VAL_11:.*]] = vector.insert %[[OUT1_F32]], %[[VAL_8]] [1] : f32 into vector<2xf32> + %float_result = math.round %float : vector<2xf32> + // CHECK: %[[IN0_F64:.*]] = vector.extract %[[VAL_1]][0] : vector<2xf64> + // CHECK: %[[OUT0_F64:.*]] = call @round(%[[IN0_F64]]) : (f64) -> f64 + // CHECK: %[[VAL_14:.*]] = vector.insert %[[OUT0_F64]], %[[CVD]] [0] : f64 into vector<2xf64> + // CHECK: %[[IN1_F64:.*]] = vector.extract %[[VAL_1]][1] : vector<2xf64> + // CHECK: %[[OUT1_F64:.*]] = call @round(%[[IN1_F64]]) : (f64) -> f64 + // CHECK: %[[VAL_17:.*]] = vector.insert %[[OUT1_F64]], %[[VAL_14]] [1] : f64 into vector<2xf64> + %double_result = math.round %double : vector<2xf64> + // CHECK: return %[[VAL_11]], %[[VAL_17]] : vector<2xf32>, vector<2xf64> + return %float_result, %double_result : vector<2xf32>, vector<2xf64> +}