Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -13224,6 +13224,30 @@ Y = N1.getOperand(0); } if (Sqrt.getNode()) { + // If the other multiply operand is known positive, pull it into the + // sqrt. That will eliminate the division if we convert to an estimate: + // X / (fabs(A) * sqrt(Z)) --> X / sqrt(A*A*Z) --> X * rsqrt(A*A*Z) + // TODO: Also fold the case where A == Z (fabs is missing). + if (Flags.hasAllowReassociation() && N1.hasOneUse() && + N1->getFlags().hasAllowReassociation() && Sqrt.hasOneUse() && + Y.getOpcode() == ISD::FABS && Y.hasOneUse()) { + SDValue AA = DAG.getNode(ISD::FMUL, DL, VT, Y.getOperand(0), + Y.getOperand(0), Flags); + SDValue AAZ = + DAG.getNode(ISD::FMUL, DL, VT, AA, Sqrt.getOperand(0), Flags); + SDValue NewSqrt = DAG.getNode(ISD::FSQRT, DL, VT, AAZ, Flags); + if (SDValue Rsqrt = buildRsqrtEstimate(NewSqrt.getOperand(0), Flags)) + return DAG.getNode(ISD::FMUL, DL, VT, N0, Rsqrt, Flags); + + // Estimate creation failed. Clean up speculatively created nodes. + if (NewSqrt->use_empty()) + DAG.RemoveDeadNode(NewSqrt.getNode()); + if (AAZ->use_empty()) + DAG.RemoveDeadNode(AAZ.getNode()); + if (AA->use_empty()) + DAG.RemoveDeadNode(AA.getNode()); + } + // We found a FSQRT, so try to make this fold: // X / (Y * sqrt(Z)) -> X * (rsqrt(Z) / Y) if (SDValue Rsqrt = buildRsqrtEstimate(Sqrt.getOperand(0), Flags)) { Index: llvm/test/CodeGen/X86/sqrt-fastmath.ll =================================================================== --- llvm/test/CodeGen/X86/sqrt-fastmath.ll +++ llvm/test/CodeGen/X86/sqrt-fastmath.ll @@ -618,46 +618,47 @@ ret <16 x float> %div } -; x / (fabs(y) * sqrt(z)) +; x / (fabs(y) * sqrt(z)) --> x * rsqrt(y*y*z) define float @div_sqrt_fabs_f32(float %x, float %y, float %z) { ; SSE-LABEL: div_sqrt_fabs_f32: ; SSE: # %bb.0: -; SSE-NEXT: rsqrtss %xmm2, %xmm3 -; SSE-NEXT: mulss %xmm3, %xmm2 -; SSE-NEXT: mulss %xmm3, %xmm2 -; SSE-NEXT: addss {{.*}}(%rip), %xmm2 -; SSE-NEXT: mulss {{.*}}(%rip), %xmm3 -; SSE-NEXT: mulss %xmm2, %xmm3 -; SSE-NEXT: andps {{.*}}(%rip), %xmm1 -; SSE-NEXT: divss %xmm1, %xmm3 -; SSE-NEXT: mulss %xmm3, %xmm0 +; SSE-NEXT: mulss %xmm1, %xmm1 +; SSE-NEXT: mulss %xmm2, %xmm1 +; SSE-NEXT: xorps %xmm2, %xmm2 +; SSE-NEXT: rsqrtss %xmm1, %xmm2 +; SSE-NEXT: mulss %xmm2, %xmm1 +; SSE-NEXT: mulss %xmm2, %xmm1 +; SSE-NEXT: addss {{.*}}(%rip), %xmm1 +; SSE-NEXT: mulss {{.*}}(%rip), %xmm2 +; SSE-NEXT: mulss %xmm0, %xmm2 +; SSE-NEXT: mulss %xmm1, %xmm2 +; SSE-NEXT: movaps %xmm2, %xmm0 ; SSE-NEXT: retq ; ; AVX1-LABEL: div_sqrt_fabs_f32: ; AVX1: # %bb.0: -; AVX1-NEXT: vrsqrtss %xmm2, %xmm2, %xmm3 -; AVX1-NEXT: vmulss %xmm3, %xmm2, %xmm2 -; AVX1-NEXT: vmulss %xmm3, %xmm2, %xmm2 -; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm2, %xmm2 -; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm3, %xmm3 -; AVX1-NEXT: vmulss %xmm2, %xmm3, %xmm2 -; AVX1-NEXT: vandps {{.*}}(%rip), %xmm1, %xmm1 -; AVX1-NEXT: vdivss %xmm1, %xmm2, %xmm1 -; AVX1-NEXT: vmulss %xmm1, %xmm0, %xmm0 +; AVX1-NEXT: vmulss %xmm1, %xmm1, %xmm1 +; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1 +; AVX1-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2 +; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1 +; AVX1-NEXT: vmulss %xmm2, %xmm1, %xmm1 +; AVX1-NEXT: vaddss {{.*}}(%rip), %xmm1, %xmm1 +; AVX1-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2 +; AVX1-NEXT: vmulss %xmm0, %xmm2, %xmm0 +; AVX1-NEXT: vmulss %xmm0, %xmm1, %xmm0 ; AVX1-NEXT: retq ; ; AVX512-LABEL: div_sqrt_fabs_f32: ; AVX512: # %bb.0: -; AVX512-NEXT: vrsqrtss %xmm2, %xmm2, %xmm3 -; AVX512-NEXT: vmulss %xmm3, %xmm2, %xmm2 -; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm2 = (xmm3 * xmm2) + mem -; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm3, %xmm3 -; AVX512-NEXT: vbroadcastss {{.*#+}} xmm4 = [NaN,NaN,NaN,NaN] -; AVX512-NEXT: vmulss %xmm2, %xmm3, %xmm2 -; AVX512-NEXT: vandps %xmm4, %xmm1, %xmm1 -; AVX512-NEXT: vdivss %xmm1, %xmm2, %xmm1 -; AVX512-NEXT: vmulss %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: vmulss %xmm1, %xmm1, %xmm1 +; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm1 +; AVX512-NEXT: vrsqrtss %xmm1, %xmm1, %xmm2 +; AVX512-NEXT: vmulss %xmm2, %xmm1, %xmm1 +; AVX512-NEXT: vfmadd213ss {{.*#+}} xmm1 = (xmm2 * xmm1) + mem +; AVX512-NEXT: vmulss {{.*}}(%rip), %xmm2, %xmm2 +; AVX512-NEXT: vmulss %xmm0, %xmm2, %xmm0 +; AVX512-NEXT: vmulss %xmm0, %xmm1, %xmm0 ; AVX512-NEXT: retq %s = call fast float @llvm.sqrt.f32(float %z) %a = call fast float @llvm.fabs.f32(float %y) @@ -666,47 +667,46 @@ ret float %d } -; x / (fabs(y) * sqrt(z)) +; x / (fabs(y) * sqrt(z)) --> x * rsqrt(y*y*z) define <4 x float> @div_sqrt_fabs_v4f32(<4 x float> %x, <4 x float> %y, <4 x float> %z) { ; SSE-LABEL: div_sqrt_fabs_v4f32: ; SSE: # %bb.0: -; SSE-NEXT: rsqrtps %xmm2, %xmm3 -; SSE-NEXT: mulps %xmm3, %xmm2 -; SSE-NEXT: mulps %xmm3, %xmm2 -; SSE-NEXT: addps {{.*}}(%rip), %xmm2 -; SSE-NEXT: mulps {{.*}}(%rip), %xmm3 -; SSE-NEXT: mulps %xmm2, %xmm3 -; SSE-NEXT: andps {{.*}}(%rip), %xmm1 -; SSE-NEXT: divps %xmm1, %xmm3 -; SSE-NEXT: mulps %xmm3, %xmm0 +; SSE-NEXT: mulps %xmm1, %xmm1 +; SSE-NEXT: mulps %xmm2, %xmm1 +; SSE-NEXT: rsqrtps %xmm1, %xmm2 +; SSE-NEXT: mulps %xmm2, %xmm1 +; SSE-NEXT: mulps %xmm2, %xmm1 +; SSE-NEXT: addps {{.*}}(%rip), %xmm1 +; SSE-NEXT: mulps {{.*}}(%rip), %xmm2 +; SSE-NEXT: mulps %xmm1, %xmm2 +; SSE-NEXT: mulps %xmm2, %xmm0 ; SSE-NEXT: retq ; ; AVX1-LABEL: div_sqrt_fabs_v4f32: ; AVX1: # %bb.0: -; AVX1-NEXT: vrsqrtps %xmm2, %xmm3 -; AVX1-NEXT: vmulps %xmm3, %xmm2, %xmm2 -; AVX1-NEXT: vmulps %xmm3, %xmm2, %xmm2 -; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm2, %xmm2 -; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm3, %xmm3 -; AVX1-NEXT: vmulps %xmm2, %xmm3, %xmm2 -; AVX1-NEXT: vandps {{.*}}(%rip), %xmm1, %xmm1 -; AVX1-NEXT: vdivps %xmm1, %xmm2, %xmm1 +; AVX1-NEXT: vmulps %xmm1, %xmm1, %xmm1 +; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1 +; AVX1-NEXT: vrsqrtps %xmm1, %xmm2 +; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1 +; AVX1-NEXT: vmulps %xmm2, %xmm1, %xmm1 +; AVX1-NEXT: vaddps {{.*}}(%rip), %xmm1, %xmm1 +; AVX1-NEXT: vmulps {{.*}}(%rip), %xmm2, %xmm2 +; AVX1-NEXT: vmulps %xmm1, %xmm2, %xmm1 ; AVX1-NEXT: vmulps %xmm1, %xmm0, %xmm0 ; AVX1-NEXT: retq ; ; AVX512-LABEL: div_sqrt_fabs_v4f32: ; AVX512: # %bb.0: -; AVX512-NEXT: vrsqrtps %xmm2, %xmm3 -; AVX512-NEXT: vmulps %xmm3, %xmm2, %xmm2 -; AVX512-NEXT: vbroadcastss {{.*#+}} xmm4 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0] -; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm4 = (xmm3 * xmm2) + xmm4 -; AVX512-NEXT: vbroadcastss {{.*#+}} xmm2 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] -; AVX512-NEXT: vmulps %xmm2, %xmm3, %xmm2 -; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [NaN,NaN,NaN,NaN] -; AVX512-NEXT: vmulps %xmm4, %xmm2, %xmm2 -; AVX512-NEXT: vandps %xmm3, %xmm1, %xmm1 -; AVX512-NEXT: vdivps %xmm1, %xmm2, %xmm1 +; AVX512-NEXT: vmulps %xmm1, %xmm1, %xmm1 +; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm1 +; AVX512-NEXT: vrsqrtps %xmm1, %xmm2 +; AVX512-NEXT: vmulps %xmm2, %xmm1, %xmm1 +; AVX512-NEXT: vbroadcastss {{.*#+}} xmm3 = [-3.0E+0,-3.0E+0,-3.0E+0,-3.0E+0] +; AVX512-NEXT: vfmadd231ps {{.*#+}} xmm3 = (xmm2 * xmm1) + xmm3 +; AVX512-NEXT: vbroadcastss {{.*#+}} xmm1 = [-5.0E-1,-5.0E-1,-5.0E-1,-5.0E-1] +; AVX512-NEXT: vmulps %xmm1, %xmm2, %xmm1 +; AVX512-NEXT: vmulps %xmm3, %xmm1, %xmm1 ; AVX512-NEXT: vmulps %xmm1, %xmm0, %xmm0 ; AVX512-NEXT: retq %s = call <4 x float> @llvm.sqrt.v4f32(<4 x float> %z) @@ -716,6 +716,11 @@ ret <4 x float> %d } +; This has 'arcp' but does not have 'reassoc' FMF. +; We allow converting the sqrt to an estimate, but +; do not pull the divisor into the estimate. +; x / (fabs(y) * sqrt(z)) --> x * rsqrt(z) / fabs(y) + define <4 x float> @div_sqrt_fabs_v4f32_fmf(<4 x float> %x, <4 x float> %y, <4 x float> %z) { ; SSE-LABEL: div_sqrt_fabs_v4f32_fmf: ; SSE: # %bb.0: @@ -765,6 +770,8 @@ ret <4 x float> %d } +; No estimates for f64, so do not convert fabs into an fmul. + define double @div_sqrt_fabs_f64(double %x, double %y, double %z) { ; SSE-LABEL: div_sqrt_fabs_f64: ; SSE: # %bb.0: