diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td --- a/llvm/lib/Target/AArch64/AArch64Combine.td +++ b/llvm/lib/Target/AArch64/AArch64Combine.td @@ -165,8 +165,8 @@ def lower_vector_fcmp : GICombineRule< (defs root:$root), (match (wip_match_opcode G_FCMP):$root, - [{ return lowerVectorFCMP(*${root}, MRI, B); }]), - (apply [{}])>; + [{ return matchLowerVectorFCMP(*${root}, MRI, B); }]), + (apply [{ applyLowerVectorFCMP(*${root}, MRI, B); }])>; def form_truncstore_matchdata : GIDefMatchData<"Register">; def form_truncstore : GICombineRule< diff --git a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp --- a/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64PostLegalizerLowering.cpp @@ -949,29 +949,45 @@ } /// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo. -bool lowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI, - MachineIRBuilder &MIB) { +bool matchLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI, + MachineIRBuilder &MIB) { assert(MI.getOpcode() == TargetOpcode::G_FCMP); const auto &ST = MI.getMF()->getSubtarget(); + Register Dst = MI.getOperand(0).getReg(); LLT DstTy = MRI.getType(Dst); if (!DstTy.isVector() || !ST.hasNEON()) return false; - const auto Pred = - static_cast(MI.getOperand(1).getPredicate()); Register LHS = MI.getOperand(2).getReg(); unsigned EltSize = MRI.getType(LHS).getScalarSizeInBits(); if (EltSize == 16 && !ST.hasFullFP16()) return false; if (EltSize != 16 && EltSize != 32 && EltSize != 64) return false; - Register RHS = MI.getOperand(3).getReg(); + + return true; +} + +/// Try to lower a vector G_FCMP \p MI into an AArch64-specific pseudo. +void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI, + MachineIRBuilder &MIB) { + assert(MI.getOpcode() == TargetOpcode::G_FCMP); + const auto &ST = MI.getMF()->getSubtarget(); + + const auto &CmpMI = cast(MI); + + Register Dst = CmpMI.getReg(0); + CmpInst::Predicate Pred = CmpMI.getCond(); + Register LHS = CmpMI.getLHSReg(); + Register RHS = CmpMI.getRHSReg(); + + LLT DstTy = MRI.getType(Dst); + auto Splat = getAArch64VectorSplat(*MRI.getVRegDef(RHS), MRI); // Compares against 0 have special target-specific pseudos. bool IsZero = Splat && Splat->isCst() && Splat->getCst() == 0; - bool Invert = false; AArch64CC::CondCode CC, CC2 = AArch64CC::AL; if (Pred == CmpInst::Predicate::FCMP_ORD && IsZero) { @@ -984,10 +1000,12 @@ } else changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert); - bool NoNans = ST.getTargetLowering()->getTargetMachine().Options.NoNaNsFPMath; - // Instead of having an apply function, just build here to simplify things. MIB.setInstrAndDebugLoc(MI); + + const bool NoNans = + ST.getTargetLowering()->getTargetMachine().Options.NoNaNsFPMath; + auto Cmp = getVectorFCMP(CC, LHS, RHS, IsZero, NoNans, MRI); Register CmpRes; if (CC2 == AArch64CC::AL) @@ -1002,7 +1020,6 @@ CmpRes = MIB.buildNot(DstTy, CmpRes).getReg(0); MRI.replaceRegWith(Dst, CmpRes); MI.eraseFromParent(); - return true; } bool matchFormTruncstore(MachineInstr &MI, MachineRegisterInfo &MRI,