diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -736,14 +736,45 @@ matchAndRewrite(complex::TanOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - Value cos = rewriter.create(loc, adaptor.getComplex()); Value sin = rewriter.create(loc, adaptor.getComplex()); rewriter.replaceOpWithNewOp(op, sin, cos); + return success(); + } +}; + +struct TanhOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::TanhOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto type = adaptor.getComplex().getType().cast(); + auto elementType = type.getElementType().cast(); + // The hyperbolic tangent for complex number can be calculated as follows. + // tanh(x + i * y) = (tanh(x) + i * tan(y)) / (1 + tanh(x) * tan(y)) + // See: https://proofwiki.org/wiki/Hyperbolic_Tangent_of_Complex_Number + Value real = + rewriter.create(loc, elementType, adaptor.getComplex()); + Value imag = + rewriter.create(loc, elementType, adaptor.getComplex()); + Value tanhA = rewriter.create(loc, real); + Value cosB = rewriter.create(loc, imag); + Value sinB = rewriter.create(loc, imag); + Value tanB = rewriter.create(loc, sinB, cosB); + Value numerator = + rewriter.create(loc, type, tanhA, tanB); + Value one = rewriter.create( + loc, elementType, rewriter.getFloatAttr(elementType, 1)); + Value mul = rewriter.create(loc, tanhA, tanB); + Value denominator = rewriter.create(loc, type, one, mul); + rewriter.replaceOpWithNewOp(op, numerator, denominator); return success(); } }; + } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -765,7 +796,8 @@ NegOpConversion, SignOpConversion, SinOpConversion, - TanOpConversion>(patterns.getContext()); + TanOpConversion, + TanhOpConversion>(patterns.getContext()); // clang-format on } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -pass-pipeline="func.func(convert-complex-to-standard)" | FileCheck %s +// RUN: mlir-opt %s --convert-complex-to-standard --split-input-file | FileCheck %s // CHECK-LABEL: func @complex_abs // CHECK-SAME: %[[ARG:.*]]: complex @@ -14,6 +14,8 @@ // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: return %[[NORM]] : f32 +// ----- + // CHECK-LABEL: func @complex_add // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_add(%lhs: complex, %rhs: complex) -> complex { @@ -29,6 +31,8 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_cos // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_cos(%arg: complex) -> complex { @@ -50,6 +54,8 @@ // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] +// ----- + // CHECK-LABEL: func @complex_div // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_div(%lhs: complex, %rhs: complex) -> complex { @@ -159,6 +165,8 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_eq // CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex func.func @complex_eq(%lhs: complex, %rhs: complex) -> i1 { @@ -174,6 +182,8 @@ // CHECK: %[[EQUAL:.*]] = arith.andi %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 // CHECK: return %[[EQUAL]] : i1 +// ----- + // CHECK-LABEL: func @complex_exp // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_exp(%arg: complex) -> complex { @@ -190,6 +200,8 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func.func @complex_expm1( // CHECK-SAME: %[[ARG:.*]]: complex) -> complex { func.func @complex_expm1(%arg: complex) -> complex { @@ -211,6 +223,8 @@ // CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex // CHECK: return %[[RES]] : complex +// ----- + // CHECK-LABEL: func @complex_log // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_log(%arg: complex) -> complex { @@ -230,6 +244,8 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_log1p // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_log1p(%arg: complex) -> complex { @@ -254,6 +270,8 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_mul // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_mul(%lhs: complex, %rhs: complex) -> complex { @@ -372,6 +390,8 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[FINAL_REAL]], %[[FINAL_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_neg // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_neg(%arg: complex) -> complex { @@ -385,6 +405,8 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_neq // CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex func.func @complex_neq(%lhs: complex, %rhs: complex) -> i1 { @@ -400,6 +422,8 @@ // CHECK: %[[NOT_EQUAL:.*]] = arith.ori %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 // CHECK: return %[[NOT_EQUAL]] : i1 +// ----- + // CHECK-LABEL: func @complex_sin // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_sin(%arg: complex) -> complex { @@ -421,6 +445,8 @@ // CHECK-DAG: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] +// ----- + // CHECK-LABEL: func @complex_sign // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_sign(%arg: complex) -> complex { @@ -445,6 +471,8 @@ // CHECK: %[[RESULT:.*]] = arith.select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_sub // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func.func @complex_sub(%lhs: complex, %rhs: complex) -> complex { @@ -460,6 +488,8 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// ----- + // CHECK-LABEL: func @complex_tan // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_tan(%arg: complex) -> complex { @@ -595,4 +625,23 @@ // CHECK: %[[RESULT_REAL_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_REAL_SPECIAL_CASE_1]], %[[RESULT_REAL]] : f32 // CHECK: %[[RESULT_IMAG_WITH_SPECIAL_CASES:.*]] = arith.select %[[RESULT_IS_NAN]], %[[RESULT_IMAG_SPECIAL_CASE_1]], %[[RESULT_IMAG]] : f32 // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL_WITH_SPECIAL_CASES]], %[[RESULT_IMAG_WITH_SPECIAL_CASES]] : complex -// CHECK: return %[[RESULT]] : complex \ No newline at end of file +// CHECK: return %[[RESULT]] : complex + +// ----- + +// CHECK-LABEL: func @complex_tanh +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_tanh(%arg: complex) -> complex { + %tanh = complex.tanh %arg: complex + return %tanh : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[TANH_A:.*]] = math.tanh %[[REAL]] : f32 +// CHECK: %[[COS_B:.*]] = math.cos %[[IMAG]] : f32 +// CHECK: %[[SIN_B:.*]] = math.sin %[[IMAG]] : f32 +// CHECK: %[[TAN_B:.*]] = arith.divf %[[SIN_B]], %[[COS_B]] : f32 +// CHECK: %[[NUM:.*]] = complex.create %[[TANH_A]], %[[TAN_B]] : complex +// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[MUL:.*]] = arith.mulf %[[TANH_A]], %[[TAN_B]] : f32 +// CHECK: %[[DENOM:.*]] = complex.create %[[ONE]], %[[MUL]] : complex