diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -210,6 +210,25 @@ }]; } +//===----------------------------------------------------------------------===// +// NegOp +//===----------------------------------------------------------------------===// + +def NegOp : ComplexUnaryOp<"neg", [SameOperandsAndResultType]> { + let summary = "Negation operator"; + let description = [{ + The `neg` op takes a single complex number `complex` and returns `-complex`. + + Example: + + ```mlir + %a = complex.neg %b : complex + ``` + }]; + + let results = (outs Complex:$result); +} + //===----------------------------------------------------------------------===// // NotEqualOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -313,6 +313,28 @@ return success(); } }; + +struct NegOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::NegOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::NegOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + auto type = transformed.complex().getType().cast(); + auto elementType = type.getElementType().cast(); + + Value real = + rewriter.create(loc, elementType, transformed.complex()); + Value imag = + rewriter.create(loc, elementType, transformed.complex()); + Value negReal = rewriter.create(loc, real); + Value negImag = rewriter.create(loc, imag); + rewriter.replaceOpWithNewOp(op, type, negReal, negImag); + return success(); + } +}; } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -320,7 +342,8 @@ patterns.add, ComparisonOpConversion, - DivOpConversion, ExpOpConversion>(patterns.getContext()); + DivOpConversion, ExpOpConversion, NegOpConversion>( + patterns.getContext()); } namespace { @@ -340,7 +363,7 @@ target.addLegalDialect(); target.addIllegalOp(); + complex::ExpOp, complex::NotEqualOp, complex::NegOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -154,6 +154,19 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex // CHECK: return %[[RESULT]] : complex +// CHECK-LABEL: func @complex_neg +// CHECK-SAME: %[[ARG:.*]]: complex +func @complex_neg(%arg: complex) -> complex { + %neg = complex.neg %arg: complex + return %neg : complex +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-DAG: %[[NEG_REAL:.*]] = negf %[[REAL]] : f32 +// CHECK-DAG: %[[NEG_IMAG:.*]] = negf %[[IMAG]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + // CHECK-LABEL: func @complex_neq // CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex func @complex_neq(%lhs: complex, %rhs: complex) -> i1 {