Index: llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp +++ llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp @@ -23,6 +23,7 @@ #include using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "AMDGPUtti" @@ -355,6 +356,26 @@ return false; } +/// Match an fpext from half to float, or a constant we can convert. +static bool matchFPExtFromF16(Value *Arg, Value *&FPExtSrc) { + if (match(Arg, m_OneUse(m_FPExt(m_Value(FPExtSrc))))) + return FPExtSrc->getType()->isHalfTy(); + + ConstantFP *CFP; + if (match(Arg, m_ConstantFP(CFP))) { + bool LosesInfo; + APFloat Val(CFP->getValueAPF()); + Val.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &LosesInfo); + if (LosesInfo) + return false; + + FPExtSrc = ConstantFP::get(Type::getHalfTy(Arg->getContext()), Val); + return true; + } + + return false; +} + std::optional GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const { Intrinsic::ID IID = II.getIntrinsicID(); @@ -701,6 +722,20 @@ } } + if (!ST->hasMed3_16()) + break; + + Value *X, *Y, *Z; + + // Repeat floating-point width reduction done for minnum/maxnum. + // fmed3((fpext X), (fpext Y), (fpext Z)) -> fpext (fmed3(X, Y, Z)) + if (matchFPExtFromF16(Src0, X) && matchFPExtFromF16(Src1, Y) && + matchFPExtFromF16(Src2, Z)) { + Value *NewCall = IC.Builder.CreateIntrinsic(IID, {X->getType()}, + {X, Y, Z}, &II, II.getName()); + return new FPExtInst(NewCall, II.getType()); + } + break; } case Intrinsic::amdgcn_icmp: Index: llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll =================================================================== --- llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll +++ llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll @@ -24,10 +24,8 @@ ; ; GFX9-LABEL: define float @fmed3_f32_fpext_f16 ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG1:%.*]], half [[ARG2:%.*]]) #[[ATTR1:[0-9]+]] { -; GFX9-NEXT: [[ARG0_EXT:%.*]] = fpext half [[ARG0]] to float -; GFX9-NEXT: [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float -; GFX9-NEXT: [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]]) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG0]], half [[ARG1]], half [[ARG2]]) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %arg0.ext = fpext half %arg0 to float @@ -48,10 +46,8 @@ ; ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_flags ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG1:%.*]], half [[ARG2:%.*]]) #[[ATTR1]] { -; GFX9-NEXT: [[ARG0_EXT:%.*]] = fpext half [[ARG0]] to float -; GFX9-NEXT: [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float -; GFX9-NEXT: [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float -; GFX9-NEXT: [[MED3:%.*]] = call nsz float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]]) +; GFX9-NEXT: [[MED31:%.*]] = call nsz half @llvm.amdgcn.fmed3.f16(half [[ARG0]], half [[ARG1]], half [[ARG2]]) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %arg0.ext = fpext half %arg0 to float @@ -71,9 +67,8 @@ ; ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k0 ; GFX9-SAME: (half [[ARG1:%.*]], half [[ARG2:%.*]]) #[[ATTR1]] { -; GFX9-NEXT: [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float -; GFX9-NEXT: [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG1_EXT]], float [[ARG2_EXT]], float 2.000000e+00) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG1]], half [[ARG2]], half 0xH4000) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %arg1.ext = fpext half %arg1 to float @@ -92,9 +87,8 @@ ; ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k1 ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG2:%.*]]) #[[ATTR1]] { -; GFX9-NEXT: [[ARG0_EXT:%.*]] = fpext half [[ARG0]] to float -; GFX9-NEXT: [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG2_EXT]], float 2.000000e+00) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG0]], half [[ARG2]], half 0xH4000) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %arg0.ext = fpext half %arg0 to float @@ -113,9 +107,8 @@ ; ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k2 ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG1:%.*]]) #[[ATTR1]] { -; GFX9-NEXT: [[ARG0_EXT:%.*]] = fpext half [[ARG0]] to float -; GFX9-NEXT: [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float 2.000000e+00) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG0]], half [[ARG1]], half 0xH4000) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %arg0.ext = fpext half %arg0 to float @@ -133,8 +126,8 @@ ; ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k0_k1 ; GFX9-SAME: (half [[ARG2:%.*]]) #[[ATTR1]] { -; GFX9-NEXT: [[ARG2_EXT:%.*]] = fpext half [[ARG2]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG2_EXT]], float 0.000000e+00, float 1.600000e+01) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG2]], half 0xH0000, half 0xH4C00) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %arg2.ext = fpext half %arg2 to float @@ -151,8 +144,8 @@ ; ; GFX9-LABEL: define float @fmed3_f32_fpext_f16_k0_k2 ; GFX9-SAME: (half [[ARG1:%.*]]) #[[ATTR1]] { -; GFX9-NEXT: [[ARG1_EXT:%.*]] = fpext half [[ARG1]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG1_EXT]], float 0.000000e+00, float 2.000000e+00) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[ARG1]], half 0xH0000, half 0xH4000) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %arg1.ext = fpext half %arg1 to float @@ -177,10 +170,8 @@ ; GFX9-NEXT: [[FABS_ARG0:%.*]] = call half @llvm.fabs.f16(half [[ARG0]]) ; GFX9-NEXT: [[FABS_ARG1:%.*]] = call half @llvm.fabs.f16(half [[ARG1]]) ; GFX9-NEXT: [[FABS_ARG2:%.*]] = call half @llvm.fabs.f16(half [[ARG2]]) -; GFX9-NEXT: [[ARG0_EXT:%.*]] = fpext half [[FABS_ARG0]] to float -; GFX9-NEXT: [[ARG1_EXT:%.*]] = fpext half [[FABS_ARG1]] to float -; GFX9-NEXT: [[ARG2_EXT:%.*]] = fpext half [[FABS_ARG2]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]]) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[FABS_ARG0]], half [[FABS_ARG1]], half [[FABS_ARG2]]) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %fabs.arg0 = call half @llvm.fabs.f16(half %arg0) @@ -208,12 +199,10 @@ ; GFX9-LABEL: define float @fmed3_fabs_f32_fpext_f16 ; GFX9-SAME: (half [[ARG0:%.*]], half [[ARG1:%.*]], half [[ARG2:%.*]]) #[[ATTR1]] { ; GFX9-NEXT: [[TMP1:%.*]] = call half @llvm.fabs.f16(half [[ARG0]]) -; GFX9-NEXT: [[FABS_EXT_ARG0:%.*]] = fpext half [[TMP1]] to float ; GFX9-NEXT: [[TMP2:%.*]] = call half @llvm.fabs.f16(half [[ARG1]]) -; GFX9-NEXT: [[FABS_EXT_ARG1:%.*]] = fpext half [[TMP2]] to float ; GFX9-NEXT: [[TMP3:%.*]] = call half @llvm.fabs.f16(half [[ARG2]]) -; GFX9-NEXT: [[FABS_EXT_ARG2:%.*]] = fpext half [[TMP3]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[FABS_EXT_ARG0]], float [[FABS_EXT_ARG1]], float [[FABS_EXT_ARG2]]) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[TMP1]], half [[TMP2]], half [[TMP3]]) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %arg0.ext = fpext half %arg0 to float @@ -243,10 +232,8 @@ ; GFX9-NEXT: [[FNEG_ARG0:%.*]] = fneg half [[ARG0]] ; GFX9-NEXT: [[FNEG_ARG1:%.*]] = fneg half [[ARG1]] ; GFX9-NEXT: [[FNEG_ARG2:%.*]] = fneg half [[ARG2]] -; GFX9-NEXT: [[ARG0_EXT:%.*]] = fpext half [[FNEG_ARG0]] to float -; GFX9-NEXT: [[ARG1_EXT:%.*]] = fpext half [[FNEG_ARG1]] to float -; GFX9-NEXT: [[ARG2_EXT:%.*]] = fpext half [[FNEG_ARG2]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]]) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[FNEG_ARG0]], half [[FNEG_ARG1]], half [[FNEG_ARG2]]) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %fneg.arg0 = fneg half %arg0 @@ -315,10 +302,8 @@ ; GFX9-NEXT: [[FNEG_FABS_ARG0:%.*]] = fneg half [[FABS_ARG0]] ; GFX9-NEXT: [[FNEG_FABS_ARG1:%.*]] = fneg half [[FABS_ARG1]] ; GFX9-NEXT: [[FNEG_FABS_ARG2:%.*]] = fneg half [[FABS_ARG2]] -; GFX9-NEXT: [[ARG0_EXT:%.*]] = fpext half [[FNEG_FABS_ARG0]] to float -; GFX9-NEXT: [[ARG1_EXT:%.*]] = fpext half [[FNEG_FABS_ARG1]] to float -; GFX9-NEXT: [[ARG2_EXT:%.*]] = fpext half [[FNEG_FABS_ARG2]] to float -; GFX9-NEXT: [[MED3:%.*]] = call float @llvm.amdgcn.fmed3.f32(float [[ARG0_EXT]], float [[ARG1_EXT]], float [[ARG2_EXT]]) +; GFX9-NEXT: [[MED31:%.*]] = call half @llvm.amdgcn.fmed3.f16(half [[FNEG_FABS_ARG0]], half [[FNEG_FABS_ARG1]], half [[FNEG_FABS_ARG2]]) +; GFX9-NEXT: [[MED3:%.*]] = fpext half [[MED31]] to float ; GFX9-NEXT: ret float [[MED3]] ; %fabs.arg0 = call half @llvm.fabs.f16(half %arg0)