diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -79,6 +79,33 @@ } }; +template +struct ComponentwiseBinaryOpConversion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(BinaryComplexOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + typename BinaryComplexOp::Adaptor transformed(operands); + auto type = transformed.lhs().getType().template cast(); + auto elementType = type.getElementType().template cast(); + mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); + + Value realLhs = b.create(elementType, transformed.lhs()); + Value realRhs = b.create(elementType, transformed.rhs()); + Value resultReal = + b.create(elementType, realLhs, realRhs); + Value imagLhs = b.create(elementType, transformed.lhs()); + Value imagRhs = b.create(elementType, transformed.rhs()); + Value resultImag = + b.create(elementType, imagLhs, imagRhs); + rewriter.replaceOpWithNewOp(op, type, resultReal, + resultImag); + return success(); + } +}; + struct DivOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -554,6 +581,8 @@ AbsOpConversion, ComparisonOpConversion, ComparisonOpConversion, + ComponentwiseBinaryOpConversion, + ComponentwiseBinaryOpConversion, DivOpConversion, ExpOpConversion, LogOpConversion, @@ -578,12 +607,8 @@ populateComplexToStandardConversionPatterns(patterns); ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addIllegalOp(); + target.addLegalDialect(); + target.addLegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -14,6 +14,21 @@ // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: return %[[NORM]] : f32 +// CHECK-LABEL: func @complex_add +// CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) +func @complex_add(%lhs: complex, %rhs: complex) -> complex { + %add = complex.add %lhs, %rhs: complex + return %add : complex +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[RESULT_REAL:.*]] = addf %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = addf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex + // CHECK-LABEL: func @complex_div // CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) func @complex_div(%lhs: complex, %rhs: complex) -> complex { @@ -366,3 +381,18 @@ // CHECK: %[[SIGN:.*]] = complex.create %[[REAL_SIGN]], %[[IMAG_SIGN]] : complex // CHECK: %[[RESULT:.*]] = select %[[IS_ZERO]], %[[ARG]], %[[SIGN]] : complex // CHECK: return %[[RESULT]] : complex + +// CHECK-LABEL: func @complex_sub +// CHECK-SAME: (%[[LHS:.*]]: complex, %[[RHS:.*]]: complex) +func @complex_sub(%lhs: complex, %rhs: complex) -> complex { + %sub = complex.sub %lhs, %rhs: complex + return %sub : complex +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[RESULT_REAL:.*]] = subf %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK: %[[RESULT_IMAG:.*]] = subf %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[RESULT_REAL]], %[[RESULT_IMAG]] : complex +// CHECK: return %[[RESULT]] : complex