diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23190,6 +23190,10 @@ bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); + // We don't need to replace SQRT with RSQRT for half type. + if (VT.getScalarType() == MVT::f16) + return true; + // We never want to use both SQRT and RSQRT instructions for the same input. if (DAG.getNodeIfExists(X86ISD::FRSQRT, DAG.getVTList(VT), Op)) return false; @@ -23228,11 +23232,15 @@ UseOneConstNR = false; // There is no FSQRT for 512-bits, but there is RSQRT14. unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT; - return DAG.getNode(Opcode, DL, VT, Op); + SDValue Estimate = DAG.getNode(Opcode, DL, VT, Op); + if (RefinementSteps == 0 && !Reciprocal) + Estimate = DAG.getNode(ISD::FMUL, DL, VT, Op, Estimate); + return Estimate; } if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) && Subtarget.hasFP16()) { + assert(Reciprocal && "Don't replace SQRT with RSQRT for half type"); if (RefinementSteps == ReciprocalEstimate::Unspecified) RefinementSteps = 0; diff --git a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll --- a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll @@ -969,6 +969,15 @@ ret <8 x half> %2 } +define <8 x half> @test_sqrt_ph_128_fast2(<8 x half> %a0, <8 x half> %a1) { +; CHECK-LABEL: test_sqrt_ph_128_fast2: +; CHECK: # %bb.0: +; CHECK-NEXT: vsqrtph %xmm0, %xmm0 +; CHECK-NEXT: retq + %1 = call fast <8 x half> @llvm.sqrt.v8f16(<8 x half> %a0) + ret <8 x half> %1 +} + define <8 x half> @test_mask_sqrt_ph_128(<8 x half> %a0, <8 x half> %passthru, i8 %mask) { ; CHECK-LABEL: test_mask_sqrt_ph_128: ; CHECK: # %bb.0: diff --git a/llvm/test/CodeGen/X86/sqrt-fastmath.ll b/llvm/test/CodeGen/X86/sqrt-fastmath.ll --- a/llvm/test/CodeGen/X86/sqrt-fastmath.ll +++ b/llvm/test/CodeGen/X86/sqrt-fastmath.ll @@ -384,6 +384,40 @@ ret float %div } +define float @f32_estimate2(float %x) #5 { +; SSE-LABEL: f32_estimate2: +; SSE: # %bb.0: +; SSE-NEXT: rsqrtss %xmm0, %xmm1 +; SSE-NEXT: mulss %xmm0, %xmm1 +; SSE-NEXT: andps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 +; SSE-NEXT: cmpltss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 +; SSE-NEXT: andnps %xmm1, %xmm0 +; SSE-NEXT: retq +; +; AVX1-LABEL: f32_estimate2: +; AVX1: # %bb.0: +; AVX1-NEXT: vrsqrtss %xmm0, %xmm0, %xmm1 +; AVX1-NEXT: vandps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2 +; AVX1-NEXT: vmulss %xmm1, %xmm0, %xmm0 +; AVX1-NEXT: vcmpltss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm2, %xmm1 +; AVX1-NEXT: vandnps %xmm0, %xmm1, %xmm0 +; AVX1-NEXT: retq +; +; AVX512-LABEL: f32_estimate2: +; AVX512: # %bb.0: +; AVX512-NEXT: vrsqrtss %xmm0, %xmm0, %xmm1 +; AVX512-NEXT: vmulss %xmm1, %xmm0, %xmm1 +; AVX512-NEXT: vbroadcastss {{.*#+}} xmm2 = [NaN,NaN,NaN,NaN] +; AVX512-NEXT: vandps %xmm2, %xmm0, %xmm0 +; AVX512-NEXT: vcmpltss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %k1 +; AVX512-NEXT: vxorps %xmm0, %xmm0, %xmm0 +; AVX512-NEXT: vmovss %xmm0, %xmm1, %xmm1 {%k1} +; AVX512-NEXT: vmovaps %xmm1, %xmm0 +; AVX512-NEXT: retq + %sqrt = tail call fast float @llvm.sqrt.f32(float %x) + ret float %sqrt +} + define <4 x float> @v4f32_no_estimate(<4 x float> %x) #0 { ; SSE-LABEL: v4f32_no_estimate: ; SSE: # %bb.0: @@ -446,6 +480,42 @@ ret <4 x float> %div } +define <4 x float> @v4f32_estimate2(<4 x float> %x) #5 { +; SSE-LABEL: v4f32_estimate2: +; SSE: # %bb.0: +; SSE-NEXT: rsqrtps %xmm0, %xmm2 +; SSE-NEXT: mulps %xmm0, %xmm2 +; SSE-NEXT: andps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 +; SSE-NEXT: movaps {{.*#+}} xmm1 = [1.17549435E-38,1.17549435E-38,1.17549435E-38,1.17549435E-38] +; SSE-NEXT: cmpleps %xmm0, %xmm1 +; SSE-NEXT: andps %xmm2, %xmm1 +; SSE-NEXT: movaps %xmm1, %xmm0 +; SSE-NEXT: retq +; +; AVX1-LABEL: v4f32_estimate2: +; AVX1: # %bb.0: +; AVX1-NEXT: vrsqrtps %xmm0, %xmm1 +; AVX1-NEXT: vandps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm2 +; AVX1-NEXT: vmulps %xmm1, %xmm0, %xmm0 +; AVX1-NEXT: vmovaps {{.*#+}} xmm1 = [1.17549435E-38,1.17549435E-38,1.17549435E-38,1.17549435E-38] +; AVX1-NEXT: vcmpleps %xmm2, %xmm1, %xmm1 +; AVX1-NEXT: vandps %xmm0, %xmm1, %xmm0 +; AVX1-NEXT: retq +; +; AVX512-LABEL: v4f32_estimate2: +; AVX512: # %bb.0: +; AVX512-NEXT: vrsqrtps %xmm0, %xmm1 +; AVX512-NEXT: vmulps %xmm1, %xmm0, %xmm1 +; AVX512-NEXT: vbroadcastss {{.*#+}} xmm2 = [NaN,NaN,NaN,NaN] +; AVX512-NEXT: vandps %xmm2, %xmm0, %xmm0 +; AVX512-NEXT: vbroadcastss {{.*#+}} xmm2 = [1.17549435E-38,1.17549435E-38,1.17549435E-38,1.17549435E-38] +; AVX512-NEXT: vcmpleps %xmm0, %xmm2, %xmm0 +; AVX512-NEXT: vandps %xmm1, %xmm0, %xmm0 +; AVX512-NEXT: retq + %sqrt = tail call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> %x) + ret <4 x float> %sqrt +} + define <8 x float> @v8f32_no_estimate(<8 x float> %x) #0 { ; SSE-LABEL: v8f32_no_estimate: ; SSE: # %bb.0: @@ -1020,3 +1090,4 @@ attributes #2 = { nounwind readnone } attributes #3 = { "unsafe-fp-math"="true" "reciprocal-estimates"="sqrt,vec-sqrt" "denormal-fp-math"="ieee" } attributes #4 = { "unsafe-fp-math"="true" "reciprocal-estimates"="sqrt,vec-sqrt" "denormal-fp-math"="ieee,preserve-sign" } +attributes #5 = { "unsafe-fp-math"="true" "reciprocal-estimates"="all:0" }