Index: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
===================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -357,11 +357,13 @@
     SDValue BuildSDIVPow2(SDNode *N);
     SDValue BuildUDIV(SDNode *N);
     SDValue BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags);
-    SDValue BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags);
-    SDValue BuildRsqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations,
-                                 SDNodeFlags *Flags);
-    SDValue BuildRsqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations,
-                                 SDNodeFlags *Flags);
+    SDValue buildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags);
+    SDValue buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags);
+    SDValue buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags, bool Recip);
+    SDValue buildSqrtNROneConst(SDValue Op, SDValue Est, unsigned Iterations,
+                                SDNodeFlags *Flags, bool Reciprocal);
+    SDValue buildSqrtNRTwoConst(SDValue Op, SDValue Est, unsigned Iterations,
+                                SDNodeFlags *Flags, bool Reciprocal);
     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
                                bool DemandHighBits = true);
     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
@@ -8825,12 +8827,12 @@
     // If this FDIV is part of a reciprocal square root, it may be folded
     // into a target-specific square root estimate instruction.
     if (N1.getOpcode() == ISD::FSQRT) {
-      if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0), Flags)) {
+      if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0), Flags)) {
         return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
       }
     } else if (N1.getOpcode() == ISD::FP_EXTEND &&
                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
-      if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0),
+      if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
                                           Flags)) {
         RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
         AddToWorklist(RV.getNode());
@@ -8838,7 +8840,7 @@
       }
     } else if (N1.getOpcode() == ISD::FP_ROUND &&
                N1.getOperand(0).getOpcode() == ISD::FSQRT) {
-      if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0),
+      if (SDValue RV = buildRsqrtEstimate(N1.getOperand(0).getOperand(0),
                                           Flags)) {
         RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
         AddToWorklist(RV.getNode());
@@ -8859,7 +8861,7 @@
       if (SqrtOp.getNode()) {
         // We found a FSQRT, so try to make this fold:
         // x / (y * sqrt(z)) -> x * (rsqrt(z) / y)
-        if (SDValue RV = BuildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) {
+        if (SDValue RV = buildRsqrtEstimate(SqrtOp.getOperand(0), Flags)) {
           RV = DAG.getNode(ISD::FDIV, SDLoc(N1), VT, RV, OtherOp, Flags);
           AddToWorklist(RV.getNode());
           return DAG.getNode(ISD::FMUL, DL, VT, N0, RV, Flags);
@@ -8916,27 +8918,7 @@
   // For now, create a Flags object for use with all unsafe math transforms.
   SDNodeFlags Flags;
   Flags.setUnsafeAlgebra(true);
-
-  // Compute this as X * (1/sqrt(X)) = X * (X ** -0.5)
-  SDValue RV = BuildRsqrtEstimate(N->getOperand(0), &Flags);
-  if (!RV)
-    return SDValue();
-
-  EVT VT = RV.getValueType();
-  SDLoc DL(N);
-  RV = DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV, &Flags);
-  AddToWorklist(RV.getNode());
-
-  // Unfortunately, RV is now NaN if the input was exactly 0.
-  // Select out this case and force the answer to 0.
-  SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
-  EVT CCVT = getSetCCResultType(VT);
-  SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, N->getOperand(0), Zero, ISD::SETEQ);
-  AddToWorklist(ZeroCmp.getNode());
-  AddToWorklist(RV.getNode());
-
-  return DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
-                     ZeroCmp, Zero, RV);
+  return buildSqrtEstimate(N->getOperand(0), &Flags);
 }
 
 /// copysign(x, fp_extend(y)) -> copysign(x, y)
@@ -14587,9 +14569,9 @@
 ///     =>
 ///   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
 /// As a result, we precompute A/2 prior to the iteration loop.
-SDValue DAGCombiner::BuildRsqrtNROneConst(SDValue Arg, SDValue Est,
-                                          unsigned Iterations,
-                                          SDNodeFlags *Flags) {
+SDValue DAGCombiner::buildSqrtNROneConst(SDValue Arg, SDValue Est,
+                                         unsigned Iterations,
+                                         SDNodeFlags *Flags, bool Reciprocal) {
   EVT VT = Arg.getValueType();
   SDLoc DL(Arg);
   SDValue ThreeHalves = DAG.getConstantFP(1.5, DL, VT);
@@ -14616,6 +14598,13 @@
     Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst, Flags);
     AddToWorklist(Est.getNode());
   }
+
+  // If non-reciprocal square root is requested, multiply the result by Arg.
+  if (!Reciprocal) {
+    Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
+    AddToWorklist(Est.getNode());
+  }
+
   return Est;
 }
 
