diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -331,7 +331,8 @@ SmallVector operands; for (auto operand : op->getOperands()) operands.push_back(rewriter.create(loc, newType, operand)); - auto result = rewriter.create(loc, newType, operands); + auto result = + rewriter.create(loc, TypeRange{newType}, operands, op->getAttrs()); rewriter.replaceOpWithNewOp(op, origType, result); return success(); } @@ -1381,11 +1382,20 @@ void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options) { + // Patterns for leveraging existing f32 lowerings on other data types. + patterns + .add, ReuseF32Expansion, + ReuseF32Expansion, ReuseF32Expansion, + ReuseF32Expansion, ReuseF32Expansion, + ReuseF32Expansion, ReuseF32Expansion, + ReuseF32Expansion, ReuseF32Expansion, + ReuseF32Expansion, ReuseF32Expansion>( + patterns.getContext()); + patterns.add, - SinAndCosApproximation, + CbrtApproximation, SinAndCosApproximation, SinAndCosApproximation>( patterns.getContext()); if (options.enableAvx2) diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -642,4 +642,47 @@ func.func @cbrt_vector(%arg0: vector<4xf32>) -> vector<4xf32> { %0 = "math.cbrt"(%arg0) : (vector<4xf32>) -> vector<4xf32> func.return %0 : vector<4xf32> -} \ No newline at end of file +} + + +// CHECK-LABEL: @math_f16 +func.func @math_f16(%arg0 : vector<4xf16>) -> vector<4xf16> { + + // CHECK-NOT: math.atan + %0 = "math.atan"(%arg0) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.atan2 + %1 = "math.atan2"(%0, %arg0) : (vector<4xf16>, vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.tanh + %2 = "math.tanh"(%1) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.log + %3 = "math.log"(%2) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.log2 + %4 = "math.log2"(%3) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.log1p + %5 = "math.log1p"(%4) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.erf + %6 = "math.erf"(%5) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.exp + %7 = "math.exp"(%6) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.expm1 + %8 = "math.expm1"(%7) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.cbrt + %9 = "math.cbrt"(%8) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.sin + %10 = "math.sin"(%9) : (vector<4xf16>) -> vector<4xf16> + + // CHECK-NOT: math.cos + %11 = "math.cos"(%10) : (vector<4xf16>) -> vector<4xf16> + + return %11 : vector<4xf16> +}