diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" @@ -339,6 +340,43 @@ // AtanOp approximation. //----------------------------------------------------------------------------// +namespace { +// Pattern to cast to F32 to reuse F32 expansion as backup. +struct Atan2Cast : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(math::Atan2Op op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult Atan2Cast::matchAndRewrite(math::Atan2Op op, + PatternRewriter &rewriter) const { + auto x = op.getOperand(0); + auto y = op.getOperand(1); + // Skip if already F32 or larger than 32 bits. + if (getElementTypeOrSelf(x).isF32() || + getElementTypeOrSelf(x).getIntOrFloatBitWidth() > 32) + return failure(); + + Type f32Type; + if (auto shaped = x.getType().dyn_cast()) { + f32Type = shaped.clone(rewriter.getF32Type()); + } else if (x.getType().isa()) { + f32Type = rewriter.getF32Type(); + } else { + return rewriter.notifyMatchFailure(op, + "unable to cast to F32 equivalent type"); + } + + x = rewriter.create(op.getLoc(), f32Type, x); + y = rewriter.create(op.getLoc(), f32Type, y); + auto result = rewriter.create(op.getLoc(), f32Type, x, y); + rewriter.replaceOpWithNewOp(op, op.getResult().getType(), + result); + return success(); +} + namespace { struct Atan2Approximation : public OpRewritePattern { public: @@ -1206,10 +1244,10 @@ void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options) { - patterns.add, + patterns.add, SinAndCosApproximation>( patterns.getContext()); if (options.enableAvx2) diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -542,7 +542,9 @@ // CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099 // CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987 // CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637 -// CHECK-DAG: %[[RATIO:.+]] = arith.divf %arg0, %arg1 +// CHECK-DAG: %[[ARG0:.+]] = arith.extf %arg0 : f16 to f32 +// CHECK-DAG: %[[ARG1:.+]] = arith.extf %arg1 : f16 to f32 +// CHECK-DAG: %[[RATIO:.+]] = arith.divf %[[ARG0]], %[[ARG1]] // CHECK-DAG: %[[ABS:.+]] = math.abs %[[RATIO]] // CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]] // CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]] @@ -562,30 +564,31 @@ // CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]] // CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]] // CHECK-DAG: %[[ATAN_ADJUST:.+]] = arith.select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]] -// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %arg1, %[[ZERO]] +// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %[[ARG1]], %[[ZERO]] // CHECK-DAG: %[[ATAN_EST:.+]] = arith.select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]] // Handle PI / 2 edge case: -// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %arg1, %[[ZERO]] -// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %arg0, %[[ZERO]] +// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %[[ARG1]], %[[ZERO]] +// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %[[ARG0]], %[[ZERO]] // CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]] // CHECK-DAG: %[[EDGE1:.+]] = arith.select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]] // Handle -PI / 2 edge case: // CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637 -// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] +// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %[[ARG0]], %[[ZERO]] // CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]] // CHECK-DAG: %[[EDGE2:.+]] = arith.select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]] // Handle Nan edgecase: -// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %arg0, %[[ZERO]] +// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %[[ARG0]], %[[ZERO]] // CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]] // CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000 // CHECK-DAG: %[[EDGE3:.+]] = arith.select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]] -// CHECK: return %[[EDGE3]] +// CHECK: %[[RET:.+]] = arith.truncf %[[EDGE3]] +// CHECK: return %[[RET]] -func @atan2_scalar(%arg0: f32, %arg1: f32) -> f32 { - %0 = math.atan2 %arg0, %arg1 : f32 - return %0 : f32 +func @atan2_scalar(%arg0: f16, %arg1: f16) -> f16 { + %0 = math.atan2 %arg0, %arg1 : f16 + return %0 : f16 }