diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -564,4 +564,24 @@ let results = (outs Complex:$result); } +//===----------------------------------------------------------------------===// +// Conj +//===----------------------------------------------------------------------===// + +def ConjOp : ComplexUnaryOp<"conj", [SameOperandsAndResultType]> { + let summary = "Calculate the complex conjugate"; + let description = [{ + The `conj` op takes a single complex number and computes the + complex conjugate. + + Example: + + ```mlir + %a = complex.conj %b: complex + ``` + }]; + + let results = (outs Complex:$result); +} + #endif // COMPLEX_OPS diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -885,6 +885,27 @@ } }; +struct ConjOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::ConjOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto type = adaptor.getComplex().getType().cast(); + auto elementType = type.getElementType().cast(); + Value real = + rewriter.create(loc, elementType, adaptor.getComplex()); + Value imag = + rewriter.create(loc, elementType, adaptor.getComplex()); + Value negImag = rewriter.create(loc, elementType, imag); + + rewriter.replaceOpWithNewOp(op, type, real, negImag); + + return success(); + } +}; + } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -909,7 +930,8 @@ SinOpConversion, SqrtOpConversion, TanOpConversion, - TanhOpConversion>(patterns.getContext()); + TanhOpConversion, + ConjOpConversion>(patterns.getContext()); // clang-format on } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -663,3 +663,17 @@ %sqrt = complex.sqrt %arg : complex return %sqrt : complex } + +// ----- + +// CHECK-LABEL: func @complex_conj +// CHECK-SAME: %[[ARG:.*]]: complex +func.func @complex_conj(%arg: complex) -> complex { + %conj = complex.conj %arg: complex + return %conj : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK: %[[NEG_IMAG:.*]] = arith.negf %[[IMAG]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[REAL]], %[[NEG_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex \ No newline at end of file