diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -155,6 +155,27 @@ let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs)"; } +//===----------------------------------------------------------------------===// +// ExpOp +//===----------------------------------------------------------------------===// + +def ExpOp : ComplexUnaryOp<"exp", [SameOperandsAndResultType]> { + let summary = "computes exponential of a complex number"; + let description = [{ + The `exp` op takes a single complex number and computes the exponential of + it, i.e. `exp(x)` or `e^(x)`, where `x` is the input tensor. + `e` denotes Euler's number and is approximately equal to 2.718281. + + Example: + + ```mlir + %a = complex.exp %b : complex + ``` + }]; + + let results = (outs Complex:$result); +} + //===----------------------------------------------------------------------===// // ImOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -86,7 +86,7 @@ ConversionPatternRewriter &rewriter) const override { complex::DivOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = transformed.lhs().getType().template cast(); + auto type = transformed.lhs().getType().cast(); auto elementType = type.getElementType().cast(); Value lhsReal = @@ -286,6 +286,33 @@ return success(); } }; + +struct ExpOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::ExpOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::ExpOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + auto type = transformed.complex().getType().cast(); + auto elementType = type.getElementType().cast(); + + Value real = + rewriter.create(loc, elementType, transformed.complex()); + Value imag = + rewriter.create(loc, elementType, transformed.complex()); + Value expReal = rewriter.create(loc, real); + Value cosImag = rewriter.create(loc, imag); + Value resultReal = rewriter.create(loc, expReal, cosImag); + Value sinImag = rewriter.create(loc, imag); + Value resultImag = rewriter.create(loc, expReal, sinImag); + + rewriter.replaceOpWithNewOp(op, type, resultReal, + resultImag); + return success(); + } +}; } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -293,7 +320,7 @@ patterns.add, ComparisonOpConversion, - DivOpConversion>(patterns.getContext()); + DivOpConversion, ExpOpConversion>(patterns.getContext()); } namespace { @@ -313,7 +340,7 @@ target.addLegalDialect(); target.addIllegalOp(); + complex::ExpOp, complex::NotEqualOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -138,6 +138,22 @@ // CHECK: %[[EQUAL:.*]] = and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 // CHECK: return %[[EQUAL]] : i1 +// CHECK-LABEL: func @complex_exp +// CHECK-SAME: %[[ARG:.*]]: complex +func @complex_exp(%arg: complex) -> complex { + %exp = complex.exp %arg: complex + return %exp : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-DAG: %[[COS_IMAG:.*]] = math.cos %[[IMAG]] : f32 +// CHECK-DAG: %[[EXP_REAL:.*]] = math.exp %[[REAL]] : f32 +// CHECK-DAG: %[[RESULT_REAL:.]] = mulf %[[EXP_REAL]], %[[COS_IMAG]] : f32 +// CHECK-DAG: %[[SIN_IMAG:.*]] = math.sin %[[IMAG]] : f32 +// CHECK-DAG: %[[RESULT_IMAG:.*]] = mulf %[[EXP_REAL]], %[[SIN_IMAG]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + // CHECK-LABEL: func @complex_neq // CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex func @complex_neq(%lhs: complex, %rhs: complex) -> i1 {