diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -315,6 +315,28 @@ } }; +struct LogOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::LogOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::LogOp::Adaptor transformed(operands); + auto type = transformed.complex().getType().cast(); + auto elementType = type.getElementType().cast(); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + Value abs = b.create(elementType, transformed.complex()); + Value resultReal = b.create(elementType, abs); + Value real = b.create(elementType, transformed.complex()); + Value imag = b.create(elementType, transformed.complex()); + Value resultImag = b.create(elementType, imag, real); + rewriter.replaceOpWithNewOp(op, type, resultReal, + resultImag); + return success(); + } +}; + struct NegOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -374,6 +396,7 @@ ComparisonOpConversion, DivOpConversion, ExpOpConversion, + LogOpConversion, NegOpConversion, SignOpConversion>(patterns.getContext()); // clang-format on @@ -396,8 +419,8 @@ target.addLegalDialect(); target.addIllegalOp(); + complex::ExpOp, complex::LogOp, complex::NotEqualOp, + complex::NegOp, complex::SignOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -154,6 +154,25 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-LABEL: func @complex_log +// CHECK-SAME: %[[ARG:.*]]: complex +func @complex_log(%arg: complex) -> complex { + %log = complex.log %arg: complex + return %log : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[SQR_REAL:.*]] = mulf %[[REAL]], %[[REAL]] : f32 +// CHECK: %[[SQR_IMAG:.*]] = mulf %[[IMAG]], %[[IMAG]] : f32 +// CHECK: %[[SQ_NORM:.*]] = addf %[[SQR_REAL]], %[[SQR_IMAG]] : f32 +// CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 +// CHECK: %[[RESULT_REAL:.*]] = math.log %[[NORM]] : f32 +// CHECK: %[[REAL2:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG2:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = math.atan2 %[[IMAG2]], %[[REAL2]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + // CHECK-LABEL: func @complex_neg // CHECK-SAME: %[[ARG:.*]]: complex func @complex_neg(%arg: complex) -> complex {