diff --git a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h --- a/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/CodeGen/Register.h" #include "llvm/Support/LowLevelTypeImpl.h" +#include "llvm/IR/InstrTypes.h" #include namespace llvm { @@ -755,6 +756,10 @@ /// Transform G_ADD(G_SUB(y, x), x) to y. bool matchAddSubSameReg(MachineInstr &MI, Register &Src); + /// \returns true if it is possible to simplify a select instruction \p MI + /// to a min/max instruction of some sort. + bool matchSimplifySelectToMinMax(MachineInstr &MI, BuildFnTy &MatchInfo); + private: /// Given a non-indexed load or store instruction \p MI, find an offset that /// can be usefully and legally folded into it as a post-indexing operation. @@ -800,6 +805,49 @@ /// a re-association of its operands would break an existing legal addressing /// mode that the address computation currently represents. bool reassociationCanBreakAddressingModePattern(MachineInstr &PtrAdd); + + /// Behavior when a floating point min/max is given one NaN and one + /// non-NaN as input. + enum class SelectPatternNaNBehaviour { + NOT_APPLICABLE = 0, /// NaN behavior not applicable. + RETURNS_NAN, /// Given one NaN input, returns the NaN. + RETURNS_OTHER, /// Given one NaN input, returns the non-NaN. + RETURNS_ANY /// Given one NaN input, can return either (or both operands are + /// known non-NaN.) + }; + + /// \returns which of \p LHS and \p RHS would be the result of a non-equality + /// floating point comparison where one of \p LHS and \p RHS may be NaN. + /// + /// If both \p LHS and \p RHS may be NaN, returns + /// SelectPatternNaNBehaviour::NOT_APPLICABLE. + SelectPatternNaNBehaviour + computeRetValAgainstNaN(Register LHS, Register RHS, + bool IsOrderedComparison) const; + + /// Determines the floating point min/max opcode which should be used for + /// a G_SELECT fed by a G_FCMP with predicate \p Pred. + /// + /// \returns 0 if this G_SELECT should not be combined to a floating point + /// min or max. If it should be combined, returns one of + /// + /// * G_FMAXNUM + /// * G_FMAXIMUM + /// * G_FMINNUM + /// * G_FMINIMUM + /// + /// Helper function for matchFPSelectToMinMax. + unsigned getFPMinMaxOpcForSelect(CmpInst::Predicate Pred, LLT DstTy, + SelectPatternNaNBehaviour VsNaNRetVal) const; + + /// Handle floating point cases for matchSimplifySelectToMinMax. + /// + /// E.g. + /// + /// select (fcmp uge x, 1.0) x, 1.0 -> fmax x, 1.0 + /// select (fcmp uge x, 1.0) 1.0, x -> fminnm x, 1.0 + bool matchFPSelectToMinMax(Register Dst, Register Cond, Register TrueVal, + Register FalseVal, BuildFnTy &MatchInfo); }; } // namespace llvm diff --git a/llvm/include/llvm/Target/GlobalISel/Combine.td b/llvm/include/llvm/Target/GlobalISel/Combine.td --- a/llvm/include/llvm/Target/GlobalISel/Combine.td +++ b/llvm/include/llvm/Target/GlobalISel/Combine.td @@ -929,6 +929,12 @@ (apply [{ return Helper.replaceSingleDefInstWithReg(*${root}, ${matchinfo}); }])>; +def select_to_minmax: GICombineRule< + (defs root:$root, build_fn_matchinfo:$info), + (match (wip_match_opcode G_SELECT):$root, + [{ return Helper.matchSimplifySelectToMinMax(*${root}, ${info}); }]), + (apply [{ Helper.applyBuildFn(*${root}, ${info}); }])>; + // FIXME: These should use the custom predicate feature once it lands. def undef_combines : GICombineGroup<[undef_to_fp_zero, undef_to_int_zero, undef_to_negative_one, diff --git a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp --- a/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp +++ b/llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp @@ -5818,6 +5818,138 @@ return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); } +unsigned CombinerHelper::getFPMinMaxOpcForSelect( + CmpInst::Predicate Pred, LLT DstTy, + SelectPatternNaNBehaviour VsNaNRetVal) const { + assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE && + "Expected a NaN behaviour?"); + // Choose an opcode based off of legality or the behaviour when one of the + // LHS/RHS may be NaN. + switch (Pred) { + default: + return 0; + case CmpInst::FCMP_UGT: + case CmpInst::FCMP_UGE: + case CmpInst::FCMP_OGT: + case CmpInst::FCMP_OGE: + if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) + return TargetOpcode::G_FMAXNUM; + if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) + return TargetOpcode::G_FMAXIMUM; + if (isLegal({TargetOpcode::G_FMAXNUM, {DstTy}})) + return TargetOpcode::G_FMAXNUM; + if (isLegal({TargetOpcode::G_FMAXIMUM, {DstTy}})) + return TargetOpcode::G_FMAXIMUM; + return 0; + case CmpInst::FCMP_ULT: + case CmpInst::FCMP_ULE: + case CmpInst::FCMP_OLT: + case CmpInst::FCMP_OLE: + if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) + return TargetOpcode::G_FMINNUM; + if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) + return TargetOpcode::G_FMINIMUM; + if (isLegal({TargetOpcode::G_FMINNUM, {DstTy}})) + return TargetOpcode::G_FMINNUM; + if (!isLegal({TargetOpcode::G_FMINIMUM, {DstTy}})) + return 0; + return TargetOpcode::G_FMINIMUM; + } +} + +CombinerHelper::SelectPatternNaNBehaviour +CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS, + bool IsOrderedComparison) const { + bool LHSSafe = isKnownNeverNaN(LHS, MRI); + bool RHSSafe = isKnownNeverNaN(RHS, MRI); + // Completely unsafe. + if (!LHSSafe && !RHSSafe) + return SelectPatternNaNBehaviour::NOT_APPLICABLE; + if (LHSSafe && RHSSafe) + return SelectPatternNaNBehaviour::RETURNS_ANY; + // An ordered comparison will return false when given a NaN, so it + // returns the RHS. + if (IsOrderedComparison) + return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN + : SelectPatternNaNBehaviour::RETURNS_OTHER; + // An unordered comparison will return true when given a NaN, so it + // returns the LHS. + return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER + : SelectPatternNaNBehaviour::RETURNS_NAN; +} + +bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond, + Register TrueVal, Register FalseVal, + BuildFnTy &MatchInfo) { + // Match: select (fcmp cond x, y) x, y + // select (fcmp cond x, y) y, x + // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition. + LLT DstTy = MRI.getType(Dst); + // Bail out early on pointers, since we'll never want to fold to a min/max. + // TODO: Handle vectors. + if (DstTy.isPointer() || DstTy.isVector()) + return false; + // Match a floating point compare with a less-than/greater-than predicate. + // TODO: Allow multiple users of the compare if they are all selects. + CmpInst::Predicate Pred; + Register CmpLHS, CmpRHS; + if (!mi_match(Cond, MRI, + m_OneNonDBGUse( + m_GFCmp(m_Pred(Pred), m_Reg(CmpLHS), m_Reg(CmpRHS)))) || + CmpInst::isEquality(Pred)) + return false; + SelectPatternNaNBehaviour ResWithKnownNaNInfo = + computeRetValAgainstNaN(CmpLHS, CmpRHS, CmpInst::isOrdered(Pred)); + if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE) + return false; + if (TrueVal == CmpRHS && FalseVal == CmpLHS) { + std::swap(CmpLHS, CmpRHS); + Pred = CmpInst::getSwappedPredicate(Pred); + if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN) + ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER; + else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER) + ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN; + } + if (TrueVal != CmpLHS || FalseVal != CmpRHS) + return false; + // Decide what type of max/min this should be based off of the predicate. + unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, ResWithKnownNaNInfo); + if (!Opc || !isLegal({Opc, {DstTy}})) + return false; + // Comparisons between signed zero and zero may have different results... + // unless we have fmaximum/fminimum. In that case, we know -0 < 0. + if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) { + // We don't know if a comparison between two 0s will give us a consistent + // result. Be conservative and only proceed if at least one side is + // non-zero. + auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpLHS, MRI); + if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) { + KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpRHS, MRI); + if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) + return false; + } + } + MatchInfo = [=](MachineIRBuilder &B) { + B.buildInstr(Opc, {Dst}, {CmpLHS, CmpRHS}); + }; + return true; +} + +bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI, + BuildFnTy &MatchInfo) { + // TODO: Handle integer cases. + assert(MI.getOpcode() == TargetOpcode::G_SELECT); + // Condition may be fed by a truncated compare. + Register Cond = MI.getOperand(1).getReg(); + Register MaybeTrunc; + if (mi_match(Cond, MRI, m_OneNonDBGUse(m_GTrunc(m_Reg(MaybeTrunc))))) + Cond = MaybeTrunc; + Register Dst = MI.getOperand(0).getReg(); + Register TrueVal = MI.getOperand(2).getReg(); + Register FalseVal = MI.getOperand(3).getReg(); + return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo); +} + bool CombinerHelper::tryCombine(MachineInstr &MI) { if (tryCombineCopy(MI)) return true; diff --git a/llvm/lib/Target/AArch64/AArch64Combine.td b/llvm/lib/Target/AArch64/AArch64Combine.td --- a/llvm/lib/Target/AArch64/AArch64Combine.td +++ b/llvm/lib/Target/AArch64/AArch64Combine.td @@ -228,6 +228,7 @@ select_combines, fold_merge_to_zext, constant_fold, identity_combines, ptr_add_immed_chain, overlapping_and, - split_store_zero_128, undef_combines]> { + split_store_zero_128, undef_combines, + select_to_minmax]> { let DisableRuleOption = "aarch64postlegalizercombiner-disable-rule"; }