diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -210,6 +210,25 @@ }]; } +//===----------------------------------------------------------------------===// +// NegOp +//===----------------------------------------------------------------------===// + +def NegOp : ComplexUnaryOp<"neg", [SameOperandsAndResultType]> { + let summary = "Negation operator"; + let description = [{ + The `neg` op takes a single complex number `complex` and returns `-complex`. + + Example: + + ```mlir + %a = complex.neg %b : complex<f32> + ``` + }]; + + let results = (outs Complex<AnyFloat>:$result); +} + //===----------------------------------------------------------------------===// // NotEqualOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -313,6 +313,28 @@ return success(); } }; + +struct NegOpConversion : public OpConversionPattern<complex::NegOp> { + using OpConversionPattern<complex::NegOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::NegOp op, ArrayRef<Value> operands, + ConversionPatternRewriter &rewriter) const override { + complex::NegOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + auto type = transformed.complex().getType().cast<ComplexType>(); + auto elementType = type.getElementType().cast<FloatType>(); + + Value real = + rewriter.create<complex::ReOp>(loc, elementType, transformed.complex()); + Value imag = + rewriter.create<complex::ImOp>(loc, elementType, transformed.complex()); + Value negReal = rewriter.create<NegFOp>(loc, real); + Value negImag = rewriter.create<NegFOp>(loc, imag); + rewriter.replaceOpWithNewOp<complex::CreateOp>(op, type, negReal, negImag); + return success(); + } +}; } // namespace void mlir::populateComplexToStandardConversionPatterns( @@ -320,7 +342,8 @@ patterns.add<AbsOpConversion, ComparisonOpConversion<complex::EqualOp, CmpFPredicate::OEQ>, ComparisonOpConversion<complex::NotEqualOp, CmpFPredicate::UNE>, - DivOpConversion, ExpOpConversion>(patterns.getContext()); + DivOpConversion, ExpOpConversion, NegOpConversion>( + patterns.getContext()); } namespace { @@ -340,7 +363,7 @@ target.addLegalDialect<StandardOpsDialect, math::MathDialect, complex::ComplexDialect>(); target.addIllegalOp<complex::AbsOp, complex::DivOp, complex::EqualOp, - complex::ExpOp, complex::NotEqualOp>(); + complex::ExpOp, complex::NotEqualOp, complex::NegOp>(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -154,6 +154,19 @@ // CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex<f32> // CHECK: return %[[RESULT]] : complex<f32> +// CHECK-LABEL: func @complex_neg +// CHECK-SAME: %[[ARG:.*]]: complex<f32> +func @complex_neg(%arg: complex<f32>) -> complex<f32> { + %neg = complex.neg %arg: complex<f32> + return %neg : complex<f32> +} +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex<f32> +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex<f32> +// CHECK-DAG: %[[NEG_REAL:.*]] = negf %[[REAL]] : f32 +// CHECK-DAG: %[[NEG_IMAG:.*]] = negf %[[IMAG]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[NEG_REAL]], %[[NEG_IMAG]] : complex<f32> +// CHECK: return %[[RESULT]] : complex<f32> + // CHECK-LABEL: func @complex_neq // CHECK-SAME: %[[LHS:.*]]: complex<f32>, %[[RHS:.*]]: complex<f32> func @complex_neq(%lhs: complex<f32>, %rhs: complex<f32>) -> i1 {