@@ -14624,35 +14613,55 @@
 ///   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
 ///     =>
 ///   X_{i+1} = (-0.5 * X_i) * (A * X_i * X_i + (-3.0))
-SDValue DAGCombiner::BuildRsqrtNRTwoConst(SDValue Arg, SDValue Est,
-                                          unsigned Iterations,
-                                          SDNodeFlags *Flags) {
+SDValue DAGCombiner::buildSqrtNRTwoConst(SDValue Arg, SDValue Est,
+                                         unsigned Iterations,
+                                         SDNodeFlags *Flags, bool Reciprocal) {
   EVT VT = Arg.getValueType();
   SDLoc DL(Arg);
   SDValue MinusThree = DAG.getConstantFP(-3.0, DL, VT);
   SDValue MinusHalf = DAG.getConstantFP(-0.5, DL, VT);
 
-  // Newton iterations: Est = -0.5 * Est * (-3.0 + Arg * Est * Est)
+  // This routine must enter the loop below to work correctly
+  // when (Reciprocal == false).
+  assert(Iterations > 0);
+
+  // Newton iterations for reciprocal square root:
+  // E = (E * -0.5) * ((A * E) * E + -3.0)
   for (unsigned i = 0; i < Iterations; ++i) {
-    SDValue HalfEst = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
-    AddToWorklist(HalfEst.getNode());
+    SDValue AE = DAG.getNode(ISD::FMUL, DL, VT, Arg, Est, Flags);
+    AddToWorklist(AE.getNode());
 
-    Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Est, Flags);
-    AddToWorklist(Est.getNode());
+    SDValue AEE = DAG.getNode(ISD::FMUL, DL, VT, AE, Est, Flags);
+    AddToWorklist(AEE.getNode());
 
-    Est = DAG.getNode(ISD::FMUL, DL, VT, Est, Arg, Flags);
-    AddToWorklist(Est.getNode());
+    SDValue RHS = DAG.getNode(ISD::FADD, DL, VT, AEE, MinusThree, Flags);
+    AddToWorklist(RHS.getNode());
 
-    Est = DAG.getNode(ISD::FADD, DL, VT, Est, MinusThree, Flags);
-    AddToWorklist(Est.getNode());
+    // When calculating a square root at the last iteration build:
+    // S = ((A * E) * -0.5) * ((A * E) * E + -3.0)
+    // (notice a common subexpression)
+    SDValue LHS;
+    if (Reciprocal || (i + 1) < Iterations) {
+      // RSQRT: LHS = (E * -0.5)
+      LHS = DAG.getNode(ISD::FMUL, DL, VT, Est, MinusHalf, Flags);
+    } else {
+      // SQRT: LHS = (A * E) * -0.5
+      LHS = DAG.getNode(ISD::FMUL, DL, VT, AE, MinusHalf, Flags);
+    }
+    AddToWorklist(LHS.getNode());
 
-    Est = DAG.getNode(ISD::FMUL, DL, VT, Est, HalfEst, Flags);
+    Est = DAG.getNode(ISD::FMUL, DL, VT, LHS, RHS, Flags);
     AddToWorklist(Est.getNode());
   }
+
   return Est;
 }
 
-SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
+/// Build code to calculate either rsqrt(Op) or sqrt(Op). In the latter case
+/// Op*rsqrt(Op) is actually computed, so additional postprocessing is needed if
+/// Op can be zero.
+SDValue DAGCombiner::buildSqrtEstimateImpl(SDValue Op, SDNodeFlags *Flags,
+                                           bool Reciprocal) {
   if (Level >= AfterLegalizeDAG)
     return SDValue();
 
@@ -14663,9 +14672,9 @@
   if (SDValue Est = TLI.getRsqrtEstimate(Op, DCI, Iterations, UseOneConstNR)) {
     AddToWorklist(Est.getNode());
     if (Iterations) {
-      Est = UseOneConstNR ?
-        BuildRsqrtNROneConst(Op, Est, Iterations, Flags) :
-        BuildRsqrtNRTwoConst(Op, Est, Iterations, Flags);
+      Est = UseOneConstNR
+                ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
+                : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
     }
     return Est;
   }
@@ -14673,6 +14682,30 @@
   return SDValue();
 }
 
+SDValue DAGCombiner::buildRsqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
+  return buildSqrtEstimateImpl(Op, Flags, true);
+}
+
+SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
+  SDValue Est = buildSqrtEstimateImpl(Op, Flags, false);
+  if (!Est)
+    return SDValue();
+
+  // Unfortunately, Est is now NaN if the input was exactly 0.
+  // Select out this case and force the answer to 0.
+  EVT VT = Est.getValueType();
+  SDLoc DL(Op);
+  SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
+  EVT CCVT = getSetCCResultType(VT);
+  SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, Op, Zero, ISD::SETEQ);
+  AddToWorklist(ZeroCmp.getNode());
+
+  Est = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, ZeroCmp,
+                    Zero, Est);
+  AddToWorklist(Est.getNode());
+  return Est;
+}
+
 /// Return true if base is a frame index, which is known not to alias with
 /// anything but itself.  Provides base object and offset as results.
 static bool FindBaseOffset(SDValue Ptr, SDValue &Base, int64_t &Offset,
Index: llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll
===================================================================
--- llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll
+++ llvm/trunk/test/CodeGen/X86/sqrt-fastmath-mir.ll
@@ -0,0 +1,52 @@
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=avx2,fma -recip=sqrt:2 -stop-after=expand-isel-pseudos 2>&1 | FileCheck %s
+
+declare float @llvm.sqrt.f32(float) #0
+
+define float @foo(float %f) #0 {
+; CHECK: {{name: *foo}}
+; CHECK: body:
+; CHECK:     %0 = COPY %xmm0
+; CHECK:     %1 = VRSQRTSSr killed %2, %0
+; CHECK:     %3 = VMULSSrr %0, %1
+; CHECK:     %4 = VMOVSSrm
+; CHECK:     %5 = VFMADDSSr213r %1, killed %3, %4
+; CHECK:     %6 = VMOVSSrm
+; CHECK:     %7 = VMULSSrr %1, %6
+; CHECK:     %8 = VMULSSrr killed %7, killed %5
+; CHECK:     %9 = VMULSSrr %0, %8
+; CHECK:     %10 = VFMADDSSr213r %8, %9, %4
+; CHECK:     %11 = VMULSSrr %9, %6
+; CHECK:     %12 = VMULSSrr killed %11, killed %10
+; CHECK:     %13 = FsFLD0SS
+; CHECK:     %14 = VCMPSSrr %0, killed %13, 0
+; CHECK:     %15 = VFsANDNPSrr killed %14, killed %12
+; CHECK:     %xmm0 = COPY %15
+; CHECK:     RET 0, %xmm0
+  %call = tail call float @llvm.sqrt.f32(float %f) #1
+  ret float %call
+}
+
+define float @rfoo(float %f) #0 {
+; CHECK: {{name: *rfoo}}
+; CHECK: body:             |
+; CHECK:     %0 = COPY %xmm0
+; CHECK:     %1 = VRSQRTSSr killed %2, %0
+; CHECK:     %3 = VMULSSrr %0, %1
+; CHECK:     %4 = VMOVSSrm
+; CHECK:     %5 = VFMADDSSr213r %1, killed %3, %4
+; CHECK:     %6 = VMOVSSrm
+; CHECK:     %7 = VMULSSrr %1, %6
+; CHECK:     %8 = VMULSSrr killed %7, killed %5
+; CHECK:     %9 = VMULSSrr %0, %8
+; CHECK:     %10 = VFMADDSSr213r %8, killed %9, %4
+; CHECK:     %11 = VMULSSrr %8, %6
+; CHECK:     %12 = VMULSSrr killed %11, killed %10
+; CHECK:     %xmm0 = COPY %12
+; CHECK:     RET 0, %xmm0
+  %sqrt = tail call float @llvm.sqrt.f32(float %f)
+  %div = fdiv fast float 1.0, %sqrt
+  ret float %div
+}
+
+attributes #0 = { "unsafe-fp-math"="true" }
+attributes #1 = { nounwind readnone }
Index: llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll
===================================================================
--- llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll
+++ llvm/trunk/test/CodeGen/X86/sqrt-fastmath.ll
@@ -34,12 +34,11 @@
 ; ESTIMATE-LABEL: ff:
 ; ESTIMATE:       # BB#0:
 ; ESTIMATE-NEXT:    vrsqrtss %xmm0, %xmm0, %xmm1
-; ESTIMATE-NEXT:    vmulss {{.*}}(%rip), %xmm1, %xmm2
-; ESTIMATE-NEXT:    vmulss %xmm0, %xmm1, %xmm3
-; ESTIMATE-NEXT:    vmulss %xmm3, %xmm1, %xmm1
+; ESTIMATE-NEXT:    vmulss %xmm1, %xmm0, %xmm2
+; ESTIMATE-NEXT:    vmulss %xmm1, %xmm2, %xmm1
 ; ESTIMATE-NEXT:    vaddss {{.*}}(%rip), %xmm1, %xmm1
-; ESTIMATE-NEXT:    vmulss %xmm0, %xmm2, %xmm2
-; ESTIMATE-NEXT:    vmulss %xmm2, %xmm1, %xmm1
+; ESTIMATE-NEXT:    vmulss {{.*}}(%rip), %xmm2, %xmm2
+; ESTIMATE-NEXT:    vmulss %xmm1, %xmm2, %xmm1
 ; ESTIMATE-NEXT:    vxorps %xmm2, %xmm2, %xmm2
 ; ESTIMATE-NEXT:    vcmpeqss %xmm2, %xmm0, %xmm0
 ; ESTIMATE-NEXT:    vandnps %xmm1, %xmm0, %xmm0
@@ -78,11 +77,11 @@
 ; ESTIMATE-LABEL: reciprocal_square_root:
 ; ESTIMATE:       # BB#0:
 ; ESTIMATE-NEXT:    vrsqrtss %xmm0, %xmm0, %xmm1
-; ESTIMATE-NEXT:    vmulss {{.*}}(%rip), %xmm1, %xmm2
-; ESTIMATE-NEXT:    vmulss %xmm0, %xmm1, %xmm0
-; ESTIMATE-NEXT:    vmulss %xmm0, %xmm1, %xmm0
-; ESTIMATE-NEXT:    vaddss {{.*}}(%rip), %xmm0, %xmm0
+; ESTIMATE-NEXT:    vmulss %xmm1, %xmm1, %xmm2
 ; ESTIMATE-NEXT:    vmulss %xmm2, %xmm0, %xmm0
+; ESTIMATE-NEXT:    vaddss {{.*}}(%rip), %xmm0, %xmm0
+; ESTIMATE-NEXT:    vmulss {{.*}}(%rip), %xmm1, %xmm1
+; ESTIMATE-NEXT:    vmulss %xmm0, %xmm1, %xmm0
 ; ESTIMATE-NEXT:    retq
   %sqrt = tail call float @llvm.sqrt.f32(float %x)
   %div = fdiv fast float 1.0, %sqrt
@@ -100,11 +99,11 @@
 ; ESTIMATE-LABEL: reciprocal_square_root_v4f32:
 ; ESTIMATE:       # BB#0:
 ; ESTIMATE-NEXT:    vrsqrtps %xmm0, %xmm1
-; ESTIMATE-NEXT:    vmulps %xmm0, %xmm1, %xmm0
-; ESTIMATE-NEXT:    vmulps %xmm0, %xmm1, %xmm0
+; ESTIMATE-NEXT:    vmulps %xmm1, %xmm1, %xmm2
+; ESTIMATE-NEXT:    vmulps %xmm2, %xmm0, %xmm0
 ; ESTIMATE-NEXT:    vaddps {{.*}}(%rip), %xmm0, %xmm0
 ; ESTIMATE-NEXT:    vmulps {{.*}}(%rip), %xmm1, %xmm1
-; ESTIMATE-NEXT:    vmulps %xmm1, %xmm0, %xmm0
+; ESTIMATE-NEXT:    vmulps %xmm0, %xmm1, %xmm0
 ; ESTIMATE-NEXT:    retq
   %sqrt = tail call <4 x float> @llvm.sqrt.v4f32(<4 x float> %x)
   %div = fdiv fast <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt
@@ -125,11 +124,11 @@
 ; ESTIMATE-LABEL: reciprocal_square_root_v8f32:
 ; ESTIMATE:       # BB#0:
 ; ESTIMATE-NEXT:    vrsqrtps %ymm0, %ymm1
-; ESTIMATE-NEXT:    vmulps %ymm0, %ymm1, %ymm0
-; ESTIMATE-NEXT:    vmulps %ymm0, %ymm1, %ymm0
+; ESTIMATE-NEXT:    vmulps %ymm1, %ymm1, %ymm2
+; ESTIMATE-NEXT:    vmulps %ymm2, %ymm0, %ymm0
 ; ESTIMATE-NEXT:    vaddps {{.*}}(%rip), %ymm0, %ymm0
 ; ESTIMATE-NEXT:    vmulps {{.*}}(%rip), %ymm1, %ymm1
-; ESTIMATE-NEXT:    vmulps %ymm1, %ymm0, %ymm0
+; ESTIMATE-NEXT:    vmulps %ymm0, %ymm1, %ymm0
 ; ESTIMATE-NEXT:    retq
   %sqrt = tail call <8 x float> @llvm.sqrt.v8f32(<8 x float> %x)
   %div = fdiv fast <8 x float> <float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0, float 1.0>, %sqrt