Index: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -357,11 +357,13 @@ SDValue BuildSDIVPow2(SDNode *N); SDValue BuildUDIV(SDNode *N); SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags); - SDValue BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags); - SDValue BuildRsqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations, - SDNodeFlags *Flags); - SDValue BuildRsqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations, - SDNodeFlags *Flags); + SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags); + SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags); + SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags, bool Recip); + SDValue buildSqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations, + SDNodeFlags *Flags, bool Reciprocal); + SDValue buildSqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations, + SDNodeFlags *Flags, bool Reciprocal); SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1, bool DemandHighBits = true); SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1); @@ -8825,12 +8827,12 @@ // If this FDIV is part of a reciprocal square root, it may be folded // into a target-specific square root estimate instruction. if (N1.getOpcode() == ISD::FSQRT) { - if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0), Flags)) { + if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags)) { return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); } } else if (N1.getOpcode() == ISD::FP_EXTEND && N1.getOperand(0).getOpcode() == ISD::FSQRT) { - if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0), + if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) { RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV); AddToWorklist(RV.getNode()); @@ -8838,7 +8840,7 @@ } } else if (N1.getOpcode() == ISD::FP_ROUND && N1.getOperand(0).getOpcode() == ISD::FSQRT) { - if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0), + if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0), Flags)) { RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1)); AddToWorklist(RV.getNode()); @@ -8859,7 +8861,7 @@ if (SqrtOp.getNode()) { // We found a FSQRT, so try to make this fold: // x / (y * sqrt(z)) -> x * (rsqrt(z) / y) - if (SDValue RV = BuildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) { + if (SDValue RV = buildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) { RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags); AddToWorklist(RV.getNode()); return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags); @@ -8916,27 +8918,7 @@ // For now, create a Flags object for use with all unsafe math transforms. SDNodeFlags Flags; Flags.setUnsafeAlgebra(true); - - // Compute this as X * (1/sqrt(X)) = X * (X ** -0.5) - SDValue RV = BuildRsqrtEstimate(N->getOperand(0), &Flags); - if (!RV) - return SDValue(); - - EVT VT = RV.getValueType(); - SDLoc DL(N); - RV = DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV, &Flags); - AddToWorklist(RV.getNode()); - - // Unfortunately, RV is now NaN if the input was exactly 0. - // Select out this case and force the answer to 0. - SDValue Zero = DAG.getConstantFP(0.0, DL, VT); - EVT CCVT = getSetCCResultType(VT); - SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, N->getOperand(0), Zero, ISD::SETEQ); - AddToWorklist(ZeroCmp.getNode()); - AddToWorklist(RV.getNode()); - - return DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, - ZeroCmp, Zero, RV); + return buildSqrtEstimate(N->getOperand(0), &Flags); } /// copysign(x, fp_extend(y)) -> copysign(x, y) @@ -14587,9 +14569,9 @@ /// => /// X_{i+1} = X_i (1.5 - A X_i^2 / 2) /// As a result, we precompute A/2 prior to the iteration loop. -SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est, - unsigned Iterations, - SDNodeFlags *Flags) { +SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est, + unsigned Iterations, + SDNodeFlags *Flags, bool Reciprocal) { EVT VT = Arg.getValueType(); SDLoc DL(Arg); SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT); @@ -14616,6 +14598,13 @@ Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags); AddToWorklist(Est.getNode()); } + + // If non-reciprocal square root is requested, multiply the result by Arg. + if (!Reciprocal) { + Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags); + AddToWorklist(Est.getNode()); + } + return Est; } @@ -14624,35 +14613,55 @@ /// F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)] /// => /// X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0)) -SDValue DAGCombiner::BuildRsqrtNRTwoConst(SDValue Arg, SDValue Est, - unsigned Iterations, - SDNodeFlags *Flags) { +SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est, + unsigned Iterations, + SDNodeFlags *Flags, bool Reciprocal) { EVT VT = Arg.getValueType(); SDLoc DL(Arg); SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT); SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT); - // Newton iterations: Est = -0.5 * Est * (-3.0 + Arg * Est * Est) + // This routine must enter the loop below to work correctly + // when (Reciprocal == false). + assert(Iterations > 0); + + // Newton iterations for reciprocal square root: + // E = (E * -0.5) * ((A * E) * E + -3.0) for (unsigned i = 0; i < Iterations; ++i) { - SDValue HalfEst = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags); - AddToWorklist(HalfEst.getNode()); + SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags); + AddToWorklist(AE.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags); - AddToWorklist(Est.getNode()); + SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags); + AddToWorklist(AEE.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags); - AddToWorklist(Est.getNode()); + SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags); + AddToWorklist(RHS.getNode()); - Est = DAG.getNode(ISD::FADD, DL, VT, Est, MinusThree, Flags); - AddToWorklist(Est.getNode()); + // When calculating a square root at the last iteration build: + // S = ((A * E) * -0.5) * ((A * E) * E + -3.0) + // (notice a common subexpression) + SDValue LHS; + if (Reciprocal || (i + 1) < Iterations) { + // RSQRT: LHS = (E * -0.5) + LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags); + } else { + // SQRT: LHS = (A * E) * -0.5 + LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags); + } + AddToWorklist(LHS.getNode()); - Est = DAG.getNode(ISD::FMUL, DL, VT, Est, HalfEst, Flags); + Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags); AddToWorklist(Est.getNode()); } + return Est; } -SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) { +/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case +/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if +/// Op can be zero. +SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags, + bool Reciprocal) { if (Level >= AfterLegalizeDAG) return SDValue(); @@ -14663,9 +14672,9 @@ if (SDValue Est = TLI.getRsqrtEstimate(Op, DCI, Iterations, UseOneConstNR)) { AddToWorklist(Est.getNode()); if (Iterations) { - Est = UseOneConstNR ? - BuildRsqrtNROneConst(Op, Est, Iterations, Flags) : - BuildRsqrtNRTwoConst(Op, Est, Iterations, Flags); + Est = UseOneConstNR + ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal) + : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal); } return Est; } @@ -14673,6 +14682,30 @@ return SDValue(); } +SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) { + return buildSqrtEstimateImpl(Op, Flags, true); +} + +SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags) { + SDValue Est = buildSqrtEstimateImpl(Op, Flags, false); + if (!Est) + return SDValue(); + + // Unfortunately, Est is now NaN if the input was exactly 0. + // Select out this case and force the answer to 0. + EVT VT = Est.getValueType(); + SDLoc DL(Op); + SDValue Zero = DAG.getConstantFP(0.0, DL, VT); + EVT CCVT = getSetCCResultType(VT); + SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, Op, Zero, ISD::SETEQ); + AddToWorklist(ZeroCmp.getNode()); + + Est = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, ZeroCmp, + Zero, Est); + AddToWorklist(Est.getNode()); + return Est; +} + /// Return true if base is a frame index, which is known not to alias with /// anything but itself. Provides base object and offset as results. static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset, Index: llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll +++ llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll @@ -0,0 +1,52 @@ +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx2,fma -recip=sqrt:2 -stop-after=expand-isel-pseudos 2>&1 | FileCheck %s + +declare float @llvm.sqrt.f32(float) #0 + +define float @foo(float %f) #0 { +; CHECK: {{name: *foo}} +; CHECK: body: +; CHECK: %0 = COPY %xmm0 +; CHECK: %1 = VRSQRTSSr killed %2, %0 +; CHECK: %3 = VMULSSrr %0, %1 +; CHECK: %4 = VMOVSSrm +; CHECK: %5 = VFMADDSSr213r %1, killed %3, %4 +; CHECK: %6 = VMOVSSrm +; CHECK: %7 = VMULSSrr %1, %6 +; CHECK: %8 = VMULSSrr killed %7, killed %5 +; CHECK: %9 = VMULSSrr %0, %8 +; CHECK: %10 = VFMADDSSr213r %8, %9, %4 +; CHECK: %11 = VMULSSrr %9, %6 +; CHECK: %12 = VMULSSrr killed %11, killed %10 +; CHECK: %13 = FsFLD0SS +; CHECK: %14 = VCMPSSrr %0, killed %13, 0 +; CHECK: %15 = VFsANDNPSrr killed %14, killed %12 +; CHECK: %xmm0 = COPY %15 +; CHECK: RET 0, %xmm0 + %call = tail call float @llvm.sqrt.f32(float %f) #1 + ret float %call +} + +define float @rfoo(float %f) #0 { +; CHECK: {{name: *rfoo}} +; CHECK: body: | +; CHECK: %0 = COPY %xmm0 +; CHECK: %1 = VRSQRTSSr killed %2, %0 +; CHECK: %3 = VMULSSrr %0, %1 +; CHECK: %4 = VMOVSSrm +; CHECK: %5 = VFMADDSSr213r %1, killed %3, %4 +; CHECK: %6 = VMOVSSrm +; CHECK: %7 = VMULSSrr %1, %6 +; CHECK: %8 = VMULSSrr killed %7, killed %5 +; CHECK: %9 = VMULSSrr %0, %8 +; CHECK: %10 = VFMADDSSr213r %8, killed %9, %4 +; CHECK: %11 = VMULSSrr %8, %6 +; CHECK: %12 = VMULSSrr killed %11, killed %10 +; CHECK: %xmm0 = COPY %12 +; CHECK: RET 0, %xmm0 + %sqrt = tail call float @llvm.sqrt.f32(float %f) + %div = fdiv fast float 1.0, %sqrt + ret float %div +} + +attributes #0 = { "unsafe-fp-math"="true" } +attributes #1 = { nounwind readnone } Index: llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll +++ llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll @@ -34,12 +34,11 @@ ; ESTIMATE-LABEL: ff: ; ESTIMATE: # BB#0: ; ESTIMATE-NEXT: vrsqrtss %xmm0, %xmm0, %xmm1 -; ESTIMATE-NEXT: vmulss {{.*}}(%rip), %xmm1, %xmm2 -; ESTIMATE-NEXT: vmulss %xmm0, %xmm1, %xmm3 -; ESTIMATE-NEXT: vmulss %xmm3, %xmm1, %xmm1 +; ESTIMATE-NEXT: vmulss %xmm1, %xmm0, %xmm2 +; ESTIMATE-NEXT: vmulss %xmm1, %xmm2, %xmm1 ; ESTIMATE-NEXT: vaddss {{.*}}(%rip), %xmm1, %xmm1 -; ESTIMATE-NEXT: vmulss %xmm0, %xmm2, %xmm2 -; ESTIMATE-NEXT: vmulss %xmm2, %xmm1, %xmm1 +; ESTIMATE-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2 +; ESTIMATE-NEXT: vmulss %xmm1, %xmm2, %xmm1 ; ESTIMATE-NEXT: vxorps %xmm2, %xmm2, %xmm2 ; ESTIMATE-NEXT: vcmpeqss %xmm2, %xmm0, %xmm0 ; ESTIMATE-NEXT: vandnps %xmm1, %xmm0, %xmm0 @@ -78,11 +77,11 @@ ; ESTIMATE-LABEL: reciprocal_square_root: ; ESTIMATE: # BB#0: ; ESTIMATE-NEXT: vrsqrtss %xmm0, %xmm0, %xmm1 -; ESTIMATE-NEXT: vmulss {{.*}}(%rip), %xmm1, %xmm2 -; ESTIMATE-NEXT: vmulss %xmm0, %xmm1, %xmm0 -; ESTIMATE-NEXT: vmulss %xmm0, %xmm1, %xmm0 -; ESTIMATE-NEXT: vaddss {{.*}}(%rip), %xmm0, %xmm0 +; ESTIMATE-NEXT: vmulss %xmm1, %xmm1, %xmm2 ; ESTIMATE-NEXT: vmulss %xmm2, %xmm0, %xmm0 +; ESTIMATE-NEXT: vaddss {{.*}}(%rip), %xmm0, %xmm0 +; ESTIMATE-NEXT: vmulss {{.*}}(%rip), %xmm1, %xmm1 +; ESTIMATE-NEXT: vmulss %xmm0, %xmm1, %xmm0 ; ESTIMATE-NEXT: retq %sqrt = tail call float @llvm.sqrt.f32(float %x) %div = fdiv fast float 1.0, %sqrt @@ -100,11 +99,11 @@ ; ESTIMATE-LABEL: reciprocal_square_root_v4f32: ; ESTIMATE: # BB#0: ; ESTIMATE-NEXT: vrsqrtps %xmm0, %xmm1 -; ESTIMATE-NEXT: vmulps %xmm0, %xmm1, %xmm0 -; ESTIMATE-NEXT: vmulps %xmm0, %xmm1, %xmm0 +; ESTIMATE-NEXT: vmulps %xmm1, %xmm1, %xmm2 +; ESTIMATE-NEXT: vmulps %xmm2, %xmm0, %xmm0 ; ESTIMATE-NEXT: vaddps {{.*}}(%rip), %xmm0, %xmm0 ; ESTIMATE-NEXT: vmulps {{.*}}(%rip), %xmm1, %xmm1 -; ESTIMATE-NEXT: vmulps %xmm1, %xmm0, %xmm0 +; ESTIMATE-NEXT: vmulps %xmm0, %xmm1, %xmm0 ; ESTIMATE-NEXT: retq %sqrt = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> %x) %div = fdiv fast <4 x float> , %sqrt @@ -125,11 +124,11 @@ ; ESTIMATE-LABEL: reciprocal_square_root_v8f32: ; ESTIMATE: # BB#0: ; ESTIMATE-NEXT: vrsqrtps %ymm0, %ymm1 -; ESTIMATE-NEXT: vmulps %ymm0, %ymm1, %ymm0 -; ESTIMATE-NEXT: vmulps %ymm0, %ymm1, %ymm0 +; ESTIMATE-NEXT: vmulps %ymm1, %ymm1, %ymm2 +; ESTIMATE-NEXT: vmulps %ymm2, %ymm0, %ymm0 ; ESTIMATE-NEXT: vaddps {{.*}}(%rip), %ymm0, %ymm0 ; ESTIMATE-NEXT: vmulps {{.*}}(%rip), %ymm1, %ymm1 -; ESTIMATE-NEXT: vmulps %ymm1, %ymm0, %ymm0 +; ESTIMATE-NEXT: vmulps %ymm0, %ymm1, %ymm0 ; ESTIMATE-NEXT: retq %sqrt = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> %x) %div = fdiv fast <8 x float> , %sqrt