diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -23046,9 +23046,10 @@ if (LegalDAG) return SDValue(); - // TODO: Handle half and/or extended types? + // TODO: Handle extended types? EVT VT = Op.getValueType(); - if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64) + if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 && + VT.getScalarType() != MVT::f64) return SDValue(); // If estimates are explicitly disabled for this function, we're done. @@ -23185,9 +23186,10 @@ if (LegalDAG) return SDValue(); - // TODO: Handle half and/or extended types? + // TODO: Handle extended types? EVT VT = Op.getValueType(); - if (VT.getScalarType() != MVT::f32 && VT.getScalarType() != MVT::f64) + if (VT.getScalarType() != MVT::f16 && VT.getScalarType() != MVT::f32 && + VT.getScalarType() != MVT::f64) return SDValue(); // If estimates are explicitly disabled for this function, we're done. diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -23148,6 +23148,7 @@ int &RefinementSteps, bool &UseOneConstNR, bool Reciprocal) const { + SDLoc DL(Op); EVT VT = Op.getValueType(); // SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps. @@ -23169,7 +23170,23 @@ UseOneConstNR = false; // There is no FSQRT for 512-bits, but there is RSQRT14. unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RSQRT14 : X86ISD::FRSQRT; - return DAG.getNode(Opcode, SDLoc(Op), VT, Op); + return DAG.getNode(Opcode, DL, VT, Op); + } + + if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) && + Subtarget.hasFP16()) { + if (RefinementSteps == ReciprocalEstimate::Unspecified) + RefinementSteps = 0; + + if (VT == MVT::f16) { + SDValue Zero = DAG.getIntPtrConstant(0, DL); + SDValue Undef = DAG.getUNDEF(MVT::v8f16); + Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op); + Op = DAG.getNode(X86ISD::RSQRT14S, DL, MVT::v8f16, Undef, Op); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero); + } + + return DAG.getNode(X86ISD::RSQRT14, DL, VT, Op); } return SDValue(); } @@ -23179,6 +23196,7 @@ SDValue X86TargetLowering::getRecipEstimate(SDValue Op, SelectionDAG &DAG, int Enabled, int &RefinementSteps) const { + SDLoc DL(Op); EVT VT = Op.getValueType(); // SSE1 has rcpss and rcpps. AVX adds a 256-bit variant for rcpps. @@ -23203,7 +23221,23 @@ // There is no FSQRT for 512-bits, but there is RCP14. unsigned Opcode = VT == MVT::v16f32 ? X86ISD::RCP14 : X86ISD::FRCP; - return DAG.getNode(Opcode, SDLoc(Op), VT, Op); + return DAG.getNode(Opcode, DL, VT, Op); + } + + if (VT.getScalarType() == MVT::f16 && isTypeLegal(VT) && + Subtarget.hasFP16()) { + if (RefinementSteps == ReciprocalEstimate::Unspecified) + RefinementSteps = 0; + + if (VT == MVT::f16) { + SDValue Zero = DAG.getIntPtrConstant(0, DL); + SDValue Undef = DAG.getUNDEF(MVT::v8f16); + Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v8f16, Op); + Op = DAG.getNode(X86ISD::RCP14S, DL, MVT::v8f16, Undef, Op); + return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Op, Zero); + } + + return DAG.getNode(X86ISD::RCP14, DL, VT, Op); } return SDValue(); } diff --git a/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll --- a/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512fp16-arith-vl-intrinsics.ll @@ -250,6 +250,16 @@ ret <16 x half> %res } +define <16 x half> @test_int_x86_avx512fp16_div_ph_256_fast(<16 x half> %x1, <16 x half> %x2) { +; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_256_fast: +; CHECK: # %bb.0: +; CHECK-NEXT: vrcpph %ymm1, %ymm1 +; CHECK-NEXT: vmulph %ymm0, %ymm1, %ymm0 +; CHECK-NEXT: retq + %res = fdiv fast <16 x half> %x1, %x2 + ret <16 x half> %res +} + define <16 x half> @test_int_x86_avx512fp16_mask_div_ph_256(<16 x half> %x1, <16 x half> %x2, <16 x half> %src, i16 %mask, <16 x half>* %ptr) { ; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_256: ; CHECK: # %bb.0: @@ -290,6 +300,16 @@ ret <8 x half> %res } +define <8 x half> @test_int_x86_avx512fp16_div_ph_128_fast(<8 x half> %x1, <8 x half> %x2) { +; CHECK-LABEL: test_int_x86_avx512fp16_div_ph_128_fast: +; CHECK: # %bb.0: +; CHECK-NEXT: vrcpph %xmm1, %xmm1 +; CHECK-NEXT: vmulph %xmm0, %xmm1, %xmm0 +; CHECK-NEXT: retq + %res = fdiv fast <8 x half> %x1, %x2 + ret <8 x half> %res +} + define <8 x half> @test_int_x86_avx512fp16_mask_div_ph_128(<8 x half> %x1, <8 x half> %x2, <8 x half> %src, i8 %mask, <8 x half>* %ptr) { ; CHECK-LABEL: test_int_x86_avx512fp16_mask_div_ph_128: ; CHECK: # %bb.0: diff --git a/llvm/test/CodeGen/X86/avx512fp16-arith.ll b/llvm/test/CodeGen/X86/avx512fp16-arith.ll --- a/llvm/test/CodeGen/X86/avx512fp16-arith.ll +++ b/llvm/test/CodeGen/X86/avx512fp16-arith.ll @@ -154,6 +154,16 @@ ret <32 x half> %x } +define <32 x half> @vdivph_512_test_fast(<32 x half> %i, <32 x half> %j) nounwind readnone { +; CHECK-LABEL: vdivph_512_test_fast: +; CHECK: ## %bb.0: +; CHECK-NEXT: vrcpph %zmm1, %zmm1 +; CHECK-NEXT: vmulph %zmm0, %zmm1, %zmm0 +; CHECK-NEXT: retq + %x = fdiv fast <32 x half> %i, %j + ret <32 x half> %x +} + define half @add_sh(half %i, half %j, half* %x.ptr) nounwind readnone { ; CHECK-LABEL: add_sh: ; CHECK: ## %bb.0: @@ -228,6 +238,16 @@ ret half %r } +define half @div_sh_3(half %i, half %j) nounwind readnone { +; CHECK-LABEL: div_sh_3: +; CHECK: ## %bb.0: +; CHECK-NEXT: vrcpsh %xmm1, %xmm1, %xmm1 +; CHECK-NEXT: vmulsh %xmm0, %xmm1, %xmm0 +; CHECK-NEXT: retq + %r = fdiv fast half %i, %j + ret half %r +} + define i1 @cmp_une_sh(half %x, half %y) { ; CHECK-LABEL: cmp_une_sh: ; CHECK: ## %bb.0: ## %entry diff --git a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll --- a/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512fp16-intrinsics.ll @@ -24,6 +24,17 @@ ret <32 x half> %1 } +define <32 x half> @test_sqrt_ph_512_fast(<32 x half> %a0, <32 x half> %a1) { +; CHECK-LABEL: test_sqrt_ph_512_fast: +; CHECK: # %bb.0: +; CHECK-NEXT: vrsqrtph %zmm0, %zmm0 +; CHECK-NEXT: vmulph %zmm0, %zmm1, %zmm0 +; CHECK-NEXT: retq + %1 = call fast <32 x half> @llvm.sqrt.v32f16(<32 x half> %a0) + %2 = fdiv fast <32 x half> %a1, %1 + ret <32 x half> %2 +} + define <32 x half> @test_mask_sqrt_ph_512(<32 x half> %a0, <32 x half> %passthru, i32 %mask) { ; CHECK-LABEL: test_mask_sqrt_ph_512: ; CHECK: # %bb.0: @@ -98,6 +109,19 @@ ret <8 x half> %res } +define half @test_sqrt_sh2(half %a0, half %a1) { +; CHECK-LABEL: test_sqrt_sh2: +; CHECK: # %bb.0: +; CHECK-NEXT: vrsqrtsh %xmm0, %xmm0, %xmm0 +; CHECK-NEXT: vmulsh %xmm0, %xmm1, %xmm0 +; CHECK-NEXT: retq + %1 = call fast half @llvm.sqrt.f16(half %a0) + %2 = fdiv fast half %a1, %1 + ret half %2 +} + +declare half @llvm.sqrt.f16(half) + define <8 x half> @test_sqrt_sh_r(<8 x half> %a0, <8 x half> %a1, <8 x half> %a2, i8 %mask) { ; CHECK-LABEL: test_sqrt_sh_r: ; CHECK: # %bb.0: diff --git a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll --- a/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll +++ b/llvm/test/CodeGen/X86/avx512fp16vl-intrinsics.ll @@ -958,6 +958,17 @@ ret <8 x half> %1 } +define <8 x half> @test_sqrt_ph_128_fast(<8 x half> %a0, <8 x half> %a1) { +; CHECK-LABEL: test_sqrt_ph_128_fast: +; CHECK: # %bb.0: +; CHECK-NEXT: vrsqrtph %xmm0, %xmm0 +; CHECK-NEXT: vmulph %xmm0, %xmm1, %xmm0 +; CHECK-NEXT: retq + %1 = call fast <8 x half> @llvm.sqrt.v8f16(<8 x half> %a0) + %2 = fdiv fast <8 x half> %a1, %1 + ret <8 x half> %2 +} + define <8 x half> @test_mask_sqrt_ph_128(<8 x half> %a0, <8 x half> %passthru, i8 %mask) { ; CHECK-LABEL: test_mask_sqrt_ph_128: ; CHECK: # %bb.0: @@ -992,6 +1003,17 @@ ret <16 x half> %1 } +define <16 x half> @test_sqrt_ph_256_fast(<16 x half> %a0, <16 x half> %a1) { +; CHECK-LABEL: test_sqrt_ph_256_fast: +; CHECK: # %bb.0: +; CHECK-NEXT: vrsqrtph %ymm0, %ymm0 +; CHECK-NEXT: vmulph %ymm0, %ymm1, %ymm0 +; CHECK-NEXT: retq + %1 = call fast <16 x half> @llvm.sqrt.v16f16(<16 x half> %a0) + %2 = fdiv fast <16 x half> %a1, %1 + ret <16 x half> %2 +} + define <16 x half> @test_mask_sqrt_ph_256(<16 x half> %a0, <16 x half> %passthru, i16 %mask) { ; CHECK-LABEL: test_mask_sqrt_ph_256: ; CHECK: # %bb.0: