diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -43,16 +43,19 @@ } }; -struct EqualOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +template +struct ComparisonOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::EqualOp op, ArrayRef operands, + matchAndRewrite(ComparisonOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - complex::EqualOp::Adaptor transformed(operands); + typename ComparisonOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = - transformed.lhs().getType().cast().getElementType(); + auto type = transformed.lhs() + .getType() + .template cast() + .getElementType(); Value realLhs = rewriter.create(loc, type, transformed.lhs()); @@ -62,12 +65,15 @@ rewriter.create(loc, type, transformed.rhs()); Value imagRhs = rewriter.create(loc, type, transformed.rhs()); - Value realEqual = - rewriter.create(loc, CmpFPredicate::OEQ, realLhs, realRhs); - Value imagEqual = - rewriter.create(loc, CmpFPredicate::OEQ, imagLhs, imagRhs); - - rewriter.replaceOpWithNewOp(op, realEqual, imagEqual); + Value realComparison = rewriter.create(loc, p, realLhs, realRhs); + Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); + + if (p == CmpFPredicate::OEQ) + rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); + else if (p == CmpFPredicate::UNE) + rewriter.replaceOpWithNewOp(op, realComparison, imagComparison); + else + return failure(); return success(); } }; @@ -75,7 +81,10 @@ void mlir::populateComplexToStandardConversionPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add, + ComparisonOpConversion>( + patterns.getContext()); } namespace { @@ -94,7 +103,7 @@ ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -28,3 +28,18 @@ // CHECK-DAG: %[[IMAG_EQUAL:.*]] = cmpf oeq, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 // CHECK: %[[EQUAL:.*]] = and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 // CHECK: return %[[EQUAL]] : i1 + +// CHECK-LABEL: func @complex_neq +// CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +func @complex_neq(%lhs: complex, %rhs: complex) -> i1 { + %neq = complex.neq %lhs, %rhs: complex + return %neq : i1 +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK-DAG: %[[REAL_NOT_EQUAL:.*]] = cmpf une, %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = cmpf une, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[NOT_EQUAL:.*]] = or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 +// CHECK: return %[[NOT_EQUAL]] : i1 diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir --- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir +++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -28,3 +28,18 @@ // CHECK-DAG: %[[IMAG_EQUAL:.*]] = llvm.fcmp "oeq" %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 // CHECK: %[[EQUAL:.*]] = llvm.and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 // CHECK: llvm.return %[[EQUAL]] : i1 + +// CHECK-LABEL: llvm.func @complex_neq +// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*]], %[[RHS:.*]]: ![[C_TY:.*]]) +func @complex_neq(%lhs: complex, %rhs: complex) -> i1 { + %neq = complex.neq %lhs, %rhs: complex + return %neq : i1 +} +// CHECK: %[[REAL_LHS:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]] +// CHECK: %[[IMAG_LHS:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]] +// CHECK: %[[REAL_RHS:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]] +// CHECK: %[[IMAG_RHS:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]] +// CHECK-DAG: %[[REAL_NOT_EQUAL:.*]] = llvm.fcmp "une" %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = llvm.fcmp "une" %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[NOT_EQUAL:.*]] = llvm.or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 +// CHECK: llvm.return %[[NOT_EQUAL]] : i1