diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -154,19 +154,20 @@ void mlir::populateMathToLibmConversionPatterns( RewritePatternSet &patterns, PatternBenefit benefit, llvm::Optional log1pBenefit) { - patterns.add, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp, VecOpToScalarOp, - VecOpToScalarOp, + patterns.add, VecOpToScalarOp, + VecOpToScalarOp, VecOpToScalarOp, + VecOpToScalarOp, VecOpToScalarOp, + VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp, VecOpToScalarOp>( patterns.getContext(), benefit); - patterns.add, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32, PromoteOpToF32, - PromoteOpToF32>(patterns.getContext(), benefit); + patterns.add, PromoteOpToF32, + PromoteOpToF32, PromoteOpToF32, + PromoteOpToF32, PromoteOpToF32, + PromoteOpToF32, PromoteOpToF32, + PromoteOpToF32, PromoteOpToF32, + PromoteOpToF32, PromoteOpToF32>( + patterns.getContext(), benefit); patterns.add>(patterns.getContext(), "atanf", "atan", benefit); patterns.add>(patterns.getContext(), diff --git a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir --- a/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir +++ b/mlir/test/Conversion/MathToLibm/convert-to-libm.mlir @@ -246,13 +246,22 @@ // CHECK-LABEL: func @cbrt_caller // CHECK-SAME: %[[FLOAT:.*]]: f32 // CHECK-SAME: %[[DOUBLE:.*]]: f64 -func.func @cbrt_caller(%float: f32, %double: f64) -> (f32, f64) { - // CHECK-DAG: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32 +func.func @cbrt_caller(%float: f32, %double: f64, %half: f16, %bfloat: bf16, + %float_vec: vector<2xf32>) -> (f32, f64, f16, bf16, vector<2xf32>) { + // CHECK: %[[FLOAT_RESULT:.*]] = call @cbrtf(%[[FLOAT]]) : (f32) -> f32 %float_result = math.cbrt %float : f32 - // CHECK-DAG: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64 + // CHECK: %[[DOUBLE_RESULT:.*]] = call @cbrt(%[[DOUBLE]]) : (f64) -> f64 %double_result = math.cbrt %double : f64 + // Just check that these lower successfully: + // CHECK: call @cbrtf + %half_result = math.cbrt %half : f16 + // CHECK: call @cbrtf + %bfloat_result = math.cbrt %bfloat : bf16 + // CHECK: call @cbrtf + %vec_result = math.cbrt %float_vec : vector<2xf32> // CHECK: return %[[FLOAT_RESULT]], %[[DOUBLE_RESULT]] - return %float_result, %double_result : f32, f64 + return %float_result, %double_result, %half_result, %bfloat_result, %vec_result + : f32, f64, f16, bf16, vector<2xf32> } // CHECK-LABEL: func @cos_caller