diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp --- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp +++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp @@ -487,6 +487,52 @@ return complex.cast().getElementType(); } +/// Compare complex values +/// +/// Per 10.1, the only comparisons available are .EQ. (oeq) and .NE. (une). +/// +/// For completeness, all other comparison are done on the real component only. +struct CmpcOpConversion : public FIROpConversion { + using FIROpConversion::FIROpConversion; + + mlir::LogicalResult + matchAndRewrite(fir::CmpcOp cmp, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::ValueRange operands = adaptor.getOperands(); + mlir::MLIRContext *ctxt = cmp.getContext(); + mlir::Type eleTy = convertType(getComplexEleTy(cmp.lhs().getType())); + mlir::Type resTy = convertType(cmp.getType()); + mlir::Location loc = cmp.getLoc(); + auto pos0 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(0)); + SmallVector rp{rewriter.create( + loc, eleTy, operands[0], pos0), + rewriter.create( + loc, eleTy, operands[1], pos0)}; + auto rcp = + rewriter.create(loc, resTy, rp, cmp->getAttrs()); + auto pos1 = mlir::ArrayAttr::get(ctxt, rewriter.getI32IntegerAttr(1)); + SmallVector ip{rewriter.create( + loc, eleTy, operands[0], pos1), + rewriter.create( + loc, eleTy, operands[1], pos1)}; + auto icp = + rewriter.create(loc, resTy, ip, cmp->getAttrs()); + SmallVector cp{rcp, icp}; + switch (cmp.getPredicate()) { + case mlir::arith::CmpFPredicate::OEQ: // .EQ. + rewriter.replaceOpWithNewOp(cmp, resTy, cp); + break; + case mlir::arith::CmpFPredicate::UNE: // .NE. + rewriter.replaceOpWithNewOp(cmp, resTy, cp); + break; + default: + rewriter.replaceOp(cmp, rcp.getResult()); + break; + } + return success(); + } +}; + /// convert value of from-type to value of to-type struct ConvertOpConversion : public FIROpConversion { using FIROpConversion::FIROpConversion; @@ -1514,15 +1560,17 @@ AllocaOpConversion, BoxAddrOpConversion, BoxDimsOpConversion, BoxEleSizeOpConversion, BoxIsAllocOpConversion, BoxIsArrayOpConversion, BoxIsPtrOpConversion, BoxRankOpConversion, CallOpConversion, - ConvertOpConversion, DispatchOpConversion, DispatchTableOpConversion, - DTEntryOpConversion, DivcOpConversion, EmboxCharOpConversion, - ExtractValueOpConversion, HasValueOpConversion, GlobalLenOpConversion, - GlobalOpConversion, InsertOnRangeOpConversion, InsertValueOpConversion, - IsPresentOpConversion, LoadOpConversion, NegcOpConversion, - MulcOpConversion, SelectCaseOpConversion, SelectOpConversion, - SelectRankOpConversion, SelectTypeOpConversion, StoreOpConversion, - SubcOpConversion, UnboxCharOpConversion, UndefOpConversion, - UnreachableOpConversion, ZeroOpConversion>(typeConverter); + CmpcOpConversion, ConvertOpConversion, DispatchOpConversion, + DispatchTableOpConversion, DTEntryOpConversion, DivcOpConversion, + EmboxCharOpConversion, ExtractValueOpConversion, HasValueOpConversion, + GlobalLenOpConversion, GlobalOpConversion, InsertOnRangeOpConversion, + InsertValueOpConversion, IsPresentOpConversion, LoadOpConversion, + NegcOpConversion, MulcOpConversion, SelectCaseOpConversion, + SelectOpConversion, SelectRankOpConversion, SelectTypeOpConversion, + StoreOpConversion, SubcOpConversion, UnboxCharOpConversion, + UndefOpConversion, UnreachableOpConversion, ZeroOpConversion>( + typeConverter); + mlir::populateStdToLLVMConversionPatterns(typeConverter, pattern); mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, pattern); diff --git a/flang/test/Fir/convert-to-llvm.fir b/flang/test/Fir/convert-to-llvm.fir --- a/flang/test/Fir/convert-to-llvm.fir +++ b/flang/test/Fir/convert-to-llvm.fir @@ -521,6 +521,57 @@ // ----- +// Test FIR complex compare conversion + +func @compare_complex_eq(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 { + %r = fir.cmpc "oeq", %a, %b : !fir.complex<8> + return %r : i1 +} + +// CHECK-LABEL: llvm.func @compare_complex_eq +// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>, +// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)> +// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)> +// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)> +// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)> +// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)> +// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "oeq" [[RA]], [[RB]] : f64 +// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "oeq" [[IA]], [[IB]] : f64 +// CHECK: [[RES:%.*]] = llvm.and [[RESR]], [[RESI]] : i1 +// CHECK: return [[RES]] : i1 + +func @compare_complex_ne(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 { + %r = fir.cmpc "une", %a, %b : !fir.complex<8> + return %r : i1 +} + +// CHECK-LABEL: llvm.func @compare_complex_ne +// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>, +// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)> +// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)> +// CHECK-DAG: [[IA:%.*]] = llvm.extractvalue [[A]][1 : i32] : !llvm.struct<(f64, f64)> +// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)> +// CHECK-DAG: [[IB:%.*]] = llvm.extractvalue [[B]][1 : i32] : !llvm.struct<(f64, f64)> +// CHECK-DAG: [[RESR:%.*]] = llvm.fcmp "une" [[RA]], [[RB]] : f64 +// CHECK-DAG: [[RESI:%.*]] = llvm.fcmp "une" [[IA]], [[IB]] : f64 +// CHECK: [[RES:%.*]] = llvm.or [[RESR]], [[RESI]] : i1 +// CHECK: return [[RES]] : i1 + +func @compare_complex_other(%a : !fir.complex<8>, %b : !fir.complex<8>) -> i1 { + %r = fir.cmpc "ogt", %a, %b : !fir.complex<8> + return %r : i1 +} + +// CHECK-LABEL: llvm.func @compare_complex_other +// CHECK-SAME: [[A:%.*]]: !llvm.struct<(f64, f64)>, +// CHECK-SAME: [[B:%.*]]: !llvm.struct<(f64, f64)> +// CHECK-DAG: [[RA:%.*]] = llvm.extractvalue [[A]][0 : i32] : !llvm.struct<(f64, f64)> +// CHECK-DAG: [[RB:%.*]] = llvm.extractvalue [[B]][0 : i32] : !llvm.struct<(f64, f64)> +// CHECK: [[RESR:%.*]] = llvm.fcmp "ogt" [[RA]], [[RB]] : f64 +// CHECK: return [[RESR]] : i1 + +// ----- + // Test `fir.convert` operation conversion from Float type. func @convert_from_float(%arg0 : f32) {