diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -4287,9 +4287,7 @@ /// comparison may check if the operand is NAN, INF, zero, normal, etc. The /// result should be used as the condition operand for a select or branch. virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG, - const DenormalMode &Mode) const { - return SDValue(); - } + const DenormalMode &Mode) const; /// Return a target-dependent result if the input operand is not suitable for /// use with a square root estimate calculation. diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -22275,43 +22275,21 @@ Reciprocal)) { AddToWorklist(Est.getNode()); - if (Iterations) { + if (Iterations) Est = UseOneConstNR ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal) : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal); - - if (!Reciprocal) { - SDLoc DL(Op); - EVT CCVT = getSetCCResultType(VT); - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); - DenormalMode DenormMode = DAG.getDenormalMode(VT); - // Try the target specific test first. - SDValue Test = TLI.getSqrtInputTest(Op, DAG, DenormMode); - if (!Test) { - // If no test provided by target, testing it with denormal inputs to - // avoid wrong estimate. - if (DenormMode.Input == DenormalMode::IEEE) { - // This is specifically a check for the handling of denormal inputs, - // not the result. - - // Test = fabs(X) < SmallestNormal - const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); - APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem); - SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT); - SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op); - Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT); - } else - // Test = X == 0.0 - Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); - } - - // The estimate is now completely wrong if the input was exactly 0.0 or - // possibly a denormal. Force the answer to 0.0 or value provided by - // target for those cases. - Est = DAG.getNode( - Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, - Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est); - } + if (!Reciprocal) { + SDLoc DL(Op); + // Try the target specific test first. + SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT)); + + // The estimate is now completely wrong if the input was exactly 0.0 or + // possibly a denormal. Force the answer to 0.0 or value provided by + // target for those cases. + Est = DAG.getNode( + Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, + Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est); } return Est; } diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5841,6 +5841,28 @@ return false; } +SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG, + const DenormalMode &Mode) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + // Testing it with denormal inputs to avoid wrong estimate. + if (Mode.Input == DenormalMode::IEEE) { + // This is specifically a check for the handling of denormal inputs, + // not the result. + + // Test = fabs(X) < SmallestNormal + const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); + APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem); + SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT); + SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op); + return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT); + } + // Test = X == 0.0 + return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); +} + SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps, bool OptForSize, NegatibleCost &Cost, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -961,6 +961,10 @@ bool Reciprocal) const override; SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps) const override; + SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG, + const DenormalMode &Mode) const override; + SDValue getSqrtResultForDenormInput(SDValue Operand, + SelectionDAG &DAG) const override; unsigned combineRepeatedFPDivisors() const override; ConstraintType getConstraintType(StringRef Constraint) const override; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7471,6 +7471,22 @@ return SDValue(); } +SDValue +AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG, + const DenormalMode &Mode) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); +} + +SDValue +AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op, + SelectionDAG &DAG) const { + return Op; +} + SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, @@ -7494,17 +7510,8 @@ Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags); Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags); } - if (!Reciprocal) { - EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), - VT); - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); - SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ); - + if (!Reciprocal) Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags); - // Correct the result if the operand is 0.0. - Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, - VT, Eq, Operand, Estimate); - } ExtraSteps = 0; return Estimate; diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -12133,7 +12133,7 @@ if (!isTypeLegal(MVT::i1) || (VT != MVT::f64 && ((VT != MVT::v2f64 && VT != MVT::v4f32) || !Subtarget.hasVSX()))) - return SDValue(); + return TargetLowering::getSqrtInputTest(Op, DAG, Mode); SDLoc DL(Op); // The output register of FTSQRT is CR field.