diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -278,6 +278,133 @@ } } // namespace +//----------------------------------------------------------------------------// +// AtanOp approximation. +//----------------------------------------------------------------------------// + +namespace { +struct AtanApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::AtanOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +AtanApproximation::matchAndRewrite(math::AtanOp op, + PatternRewriter &rewriter) const { + auto operand = op.getOperand(); + if (!getElementTypeOrSelf(operand).isF32()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ArrayRef shape = vectorShape(op.getOperand()); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto one = broadcast(builder, f32Cst(builder, 1.0f), shape); + + // Remap the problem over [0.0, 1.0] by looking at the absolute value and the + // handling symmetry. + Value abs = builder.create(operand); + Value reciprocal = builder.create(one, abs); + Value compare = + builder.create(arith::CmpFPredicate::OLT, abs, reciprocal); + Value x = builder.create(compare, abs, reciprocal); + + // Perform the Taylor series approximation for atan over the range + // [-1.0, 1.0]. + auto n1 = broadcast(builder, f32Cst(builder, 0.14418283), shape); + auto n2 = broadcast(builder, f32Cst(builder, -0.34999234), shape); + auto n3 = broadcast(builder, f32Cst(builder, -0.01067831), shape); + auto n4 = broadcast(builder, f32Cst(builder, 1.00209986), shape); + + Value p = builder.create(x, n1, n2); + p = builder.create(x, p, n3); + p = builder.create(x, p, n4); + p = builder.create(x, p); + + // Remap the solution for over [0.0, 1.0] to [0.0, inf] + auto half_pi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); + Value sub = builder.create(half_pi, p); + Value select = builder.create(compare, p, sub); + + // Correct for signing of the input. + rewriter.replaceOpWithNewOp(op, select, operand); + return success(); +} + +//----------------------------------------------------------------------------// +// AtanOp approximation. +//----------------------------------------------------------------------------// + +namespace { +struct Atan2Approximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::Atan2Op op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +Atan2Approximation::matchAndRewrite(math::Atan2Op op, + PatternRewriter &rewriter) const { + auto y = op.getOperand(0); + auto x = op.getOperand(1); + if (!getElementTypeOrSelf(x).isF32()) + return rewriter.notifyMatchFailure(op, "unsupported operand type"); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + ArrayRef shape = vectorShape(op.getResult()); + + // Compute atan in the valid range. + auto div = builder.create(y, x); + auto atan = builder.create(div); + + // Determine what the atan would be for a 180 degree rotation. + auto zero = broadcast(builder, f32Cst(builder, 0.0f), shape); + auto pi = broadcast(builder, f32Cst(builder, 3.14159265359f), shape); + auto add_pi = builder.create(atan, pi); + auto sub_pi = builder.create(atan, pi); + auto atan_gt = + builder.create(arith::CmpFPredicate::OGT, atan, zero); + auto flipped_atan = builder.create(atan_gt, sub_pi, add_pi); + + // Determine whether to directly use atan or use the 180 degree flip + auto x_gt = builder.create(arith::CmpFPredicate::OGT, x, zero); + Value result = builder.create(x_gt, atan, flipped_atan); + + // Handle x = 0, y > 0 + Value x_zero = + builder.create(arith::CmpFPredicate::OEQ, x, zero); + Value y_gt = + builder.create(arith::CmpFPredicate::OGT, y, zero); + Value is_half_pi = builder.create(x_zero, y_gt); + auto half_pi = broadcast(builder, f32Cst(builder, 1.57079632679f), shape); + result = builder.create(is_half_pi, half_pi, result); + + // Handle x = 0, y < 0 + Value y_lt = + builder.create(arith::CmpFPredicate::OLT, y, zero); + Value is_negative_half_pi_pi = builder.create(x_zero, y_lt); + auto negative_half_pi_pi = + broadcast(builder, f32Cst(builder, -1.57079632679), shape); + result = builder.create(is_negative_half_pi_pi, negative_half_pi_pi, + result); + + // Handle x = 0, y = 0; + Value y_zero = + builder.create(arith::CmpFPredicate::OEQ, y, zero); + Value is_nan = builder.create(x_zero, y_zero); + Value cst_nan = broadcast(builder, f32FromBits(builder, 0x7fc00000), shape); + result = builder.create(is_nan, cst_nan, result); + + rewriter.replaceOp(op, result); + return success(); +} + //----------------------------------------------------------------------------// // TanhOp approximation. //----------------------------------------------------------------------------// @@ -1074,9 +1201,10 @@ void mlir::populateMathPolynomialApproximationPatterns( RewritePatternSet &patterns, const MathPolynomialApproximationOptions &options) { - patterns.add, + patterns.add, SinAndCosApproximation>( patterns.getContext()); if (options.enableAvx2) diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -507,3 +507,85 @@ %0 = math.rsqrt %arg0 : vector<2x16xf32> return %0 : vector<2x16xf32> } + +// CHECK-LABEL: @atan_scalar +// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 +// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831 +// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335 +// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099 +// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987 +// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637 +// CHECK-DAG: %[[ABS:.+]] = math.abs %arg0 +// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]] +// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]] +// CHECK-DAG: %[[SEL:.+]] = select %[[CMP]], %[[ABS]], %[[DIV]] +// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]] +// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]] +// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]] +// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]] +// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]] +// CHECK-DAG: %[[EST:.+]] = select %[[CMP]], %[[P3]], %[[SUB]] +// CHECK-DAG: %[[RES:.+]] = math.copysign %[[EST]], %arg0 +// CHECK: return %[[RES]] +func @atan_scalar(%arg0: f32) -> f32 { + %0 = math.atan %arg0 : f32 + return %0 : f32 +} + + +// CHECK-LABEL: @atan2_scalar + +// ATan approximation: +// CHECK-DAG: %[[ONE:.+]] = arith.constant 1.000000e+00 +// CHECK-DAG: %[[N1:.+]] = arith.constant 0.144182831 +// CHECK-DAG: %[[N2:.+]] = arith.constant -0.349992335 +// CHECK-DAG: %[[N3:.+]] = arith.constant -0.0106783099 +// CHECK-DAG: %[[N4:.+]] = arith.constant 1.00209987 +// CHECK-DAG: %[[HALF_PI:.+]] = arith.constant 1.57079637 +// CHECK-DAG: %[[RATIO:.+]] = arith.divf %arg0, %arg1 +// CHECK-DAG: %[[ABS:.+]] = math.abs %[[RATIO]] +// CHECK-DAG: %[[DIV:.+]] = arith.divf %cst, %[[ABS]] +// CHECK-DAG: %[[CMP:.+]] = arith.cmpf olt, %[[ABS]], %[[DIV]] +// CHECK-DAG: %[[SEL:.+]] = select %[[CMP]], %[[ABS]], %[[DIV]] +// CHECK-DAG: %[[P0:.+]] = math.fma %[[SEL]], %[[N1]], %[[N2]] +// CHECK-DAG: %[[P1:.+]] = math.fma %[[SEL]], %[[P0]], %[[N3]] +// CHECK-DAG: %[[P2:.+]] = math.fma %[[SEL]], %[[P1]], %[[N4]] +// CHECK-DAG: %[[P3:.+]] = arith.mulf %[[SEL]], %[[P2]] +// CHECK-DAG: %[[SUB:.+]] = arith.subf %[[HALF_PI]], %[[P3]] +// CHECK-DAG: %[[EST:.+]] = select %[[CMP]], %[[P3]], %[[SUB]] +// CHECK-DAG: %[[ATAN:.+]] = math.copysign %[[EST]], %[[RATIO]] + +// Handle the case of x < 0: +// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 +// CHECK-DAG: %[[PI:.+]] = arith.constant 3.14159274 +// CHECK-DAG: %[[ADD_PI:.+]] = arith.addf %[[ATAN]], %[[PI]] +// CHECK-DAG: %[[SUB_PI:.+]] = arith.subf %[[ATAN]], %[[PI]] +// CHECK-DAG: %[[CMP_ATAN:.+]] = arith.cmpf ogt, %[[ATAN]], %[[ZERO]] +// CHECK-DAG: %[[ATAN_ADJUST:.+]] = select %[[CMP_ATAN]], %[[SUB_PI]], %[[ADD_PI]] +// CHECK-DAG: %[[X_NEG:.+]] = arith.cmpf ogt, %arg1, %[[ZERO]] +// CHECK-DAG: %[[ATAN_EST:.+]] = select %[[X_NEG]], %[[ATAN]], %[[ATAN_ADJUST]] + +// Handle PI / 2 edge case: +// CHECK-DAG: %[[X_ZERO:.+]] = arith.cmpf oeq, %arg1, %[[ZERO]] +// CHECK-DAG: %[[Y_POS:.+]] = arith.cmpf ogt, %arg0, %[[ZERO]] +// CHECK-DAG: %[[IS_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_POS]] +// CHECK-DAG: %[[EDGE1:.+]] = select %[[IS_HALF_PI]], %[[HALF_PI]], %[[ATAN_EST]] + +// Handle -PI / 2 edge case: +// CHECK-DAG: %[[NEG_HALF_PI:.+]] = arith.constant -1.57079637 +// CHECK-DAG: %[[Y_NEG:.+]] = arith.cmpf olt, %arg0, %[[ZERO]] +// CHECK-DAG: %[[IS_NEG_HALF_PI:.+]] = arith.andi %[[X_ZERO]], %[[Y_NEG]] +// CHECK-DAG: %[[EDGE2:.+]] = select %[[IS_NEG_HALF_PI]], %[[NEG_HALF_PI]], %[[EDGE1]] + +// Handle Nan edgecase: +// CHECK-DAG: %[[Y_ZERO:.+]] = arith.cmpf oeq, %arg0, %[[ZERO]] +// CHECK-DAG: %[[X_Y_ZERO:.+]] = arith.andi %[[X_ZERO]], %[[Y_ZERO]] +// CHECK-DAG: %[[NAN:.+]] = arith.constant 0x7FC00000 +// CHECK-DAG: %[[EDGE3:.+]] = select %[[X_Y_ZERO]], %[[NAN]], %[[EDGE2]] +// CHECK: return %[[EDGE3]] + +func @atan2_scalar(%arg0: f32, %arg1: f32) -> f32 { + %0 = math.atan2 %arg0, %arg1 : f32 + return %0 : f32 +} + diff --git a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir --- a/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-cpu-runner/math-polynomial-approx.mlir @@ -371,6 +371,122 @@ return } +// -------------------------------------------------------------------------- // +// Atan. +// -------------------------------------------------------------------------- // + +func @atan() { + // CHECK: -0.785184 + %0 = arith.constant -1.0 : f32 + %atan_0 = math.atan %0 : f32 + vector.print %atan_0 : f32 + + // CHECK: 0.785184 + %1 = arith.constant 1.0 : f32 + %atan_1 = math.atan %1 : f32 + vector.print %atan_1 : f32 + + // CHECK: -0.463643 + %2 = arith.constant -0.5 : f32 + %atan_2 = math.atan %2 : f32 + vector.print %atan_2 : f32 + + // CHECK: 0.463643 + %3 = arith.constant 0.5 : f32 + %atan_3 = math.atan %3 : f32 + vector.print %atan_3 : f32 + + // CHECK: 0 + %4 = arith.constant 0.0 : f32 + %atan_4 = math.atan %4 : f32 + vector.print %atan_4 : f32 + + // CHECK: -1.10715 + %5 = arith.constant -2.0 : f32 + %atan_5 = math.atan %5 : f32 + vector.print %atan_5 : f32 + + // CHECK: 1.10715 + %6 = arith.constant 2.0 : f32 + %atan_6 = math.atan %6 : f32 + vector.print %atan_6 : f32 + + return +} + + +// -------------------------------------------------------------------------- // +// Atan2. +// -------------------------------------------------------------------------- // + +func @atan2() { + %zero = arith.constant 0.0 : f32 + %one = arith.constant 1.0 : f32 + %two = arith.constant 2.0 : f32 + %neg_one = arith.constant -1.0 : f32 + %neg_two = arith.constant -2.0 : f32 + + // CHECK: 0 + %atan2_0 = math.atan2 %zero, %one : f32 + vector.print %atan2_0 : f32 + + // CHECK: 1.5708 + %atan2_1 = math.atan2 %one, %zero : f32 + vector.print %atan2_1 : f32 + + // CHECK: 3.14159 + %atan2_2 = math.atan2 %zero, %neg_one : f32 + vector.print %atan2_2 : f32 + + // CHECK: -1.5708 + %atan2_3 = math.atan2 %neg_one, %zero : f32 + vector.print %atan2_3 : f32 + + // CHECK: nan + %atan2_4 = math.atan2 %zero, %zero : f32 + vector.print %atan2_4 : f32 + + // CHECK: 1.10715 + %atan2_5 = math.atan2 %two, %one : f32 + vector.print %atan2_5 : f32 + + // CHECK: 2.03444 + %x6 = arith.constant -1.0 : f32 + %y6 = arith.constant 2.0 : f32 + %atan2_6 = math.atan2 %two, %neg_one : f32 + vector.print %atan2_6 : f32 + + // CHECK: -2.03444 + %atan2_7 = math.atan2 %neg_two, %neg_one : f32 + vector.print %atan2_7 : f32 + + // CHECK: -1.10715 + %atan2_8 = math.atan2 %neg_two, %one : f32 + vector.print %atan2_8 : f32 + + // CHECK: 0.463643 + %atan2_9 = math.atan2 %one, %two : f32 + vector.print %atan2_9 : f32 + + // CHECK: 2.67795 + %x10 = arith.constant -2.0 : f32 + %y10 = arith.constant 1.0 : f32 + %atan2_10 = math.atan2 %one, %neg_two : f32 + vector.print %atan2_10 : f32 + + // CHECK: -2.67795 + %x11 = arith.constant -2.0 : f32 + %y11 = arith.constant -1.0 : f32 + %atan2_11 = math.atan2 %neg_one, %neg_two : f32 + vector.print %atan2_11 : f32 + + // CHECK: -0.463643 + %atan2_12 = math.atan2 %neg_one, %two : f32 + vector.print %atan2_12 : f32 + + return +} + func @main() { call @tanh(): () -> () @@ -382,5 +498,7 @@ call @expm1(): () -> () call @sin(): () -> () call @cos(): () -> () + call @atan() : () -> () + call @atan2() : () -> () return }