diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23190,6 +23190,10 @@ bool X86TargetLowering::isFsqrtCheap(SDValue Op, SelectionDAG &DAG) const { EVT VT = Op.getValueType(); + // We don't need to replace SQRT with RSQRT for half type. + if (VT.getScalarType() == MVT::f16) + return true; + // We never want to use both SQRT and RSQRT instructions for the same input. if (DAG.getNodeIfExists(X86ISD::FRSQRT, DAG.getVTList(VT), Op)) return false; @@ -23236,6 +23240,7 @@ if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) && Subtarget.hasFP16()) { + assert(Reciprocal && "Don't replace SQRT with RSQRT for half type"); if (RefinementSteps == ReciprocalEstimate::Unspecified) RefinementSteps = 0; diff --git a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll --- a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll @@ -123,12 +123,7 @@ define half @test_sqrt_sh3(half %a0, half %a1) { ; CHECK-LABEL: test_sqrt_sh3: ; CHECK: # %bb.0: -; CHECK-NEXT: vpbroadcastw {{.*#+}} xmm1 = [NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN] -; CHECK-NEXT: vpand %xmm1, %xmm0, %xmm1 -; CHECK-NEXT: vcmpltsh {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1, %k1 -; CHECK-NEXT: vrsqrtsh %xmm0, %xmm0, %xmm0 -; CHECK-NEXT: vxorps %xmm1, %xmm1, %xmm1 -; CHECK-NEXT: vmovsh %xmm1, %xmm0, %xmm0 {%k1} +; CHECK-NEXT: vsqrtsh %xmm0, %xmm0, %xmm0 ; CHECK-NEXT: retq %1 = call fast half @llvm.sqrt.f16(half %a0) ret half %1 diff --git a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll --- a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll @@ -972,10 +972,7 @@ define <8 x half> @test_sqrt_ph_128_fast2(<8 x half> %a0, <8 x half> %a1) { ; CHECK-LABEL: test_sqrt_ph_128_fast2: ; CHECK: # %bb.0: -; CHECK-NEXT: vpbroadcastw {{.*#+}} xmm1 = [NaN,NaN,NaN,NaN,NaN,NaN,NaN,NaN] -; CHECK-NEXT: vpand %xmm1, %xmm0, %xmm1 -; CHECK-NEXT: vcmpgeph {{\.?LCPI[0-9]+_[0-9]+}}(%rip){1to8}, %xmm1, %k1 -; CHECK-NEXT: vrsqrtph %xmm0, %xmm0 {%k1} {z} +; CHECK-NEXT: vsqrtph %xmm0, %xmm0 ; CHECK-NEXT: retq %1 = call fast <8 x half> @llvm.sqrt.v8f16(<8 x half> %a0) ret <8 x half> %1