diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.h @@ -74,6 +74,7 @@ AMDGPUTTIImpl CommonTTI; bool IsGraphicsShader; bool HasFP32Denormals; + unsigned MaxVGPRs; const FeatureBitset InlineFeatureIgnoreList = { // Codegen control options which don't matter. @@ -133,7 +134,11 @@ TLI(ST->getTargetLowering()), CommonTTI(TM, F), IsGraphicsShader(AMDGPU::isShader(F.getCallingConv())), - HasFP32Denormals(AMDGPU::SIModeRegisterDefaults(F).allFP32Denormals()) {} + HasFP32Denormals(AMDGPU::SIModeRegisterDefaults(F).allFP32Denormals()), + MaxVGPRs(ST->getMaxNumVGPRs( + std::max(ST->getWavesPerEU(F).first, + ST->getWavesPerEUForWorkGroup( + ST->getFlatWorkGroupSizes(F).second)))) {} bool hasBranchDivergence() { return true; } bool useGPUDivergenceAnalysis() const; @@ -148,6 +153,7 @@ unsigned getHardwareNumberOfRegisters(bool Vector) const; unsigned getNumberOfRegisters(bool Vector) const; + unsigned getNumberOfRegisters(unsigned RCID) const; unsigned getRegisterBitWidth(bool Vector) const; unsigned getMinVectorRegisterBitWidth() const; unsigned getLoadVectorFactor(unsigned VF, unsigned LoadSize, diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetTransformInfo.cpp @@ -239,7 +239,7 @@ unsigned GCNTTIImpl::getHardwareNumberOfRegisters(bool Vec) const { // The concept of vector registers doesn't really exist. Some packed vector // operations operate on the normal 32-bit registers. - return 256; + return MaxVGPRs; } unsigned GCNTTIImpl::getNumberOfRegisters(bool Vec) const { @@ -248,6 +248,13 @@ return getHardwareNumberOfRegisters(Vec) >> 3; } +unsigned GCNTTIImpl::getNumberOfRegisters(unsigned RCID) const { + const SIRegisterInfo *TRI = ST->getRegisterInfo(); + const TargetRegisterClass *RC = TRI->getRegClass(RCID); + unsigned NumVGPRs = (TRI->getRegSizeInBits(*RC) + 31) / 32; + return getHardwareNumberOfRegisters(false) / NumVGPRs; +} + unsigned GCNTTIImpl::getRegisterBitWidth(bool Vector) const { return 32; }