Index: llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp +++ llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp @@ -408,6 +408,13 @@ int DMaskIdx = -1, bool IsLoad = true); +/// Return true if it's legal to contract llvm.amdgcn.rcp(llvm.sqrt) +static bool canContractSqrtToRsq(const FPMathOperator *SqrtOp) { + return (SqrtOp->getType()->isFloatTy() && + (SqrtOp->hasApproxFunc() || SqrtOp->getFPAccuracy() >= 1.0f)) || + SqrtOp->getType()->isHalfTy(); +} + std::optional GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const { Intrinsic::ID IID = II.getIntrinsicID(); @@ -437,6 +444,37 @@ return IC.replaceInstUsesWith(II, ConstantFP::get(II.getContext(), Val)); } + FastMathFlags FMF = cast(II).getFastMathFlags(); + if (!FMF.allowContract()) + break; + auto *SrcCI = dyn_cast(Src); + if (!SrcCI) + break; + + auto IID = SrcCI->getIntrinsicID(); + // llvm.amdgcn.rcp(llvm.amdgcn.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable + // + // llvm.amdgcn.rcp(llvm.sqrt(x)) -> llvm.amdgcn.rsq(x) if contractable and + // relaxed. + if (IID == Intrinsic::amdgcn_sqrt || IID == Intrinsic::sqrt) { + const FPMathOperator *SqrtOp = cast(SrcCI); + FastMathFlags InnerFMF = SqrtOp->getFastMathFlags(); + if (!InnerFMF.allowContract() || !SrcCI->hasOneUse()) + break; + + if (IID == Intrinsic::sqrt && !canContractSqrtToRsq(SqrtOp)) + break; + + Function *NewDecl = Intrinsic::getDeclaration( + SrcCI->getModule(), Intrinsic::amdgcn_rsq, {SrcCI->getType()}); + + InnerFMF |= FMF; + II.setFastMathFlags(InnerFMF); + + II.setCalledFunction(NewDecl); + return IC.replaceOperand(II, 0, SrcCI->getArgOperand(0)); + } + break; } case Intrinsic::amdgcn_sqrt: Index: llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll =================================================================== --- llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll +++ llvm/test/Transforms/InstCombine/AMDGPU/rcp-contract-rsq.ll @@ -22,8 +22,7 @@ define float @amdgcn_rcp_amdgcn_sqrt_f32_contract(float %x) { ; CHECK-LABEL: define float @amdgcn_rcp_amdgcn_sqrt_f32_contract ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1:[0-9]+]] { -; CHECK-NEXT: [[SQRT:%.*]] = call contract float @llvm.amdgcn.sqrt.f32(float [[X]]) -; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rsq.f32(float [[X]]) ; CHECK-NEXT: ret float [[RSQ]] ; %sqrt = call contract float @llvm.amdgcn.sqrt.f32(float %x) @@ -76,8 +75,7 @@ define float @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f32_contract(float %x) { ; CHECK-LABEL: define float @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f32_contract ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] { -; CHECK-NEXT: [[SQRT:%.*]] = call nnan contract float @llvm.amdgcn.sqrt.f32(float [[X]]) -; CHECK-NEXT: [[RSQ:%.*]] = call ninf contract float @llvm.amdgcn.rcp.f32(float [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call nnan ninf contract float @llvm.amdgcn.rsq.f32(float [[X]]) ; CHECK-NEXT: ret float [[RSQ]] ; %sqrt = call nnan contract float @llvm.amdgcn.sqrt.f32(float %x) @@ -89,8 +87,7 @@ define half @amdgcn_rcp_amdgcn_sqrt_f16_contract(half %x) { ; CHECK-LABEL: define half @amdgcn_rcp_amdgcn_sqrt_f16_contract ; CHECK-SAME: (half [[X:%.*]]) #[[ATTR1]] { -; CHECK-NEXT: [[SQRT:%.*]] = call contract half @llvm.amdgcn.sqrt.f16(half [[X]]) -; CHECK-NEXT: [[RSQ:%.*]] = call contract half @llvm.amdgcn.rcp.f16(half [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call contract half @llvm.amdgcn.rsq.f16(half [[X]]) ; CHECK-NEXT: ret half [[RSQ]] ; %sqrt = call contract half @llvm.amdgcn.sqrt.f16(half %x) @@ -143,8 +140,7 @@ define half @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f16_contract(half %x) { ; CHECK-LABEL: define half @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f16_contract ; CHECK-SAME: (half [[X:%.*]]) #[[ATTR1]] { -; CHECK-NEXT: [[SQRT:%.*]] = call nnan contract half @llvm.amdgcn.sqrt.f16(half [[X]]) -; CHECK-NEXT: [[RSQ:%.*]] = call ninf contract half @llvm.amdgcn.rcp.f16(half [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call nnan ninf contract half @llvm.amdgcn.rsq.f16(half [[X]]) ; CHECK-NEXT: ret half [[RSQ]] ; %sqrt = call nnan contract half @llvm.amdgcn.sqrt.f16(half %x) @@ -156,8 +152,7 @@ define double @amdgcn_rcp_amdgcn_sqrt_f64_contract(double %x) { ; CHECK-LABEL: define double @amdgcn_rcp_amdgcn_sqrt_f64_contract ; CHECK-SAME: (double [[X:%.*]]) #[[ATTR1]] { -; CHECK-NEXT: [[SQRT:%.*]] = call contract double @llvm.amdgcn.sqrt.f64(double [[X]]) -; CHECK-NEXT: [[RSQ:%.*]] = call contract double @llvm.amdgcn.rcp.f64(double [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call contract double @llvm.amdgcn.rsq.f64(double [[X]]) ; CHECK-NEXT: ret double [[RSQ]] ; %sqrt = call contract double @llvm.amdgcn.sqrt.f64(double %x) @@ -210,8 +205,7 @@ define double @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f64_contract(double %x) { ; CHECK-LABEL: define double @amdgcn_rcp_nnan_amdgcn_sqrt_ninf_f64_contract ; CHECK-SAME: (double [[X:%.*]]) #[[ATTR1]] { -; CHECK-NEXT: [[SQRT:%.*]] = call nnan contract double @llvm.amdgcn.sqrt.f64(double [[X]]) -; CHECK-NEXT: [[RSQ:%.*]] = call ninf contract double @llvm.amdgcn.rcp.f64(double [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call nnan ninf contract double @llvm.amdgcn.rsq.f64(double [[X]]) ; CHECK-NEXT: ret double [[RSQ]] ; %sqrt = call nnan contract double @llvm.amdgcn.sqrt.f64(double %x) @@ -236,8 +230,7 @@ define half @amdgcn_rcp_sqrt_f16_contract(half %x) { ; CHECK-LABEL: define half @amdgcn_rcp_sqrt_f16_contract ; CHECK-SAME: (half [[X:%.*]]) #[[ATTR1]] { -; CHECK-NEXT: [[SQRT:%.*]] = call contract half @llvm.sqrt.f16(half [[X]]) -; CHECK-NEXT: [[RSQ:%.*]] = call contract half @llvm.amdgcn.rcp.f16(half [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call contract half @llvm.amdgcn.rsq.f16(half [[X]]) ; CHECK-NEXT: ret half [[RSQ]] ; %sqrt = call contract half @llvm.sqrt.f16(half %x) @@ -261,8 +254,7 @@ define float @amdgcn_rcp_afn_sqrt_f32_contract(float %x) { ; CHECK-LABEL: define float @amdgcn_rcp_afn_sqrt_f32_contract ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] { -; CHECK-NEXT: [[SQRT:%.*]] = call contract afn float @llvm.sqrt.f32(float [[X]]) -; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call contract afn float @llvm.amdgcn.rsq.f32(float [[X]]) ; CHECK-NEXT: ret float [[RSQ]] ; %sqrt = call afn contract float @llvm.sqrt.f32(float %x) @@ -273,8 +265,7 @@ define float @amdgcn_rcp_fpmath3_sqrt_f32_contract(float %x) { ; CHECK-LABEL: define float @amdgcn_rcp_fpmath3_sqrt_f32_contract ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] { -; CHECK-NEXT: [[SQRT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]), !fpmath !0 -; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rsq.f32(float [[X]]) ; CHECK-NEXT: ret float [[RSQ]] ; %sqrt = call contract float @llvm.sqrt.f32(float %x), !fpmath !0 @@ -285,8 +276,7 @@ define float @amdgcn_rcp_fpmath1_sqrt_f32_contract(float %x) { ; CHECK-LABEL: define float @amdgcn_rcp_fpmath1_sqrt_f32_contract ; CHECK-SAME: (float [[X:%.*]]) #[[ATTR1]] { -; CHECK-NEXT: [[SQRT:%.*]] = call contract float @llvm.sqrt.f32(float [[X]]), !fpmath !1 -; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rcp.f32(float [[SQRT]]) +; CHECK-NEXT: [[RSQ:%.*]] = call contract float @llvm.amdgcn.rsq.f32(float [[X]]) ; CHECK-NEXT: ret float [[RSQ]] ; %sqrt = call contract float @llvm.sqrt.f32(float %x), !fpmath !1