diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -4287,9 +4287,7 @@ /// comparison may check if the operand is NAN, INF, zero, normal, etc. The /// result should be used as the condition operand for a select or branch. virtual SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG, - const DenormalMode &Mode) const { - return SDValue(); - } + const DenormalMode &Mode) const; /// Return a target-dependent result if the input operand is not suitable for /// use with a square root estimate calculation. diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -22271,43 +22271,21 @@ Reciprocal)) { AddToWorklist(Est.getNode()); - if (Iterations) { + if (Iterations) Est = UseOneConstNR ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal) : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal); - - if (!Reciprocal) { - SDLoc DL(Op); - EVT CCVT = getSetCCResultType(VT); - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); - DenormalMode DenormMode = DAG.getDenormalMode(VT); - // Try the target specific test first. - SDValue Test = TLI.getSqrtInputTest(Op, DAG, DenormMode); - if (!Test) { - // If no test provided by target, testing it with denormal inputs to - // avoid wrong estimate. - if (DenormMode.Input == DenormalMode::IEEE) { - // This is specifically a check for the handling of denormal inputs, - // not the result. - - // Test = fabs(X) < SmallestNormal - const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); - APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem); - SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT); - SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op); - Test = DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT); - } else - // Test = X == 0.0 - Test = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); - } - - // The estimate is now completely wrong if the input was exactly 0.0 or - // possibly a denormal. Force the answer to 0.0 or value provided by - // target for those cases. - Est = DAG.getNode( - Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, - Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est); - } + if (!Reciprocal) { + SDLoc DL(Op); + // Try the target specific test first. + SDValue Test = TLI.getSqrtInputTest(Op, DAG, DAG.getDenormalMode(VT)); + + // The estimate is now completely wrong if the input was exactly 0.0 or + // possibly a denormal. Force the answer to 0.0 or value provided by + // target for those cases. + Est = DAG.getNode( + Test.getValueType().isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, + Test, TLI.getSqrtResultForDenormInput(Op, DAG), Est); } return Est; } diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -5851,6 +5851,28 @@ return false; } +SDValue TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG, + const DenormalMode &Mode) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + // Testing it with denormal inputs to avoid wrong estimate. + if (Mode.Input == DenormalMode::IEEE) { + // This is specifically a check for the handling of denormal inputs, + // not the result. + + // Test = fabs(X) < SmallestNormal + const fltSemantics &FltSem = DAG.EVTToAPFloatSemantics(VT); + APFloat SmallestNorm = APFloat::getSmallestNormalized(FltSem); + SDValue NormC = DAG.getConstantFP(SmallestNorm, DL, VT); + SDValue Fabs = DAG.getNode(ISD::FABS, DL, VT, Op); + return DAG.getSetCC(DL, CCVT, Fabs, NormC, ISD::SETLT); + } + // Test = X == 0.0 + return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); +} + SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG, bool LegalOps, bool OptForSize, NegatibleCost &Cost, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -961,6 +961,8 @@ bool Reciprocal) const override; SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps) const override; + SDValue getSqrtInputTest(SDValue Operand, SelectionDAG &DAG, + const DenormalMode &Mode) const override; unsigned combineRepeatedFPDivisors() const override; ConstraintType getConstraintType(StringRef Constraint) const override; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7471,6 +7471,16 @@ return SDValue(); } +SDValue +AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG, + const DenormalMode &Mode) const { + SDLoc DL(Op); + EVT VT = Op.getValueType(); + EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT); + SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); + return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ); +} + SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled, int &ExtraSteps, @@ -7494,17 +7504,8 @@ Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags); Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags); } - if (!Reciprocal) { - EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), - VT); - SDValue FPZero = DAG.getConstantFP(0.0, DL, VT); - SDValue Eq = DAG.getSetCC(DL, CCVT, Operand, FPZero, ISD::SETEQ); - + if (!Reciprocal) Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags); - // Correct the result if the operand is 0.0. - Estimate = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, - VT, Eq, Operand, Estimate); - } ExtraSteps = 0; return Estimate; diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp --- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp +++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp @@ -12133,7 +12133,7 @@ if (!isTypeLegal(MVT::i1) || (VT != MVT::f64 && ((VT != MVT::v2f64 && VT != MVT::v4f32) || !Subtarget.hasVSX()))) - return SDValue(); + return TargetLowering::getSqrtInputTest(Op, DAG, Mode); SDLoc DL(Op); // The output register of FTSQRT is CR field. diff --git a/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll b/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll --- a/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll +++ b/llvm/test/CodeGen/AArch64/sqrt-fastmath.ll @@ -72,8 +72,8 @@ ; CHECK-NEXT: frsqrts v2.2s, v0.2s, v2.2s ; CHECK-NEXT: fmul v2.2s, v2.2s, v0.2s ; CHECK-NEXT: fmul v1.2s, v1.2s, v2.2s -; CHECK-NEXT: fcmeq v2.2s, v0.2s, #0.0 -; CHECK-NEXT: bif v0.8b, v1.8b, v2.8b +; CHECK-NEXT: fcmeq v0.2s, v0.2s, #0.0 +; CHECK-NEXT: bic v0.8b, v1.8b, v0.8b ; CHECK-NEXT: ret %1 = tail call fast <2 x float> @llvm.sqrt.v2f32(<2 x float> %a) ret <2 x float> %1 @@ -95,8 +95,8 @@ ; CHECK-NEXT: frsqrts v2.4s, v0.4s, v2.4s ; CHECK-NEXT: fmul v2.4s, v2.4s, v0.4s ; CHECK-NEXT: fmul v1.4s, v1.4s, v2.4s -; CHECK-NEXT: fcmeq v2.4s, v0.4s, #0.0 -; CHECK-NEXT: bif v0.16b, v1.16b, v2.16b +; CHECK-NEXT: fcmeq v0.4s, v0.4s, #0.0 +; CHECK-NEXT: bic v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %1 = tail call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> %a) ret <4 x float> %1 @@ -112,25 +112,25 @@ ; CHECK-LABEL: f8sqrt: ; CHECK: // %bb.0: ; CHECK-NEXT: frsqrte v2.4s, v0.4s -; CHECK-NEXT: fmul v3.4s, v2.4s, v2.4s -; CHECK-NEXT: frsqrts v3.4s, v0.4s, v3.4s -; CHECK-NEXT: fmul v2.4s, v2.4s, v3.4s -; CHECK-NEXT: fmul v3.4s, v2.4s, v2.4s -; CHECK-NEXT: frsqrts v3.4s, v0.4s, v3.4s -; CHECK-NEXT: fmul v3.4s, v3.4s, v0.4s -; CHECK-NEXT: fmul v2.4s, v2.4s, v3.4s -; CHECK-NEXT: fcmeq v3.4s, v0.4s, #0.0 -; CHECK-NEXT: bif v0.16b, v2.16b, v3.16b -; CHECK-NEXT: frsqrte v2.4s, v1.4s -; CHECK-NEXT: fmul v3.4s, v2.4s, v2.4s -; CHECK-NEXT: frsqrts v3.4s, v1.4s, v3.4s -; CHECK-NEXT: fmul v2.4s, v2.4s, v3.4s -; CHECK-NEXT: fmul v3.4s, v2.4s, v2.4s -; CHECK-NEXT: frsqrts v3.4s, v1.4s, v3.4s -; CHECK-NEXT: fmul v3.4s, v3.4s, v1.4s -; CHECK-NEXT: fmul v2.4s, v2.4s, v3.4s -; CHECK-NEXT: fcmeq v3.4s, v1.4s, #0.0 -; CHECK-NEXT: bif v1.16b, v2.16b, v3.16b +; CHECK-NEXT: fmul v4.4s, v2.4s, v2.4s +; CHECK-NEXT: frsqrte v3.4s, v1.4s +; CHECK-NEXT: frsqrts v4.4s, v0.4s, v4.4s +; CHECK-NEXT: fmul v2.4s, v2.4s, v4.4s +; CHECK-NEXT: fmul v4.4s, v3.4s, v3.4s +; CHECK-NEXT: frsqrts v4.4s, v1.4s, v4.4s +; CHECK-NEXT: fmul v3.4s, v3.4s, v4.4s +; CHECK-NEXT: fmul v4.4s, v2.4s, v2.4s +; CHECK-NEXT: frsqrts v4.4s, v0.4s, v4.4s +; CHECK-NEXT: fmul v4.4s, v4.4s, v0.4s +; CHECK-NEXT: fmul v2.4s, v2.4s, v4.4s +; CHECK-NEXT: fmul v4.4s, v3.4s, v3.4s +; CHECK-NEXT: frsqrts v4.4s, v1.4s, v4.4s +; CHECK-NEXT: fmul v4.4s, v4.4s, v1.4s +; CHECK-NEXT: fmul v3.4s, v3.4s, v4.4s +; CHECK-NEXT: fcmeq v0.4s, v0.4s, #0.0 +; CHECK-NEXT: fcmeq v1.4s, v1.4s, #0.0 +; CHECK-NEXT: bic v0.16b, v2.16b, v0.16b +; CHECK-NEXT: bic v1.16b, v3.16b, v1.16b ; CHECK-NEXT: ret %1 = tail call fast <8 x float> @llvm.sqrt.v8f32(<8 x float> %a) ret <8 x float> %1 @@ -207,8 +207,8 @@ ; CHECK-NEXT: frsqrts v2.2d, v0.2d, v2.2d ; CHECK-NEXT: fmul v2.2d, v2.2d, v0.2d ; CHECK-NEXT: fmul v1.2d, v1.2d, v2.2d -; CHECK-NEXT: fcmeq v2.2d, v0.2d, #0.0 -; CHECK-NEXT: bif v0.16b, v1.16b, v2.16b +; CHECK-NEXT: fcmeq v0.2d, v0.2d, #0.0 +; CHECK-NEXT: bic v0.16b, v1.16b, v0.16b ; CHECK-NEXT: ret %1 = tail call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> %a) ret <2 x double> %1 @@ -224,31 +224,31 @@ ; CHECK-LABEL: d4sqrt: ; CHECK: // %bb.0: ; CHECK-NEXT: frsqrte v2.2d, v0.2d -; CHECK-NEXT: fmul v3.2d, v2.2d, v2.2d -; CHECK-NEXT: frsqrts v3.2d, v0.2d, v3.2d -; CHECK-NEXT: fmul v2.2d, v2.2d, v3.2d -; CHECK-NEXT: fmul v3.2d, v2.2d, v2.2d -; CHECK-NEXT: frsqrts v3.2d, v0.2d, v3.2d -; CHECK-NEXT: fmul v2.2d, v2.2d, v3.2d -; CHECK-NEXT: fmul v3.2d, v2.2d, v2.2d -; CHECK-NEXT: frsqrts v3.2d, v0.2d, v3.2d -; CHECK-NEXT: fmul v3.2d, v3.2d, v0.2d -; CHECK-NEXT: fmul v2.2d, v2.2d, v3.2d -; CHECK-NEXT: fcmeq v3.2d, v0.2d, #0.0 -; CHECK-NEXT: bif v0.16b, v2.16b, v3.16b -; CHECK-NEXT: frsqrte v2.2d, v1.2d -; CHECK-NEXT: fmul v3.2d, v2.2d, v2.2d -; CHECK-NEXT: frsqrts v3.2d, v1.2d, v3.2d -; CHECK-NEXT: fmul v2.2d, v2.2d, v3.2d -; CHECK-NEXT: fmul v3.2d, v2.2d, v2.2d -; CHECK-NEXT: frsqrts v3.2d, v1.2d, v3.2d -; CHECK-NEXT: fmul v2.2d, v2.2d, v3.2d -; CHECK-NEXT: fmul v3.2d, v2.2d, v2.2d -; CHECK-NEXT: frsqrts v3.2d, v1.2d, v3.2d -; CHECK-NEXT: fmul v3.2d, v3.2d, v1.2d -; CHECK-NEXT: fmul v2.2d, v2.2d, v3.2d -; CHECK-NEXT: fcmeq v3.2d, v1.2d, #0.0 -; CHECK-NEXT: bif v1.16b, v2.16b, v3.16b +; CHECK-NEXT: fmul v4.2d, v2.2d, v2.2d +; CHECK-NEXT: frsqrte v3.2d, v1.2d +; CHECK-NEXT: frsqrts v4.2d, v0.2d, v4.2d +; CHECK-NEXT: fmul v2.2d, v2.2d, v4.2d +; CHECK-NEXT: fmul v4.2d, v3.2d, v3.2d +; CHECK-NEXT: frsqrts v4.2d, v1.2d, v4.2d +; CHECK-NEXT: fmul v3.2d, v3.2d, v4.2d +; CHECK-NEXT: fmul v4.2d, v2.2d, v2.2d +; CHECK-NEXT: frsqrts v4.2d, v0.2d, v4.2d +; CHECK-NEXT: fmul v2.2d, v2.2d, v4.2d +; CHECK-NEXT: fmul v4.2d, v3.2d, v3.2d +; CHECK-NEXT: frsqrts v4.2d, v1.2d, v4.2d +; CHECK-NEXT: fmul v3.2d, v3.2d, v4.2d +; CHECK-NEXT: fmul v4.2d, v2.2d, v2.2d +; CHECK-NEXT: frsqrts v4.2d, v0.2d, v4.2d +; CHECK-NEXT: fmul v4.2d, v4.2d, v0.2d +; CHECK-NEXT: fmul v2.2d, v2.2d, v4.2d +; CHECK-NEXT: fmul v4.2d, v3.2d, v3.2d +; CHECK-NEXT: frsqrts v4.2d, v1.2d, v4.2d +; CHECK-NEXT: fmul v4.2d, v4.2d, v1.2d +; CHECK-NEXT: fmul v3.2d, v3.2d, v4.2d +; CHECK-NEXT: fcmeq v0.2d, v0.2d, #0.0 +; CHECK-NEXT: fcmeq v1.2d, v1.2d, #0.0 +; CHECK-NEXT: bic v0.16b, v2.16b, v0.16b +; CHECK-NEXT: bic v1.16b, v3.16b, v1.16b ; CHECK-NEXT: ret %1 = tail call fast <4 x double> @llvm.sqrt.v4f64(<4 x double> %a) ret <4 x double> %1 @@ -515,10 +515,10 @@ ; CHECK-NEXT: frsqrts d2, d0, d2 ; CHECK-NEXT: fmul d1, d1, d2 ; CHECK-NEXT: fcmp d0, #0.0 -; CHECK-NEXT: fmul d1, d0, d1 -; CHECK-NEXT: fcsel d0, d0, d1, eq -; CHECK-NEXT: str d0, [x0] -; CHECK-NEXT: mov v0.16b, v1.16b +; CHECK-NEXT: fmul d0, d0, d1 +; CHECK-NEXT: fmov d1, xzr +; CHECK-NEXT: fcsel d1, d1, d0, eq +; CHECK-NEXT: str d1, [x0] ; CHECK-NEXT: ret %sqrt = call fast double @llvm.sqrt.f64(double %x) store double %sqrt, double* %p @@ -636,28 +636,29 @@ ; CHECK-LABEL: sqrt_simplify_before_recip_4_uses: ; CHECK: // %bb.0: ; CHECK-NEXT: frsqrte d1, d0 -; CHECK-NEXT: fmul d3, d1, d1 -; CHECK-NEXT: frsqrts d3, d0, d3 -; CHECK-NEXT: fmul d1, d1, d3 -; CHECK-NEXT: fmul d3, d1, d1 -; CHECK-NEXT: frsqrts d3, d0, d3 -; CHECK-NEXT: fmul d1, d1, d3 +; CHECK-NEXT: fmul d4, d1, d1 +; CHECK-NEXT: frsqrts d4, d0, d4 +; CHECK-NEXT: fmul d1, d1, d4 +; CHECK-NEXT: fmul d4, d1, d1 +; CHECK-NEXT: frsqrts d4, d0, d4 ; CHECK-NEXT: mov x8, #4631107791820423168 -; CHECK-NEXT: fmul d3, d1, d1 -; CHECK-NEXT: fmov d2, x8 +; CHECK-NEXT: fmul d1, d1, d4 +; CHECK-NEXT: fmov d3, x8 ; CHECK-NEXT: mov x8, #140737488355328 -; CHECK-NEXT: frsqrts d3, d0, d3 +; CHECK-NEXT: fmul d4, d1, d1 ; CHECK-NEXT: movk x8, #16453, lsl #48 -; CHECK-NEXT: fmul d1, d1, d3 -; CHECK-NEXT: fcmp d0, #0.0 +; CHECK-NEXT: frsqrts d4, d0, d4 +; CHECK-NEXT: fmul d1, d1, d4 ; CHECK-NEXT: fmov d4, x8 -; CHECK-NEXT: fmul d3, d0, d1 -; CHECK-NEXT: fmul d2, d1, d2 +; CHECK-NEXT: fcmp d0, #0.0 +; CHECK-NEXT: fmov d2, xzr +; CHECK-NEXT: fmul d3, d1, d3 ; CHECK-NEXT: fmul d4, d1, d4 ; CHECK-NEXT: str d1, [x0] -; CHECK-NEXT: fcsel d1, d0, d3, eq +; CHECK-NEXT: fmul d1, d0, d1 +; CHECK-NEXT: fcsel d1, d2, d1, eq ; CHECK-NEXT: fdiv d0, d0, d1 -; CHECK-NEXT: str d2, [x1] +; CHECK-NEXT: str d3, [x1] ; CHECK-NEXT: str d4, [x2] ; CHECK-NEXT: ret %sqrt = tail call fast double @llvm.sqrt.f64(double %x)