Index: llvm/lib/Target/AMDGPU/AMDGPU.h =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPU.h +++ llvm/lib/Target/AMDGPU/AMDGPU.h @@ -18,6 +18,7 @@ class FunctionPass; class GCNTargetMachine; class ImmutablePass; +class MachineFunctionPass; class ModulePass; class Pass; class Target; @@ -73,6 +74,16 @@ ModulePass *createAMDGPULowerModuleLDSPass(); FunctionPass *createSIModeRegisterPass(); +namespace AMDGPU { +enum RegBankReassignMode { + RM_VGPR = 1, + RM_SGPR = 2, + RM_BOTH = RM_VGPR | RM_SGPR +}; +} +MachineFunctionPass * +createGCNRegBankReassignPass(AMDGPU::RegBankReassignMode Mode); + struct AMDGPUSimplifyLibCallsPass : PassInfoMixin { AMDGPUSimplifyLibCallsPass(TargetMachine &TM) : TM(TM) {} PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); Index: llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp =================================================================== --- llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp +++ llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp @@ -1180,7 +1180,7 @@ bool GCNPassConfig::addPreRewrite() { if (EnableRegReassign) { addPass(&GCNNSAReassignID); - addPass(&GCNRegBankReassignID); + addPass(createGCNRegBankReassignPass(AMDGPU::RM_BOTH)); } return true; } Index: llvm/lib/Target/AMDGPU/GCNRegBankReassign.cpp =================================================================== --- llvm/lib/Target/AMDGPU/GCNRegBankReassign.cpp +++ llvm/lib/Target/AMDGPU/GCNRegBankReassign.cpp @@ -42,6 +42,7 @@ #include "llvm/InitializePasses.h" using namespace llvm; +using namespace AMDGPU; static cl::opt VerifyStallCycles("amdgpu-verify-regbanks-reassign", cl::desc("Verify stall cycles in the regbanks reassign pass"), @@ -135,7 +136,8 @@ static char ID; public: - GCNRegBankReassign() : MachineFunctionPass(ID) { + GCNRegBankReassign(RegBankReassignMode Mode = RM_BOTH) + : MachineFunctionPass(ID), Mode(Mode) { initializeGCNRegBankReassignPass(*PassRegistry::getPassRegistry()); } @@ -167,6 +169,8 @@ LiveIntervals *LIS; + RegBankReassignMode Mode; + unsigned MaxNumVGPRs; unsigned MaxNumSGPRs; @@ -396,6 +400,10 @@ if (MI.isDebugValue()) return std::make_pair(StallCycles, UsedBanks); + if (!(Mode & RM_SGPR) && + MI.getDesc().TSFlags & (SIInstrFlags::SMRD | SIInstrFlags::SALU)) + return std::make_pair(StallCycles, UsedBanks); + RegsUsed.reset(); OperandMasks.clear(); for (const auto& Op : MI.explicit_uses()) { @@ -410,6 +418,8 @@ // Do not compute stalls for AGPRs if (TRI->hasAGPRs(RC)) continue; + if ((Mode != RM_BOTH) && !(Mode & (TRI->hasVGPRs(RC) ? RM_VGPR : RM_SGPR))) + continue; // Do not compute stalls if sub-register covers all banks if (Op.getSubReg()) { @@ -813,8 +823,11 @@ MRI = &MF.getRegInfo(); - LLVM_DEBUG(dbgs() << "=== RegBanks reassign analysis on function " << MF.getName() - << "\nNumVirtRegs = " << MRI->getNumVirtRegs() << "\n\n"); + LLVM_DEBUG(dbgs() << "=== RegBanks reassign analysis on function " + << MF.getName() << '\n' + << ((Mode & RM_VGPR) ? "VGPR " : "") + << ((Mode & RM_SGPR) ? "SGPR " : "") << "mode\n" + << "NumVirtRegs = " << MRI->getNumVirtRegs() << "\n\n"); if (MRI->getNumVirtRegs() > VRegThresh) { LLVM_DEBUG(dbgs() << "NumVirtRegs > " << VRegThresh @@ -880,3 +893,8 @@ return CyclesSaved > 0; } + +MachineFunctionPass * +llvm::createGCNRegBankReassignPass(RegBankReassignMode Mode) { + return new GCNRegBankReassign(Mode); +}