diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -9,6 +9,7 @@ #include "mlir/Conversion/ComplexToStandard/ComplexToStandard.h" #include +#include #include "../PassDetail.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -43,16 +44,22 @@ } }; -struct EqualOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +template +struct ComparisonOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using ResultCombiner = + std::conditional_t::value, + AndOp, OrOp>; LogicalResult - matchAndRewrite(complex::EqualOp op, ArrayRef operands, + matchAndRewrite(ComparisonOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - complex::EqualOp::Adaptor transformed(operands); + typename ComparisonOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = - transformed.lhs().getType().cast().getElementType(); + auto type = transformed.lhs() + .getType() + .template cast() + .getElementType(); Value realLhs = rewriter.create(loc, type, transformed.lhs()); @@ -62,12 +69,11 @@ rewriter.create(loc, type, transformed.rhs()); Value imagRhs = rewriter.create(loc, type, transformed.rhs()); - Value realEqual = - rewriter.create(loc, CmpFPredicate::OEQ, realLhs, realRhs); - Value imagEqual = - rewriter.create(loc, CmpFPredicate::OEQ, imagLhs, imagRhs); + Value realComparison = rewriter.create(loc, p, realLhs, realRhs); + Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); - rewriter.replaceOpWithNewOp(op, realEqual, imagEqual); + rewriter.replaceOpWithNewOp(op, realComparison, + imagComparison); return success(); } }; @@ -75,7 +81,10 @@ void mlir::populateComplexToStandardConversionPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add, + ComparisonOpConversion>( + patterns.getContext()); } namespace { @@ -94,7 +103,7 @@ ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + target.addIllegalOp(); if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -28,3 +28,18 @@ // CHECK-DAG: %[[IMAG_EQUAL:.*]] = cmpf oeq, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 // CHECK: %[[EQUAL:.*]] = and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 // CHECK: return %[[EQUAL]] : i1 + +// CHECK-LABEL: func @complex_neq +// CHECK-SAME: %[[LHS:.*]]: complex, %[[RHS:.*]]: complex +func @complex_neq(%lhs: complex, %rhs: complex) -> i1 { + %neq = complex.neq %lhs, %rhs: complex + return %neq : i1 +} +// CHECK: %[[REAL_LHS:.*]] = complex.re %[[LHS]] : complex +// CHECK: %[[IMAG_LHS:.*]] = complex.im %[[LHS]] : complex +// CHECK: %[[REAL_RHS:.*]] = complex.re %[[RHS]] : complex +// CHECK: %[[IMAG_RHS:.*]] = complex.im %[[RHS]] : complex +// CHECK-DAG: %[[REAL_NOT_EQUAL:.*]] = cmpf une, %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = cmpf une, %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[NOT_EQUAL:.*]] = or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 +// CHECK: return %[[NOT_EQUAL]] : i1 diff --git a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir --- a/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir +++ b/mlir/test/Conversion/ComplexToStandard/full-conversion.mlir @@ -28,3 +28,18 @@ // CHECK-DAG: %[[IMAG_EQUAL:.*]] = llvm.fcmp "oeq" %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 // CHECK: %[[EQUAL:.*]] = llvm.and %[[REAL_EQUAL]], %[[IMAG_EQUAL]] : i1 // CHECK: llvm.return %[[EQUAL]] : i1 + +// CHECK-LABEL: llvm.func @complex_neq +// CHECK-SAME: %[[LHS:.*]]: ![[C_TY:.*]], %[[RHS:.*]]: ![[C_TY:.*]]) +func @complex_neq(%lhs: complex, %rhs: complex) -> i1 { + %neq = complex.neq %lhs, %rhs: complex + return %neq : i1 +} +// CHECK: %[[REAL_LHS:.*]] = llvm.extractvalue %[[LHS]][0] : ![[C_TY]] +// CHECK: %[[IMAG_LHS:.*]] = llvm.extractvalue %[[LHS]][1] : ![[C_TY]] +// CHECK: %[[REAL_RHS:.*]] = llvm.extractvalue %[[RHS]][0] : ![[C_TY]] +// CHECK: %[[IMAG_RHS:.*]] = llvm.extractvalue %[[RHS]][1] : ![[C_TY]] +// CHECK-DAG: %[[REAL_NOT_EQUAL:.*]] = llvm.fcmp "une" %[[REAL_LHS]], %[[REAL_RHS]] : f32 +// CHECK-DAG: %[[IMAG_NOT_EQUAL:.*]] = llvm.fcmp "une" %[[IMAG_LHS]], %[[IMAG_RHS]] : f32 +// CHECK: %[[NOT_EQUAL:.*]] = llvm.or %[[REAL_NOT_EQUAL]], %[[IMAG_NOT_EQUAL]] : i1 +// CHECK: llvm.return %[[NOT_EQUAL]] : i1