diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -42,11 +42,40 @@ return success(); } }; + +struct EqualOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(complex::EqualOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + complex::EqualOp::Adaptor transformed(operands); + auto loc = op.getLoc(); + auto type = + transformed.lhs().getType().cast().getElementType(); + + Value realLhs = + rewriter.create(loc, type, transformed.lhs()); + Value imagLhs = + rewriter.create(loc, type, transformed.lhs()); + Value realRhs = + rewriter.create(loc, type, transformed.rhs()); + Value imagRhs = + rewriter.create(loc, type, transformed.rhs()); + Value realEqual = + rewriter.create(loc, CmpFPredicate::OEQ, realLhs, realRhs); + Value imagEqual = + rewriter.create(loc, CmpFPredicate::OEQ, imagLhs, imagRhs); + + rewriter.replaceOpWithNewOp(op, realEqual, imagEqual); + return success(); + } +}; } // namespace void mlir::populateComplexToStandardConversionPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); } namespace { @@ -65,7 +94,7 @@ ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -14,3 +14,17 @@ // CHECK: %[[NORM:.*]] = math.sqrt %[[SQ_NORM]] : f32 // CHECK: return %[[NORM]] : f32 +// CHECK-LABEL: func @complex_eq +// CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +func @complex_eq(%lhs: complex, %rhs: complex) -> i1 { + %eq = complex.eq %lhs, %rhs: complex + return %eq : i1 +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK-DAG: %[[REAL_EQUAL:.*]] = cmpf oeq, %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-DAG: %[[IMAG_EQUAL:.*]] = cmpf oeq, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[EQUAL:.*]] = and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 +// CHECK: return %[[EQUAL]] : i1 diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir --- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir +++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -14,3 +14,17 @@ // CHECK: %[[NORM:.*]] = "llvm.intr.sqrt"(%[[SQ_NORM]]) : (f32) -> f32 // CHECK: llvm.return %[[NORM]] : f32 +// CHECK-LABEL: llvm.func @complex_eq +// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*]], %[[RHS:.*]]: ![[C_TY:.*]]) +func @complex_eq(%lhs: complex, %rhs: complex) -> i1 { + %eq = complex.eq %lhs, %rhs: complex + return %eq : i1 +} +// CHECK: %[[REAL_LHS:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]] +// CHECK: %[[IMAG_LHS:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]] +// CHECK: %[[REAL_RHS:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]] +// CHECK: %[[IMAG_RHS:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]] +// CHECK-DAG: %[[REAL_EQUAL:.*]] = llvm.fcmp "oeq" %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-DAG: %[[IMAG_EQUAL:.*]] = llvm.fcmp "oeq" %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[EQUAL:.*]] = llvm.and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 +// CHECK: llvm.return %[[EQUAL]] : i1