diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp --- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp +++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp @@ -69,8 +69,6 @@ namespace { auto inst_counter_types() { return enum_seq(VM_CNT, NUM_INST_CNTS); } -using RegInterval = std::pair; - struct HardwareLimits { unsigned VmcntMax; unsigned ExpcntMax; @@ -78,11 +76,19 @@ unsigned VscntMax; }; -struct RegisterEncoding { - unsigned VGPR0; - unsigned VGPRL; - unsigned SGPR0; - unsigned SGPRL; +// TODO: Can ADT's AddressRange be generalized to implement this? +struct RegInterval { + unsigned Start = 0; + unsigned End = 0; + + auto seq() const { return llvm::seq(Start, End); } + auto begin() const { return seq().begin(); } + auto end() const { return seq().end(); } +}; + +struct RegPoolIntervals { + RegInterval VGPR; + RegInterval SGPR; }; enum WaitEventType { @@ -140,6 +146,12 @@ VMEM_BVH }; +static unsigned getRegPoint(unsigned Reg, const GCNSubtarget &ST) { + // Use encoding indexes to form register intervals. + const SIRegisterInfo &TRI = ST.getInstrInfo()->getRegisterInfo(); + return TRI.getEncodingValue(AMDGPU::getMCReg(Reg, ST)); +} + static bool updateVMCntOnly(const MachineInstr &Inst) { return SIInstrInfo::isVMEM(Inst) || SIInstrInfo::isFLATGlobal(Inst) || SIInstrInfo::isFLATScratch(Inst); @@ -186,8 +198,8 @@ class WaitcntBrackets { public: WaitcntBrackets(const GCNSubtarget *SubTarget, HardwareLimits Limits, - RegisterEncoding Encoding) - : ST(SubTarget), Limits(Limits), Encoding(Encoding) {} + RegPoolIntervals Intervals) + : ST(SubTarget), Limits(Limits), Intervals(Intervals) {} unsigned getWaitCountMax(InstCounterType T) const { switch (T) { @@ -338,7 +350,7 @@ const GCNSubtarget *ST = nullptr; HardwareLimits Limits = {}; - RegisterEncoding Encoding = {}; + RegPoolIntervals Intervals; unsigned ScoreLBs[NUM_INST_CNTS] = {0}; unsigned ScoreUBs[NUM_INST_CNTS] = {0}; unsigned PendingEvents = 0; @@ -495,7 +507,7 @@ unsigned OpNo) const { const MachineOperand &Op = MI->getOperand(OpNo); if (!TRI->isInAllocatableClass(Op.getReg())) - return {-1, -1}; + return {}; // A use via a PW operand does not need a waitcnt. // A partial write is not a WAW. @@ -503,28 +515,28 @@ RegInterval Result; - unsigned Reg = TRI->getEncodingValue(AMDGPU::getMCReg(Op.getReg(), *ST)); + unsigned Reg = getRegPoint(Op.getReg(), *ST); if (TRI->isVectorRegister(*MRI, Op.getReg())) { - assert(Reg >= Encoding.VGPR0 && Reg <= Encoding.VGPRL); - Result.first = Reg - Encoding.VGPR0; + assert(Reg >= Intervals.VGPR.Start && Reg <= Intervals.VGPR.End); + Result.Start = Reg - Intervals.VGPR.Start; if (TRI->isAGPR(*MRI, Op.getReg())) - Result.first += AGPR_OFFSET; - assert(Result.first >= 0 && Result.first < SQ_MAX_PGM_VGPRS); + Result.Start += AGPR_OFFSET; + assert(Result.Start < SQ_MAX_PGM_VGPRS); } else if (TRI->isSGPRReg(*MRI, Op.getReg())) { - assert(Reg >= Encoding.SGPR0 && Reg < SQ_MAX_PGM_SGPRS); - Result.first = Reg - Encoding.SGPR0 + NUM_ALL_VGPRS; - assert(Result.first >= NUM_ALL_VGPRS && - Result.first < SQ_MAX_PGM_SGPRS + NUM_ALL_VGPRS); + assert(Reg >= Intervals.SGPR.Start && Reg < SQ_MAX_PGM_SGPRS); + Result.Start = Reg - Intervals.SGPR.Start + NUM_ALL_VGPRS; + assert(Result.Start >= NUM_ALL_VGPRS && + Result.Start < SQ_MAX_PGM_SGPRS + NUM_ALL_VGPRS); } // TODO: Handle TTMP // else if (TRI->isTTMP(*MRI, Reg.getReg())) ... else - return {-1, -1}; + return {}; const TargetRegisterClass *RC = TII->getOpRegClass(*MI, OpNo); unsigned Size = TRI->getRegSizeInBits(*RC); - Result.second = Result.first + ((Size + 16) / 32); + Result.End = Result.Start + divideCeil(Size, 32); return Result; } @@ -536,9 +548,8 @@ unsigned Val) { RegInterval Interval = getRegInterval(MI, TII, MRI, TRI, OpNo); assert(TRI->isVectorRegister(*MRI, MI->getOperand(OpNo).getReg())); - for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { + for (unsigned RegNo : Interval) setRegScore(RegNo, EXP_CNT, Val); - } } // MUBUF and FLAT LDS DMA operations need a wait on vmcnt before LDS written @@ -655,9 +666,7 @@ MachineOperand &DefMO = Inst.getOperand(I); if (DefMO.isReg() && DefMO.isDef() && TRI->isVGPR(*MRI, DefMO.getReg())) { - setRegScore( - TRI->getEncodingValue(AMDGPU::getMCReg(DefMO.getReg(), *ST)), - EXP_CNT, CurrScore); + setRegScore(getRegPoint(DefMO.getReg(), *ST), EXP_CNT, CurrScore); } } } @@ -676,10 +685,8 @@ MachineOperand *MO = TII->getNamedOperand(Inst, AMDGPU::OpName::data); unsigned OpNo;//TODO: find the OpNo for this operand; RegInterval Interval = getRegInterval(&Inst, TII, MRI, TRI, OpNo); - for (int RegNo = Interval.first; RegNo < Interval.second; - ++RegNo) { + for (unsigned RegNo : Interval) setRegScore(RegNo + NUM_ALL_VGPRS, t, CurrScore); - } #endif } else { // Match the score to the destination registers. @@ -689,17 +696,16 @@ continue; RegInterval Interval = getRegInterval(&Inst, TII, MRI, TRI, I); if (T == VM_CNT) { - if (Interval.first >= NUM_ALL_VGPRS) + if (Interval.Start >= NUM_ALL_VGPRS) continue; if (updateVMCntOnly(Inst)) { VmemType V = getVmemType(Inst); - for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) + for (unsigned RegNo : Interval) VgprVmemTypes[RegNo] |= 1 << V; } } - for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { + for (unsigned RegNo : Interval) setRegScore(RegNo, T, CurrScore); - } } if (Inst.mayStore() && (TII->isDS(Inst) || mayWriteLDSThroughDMA(Inst))) { setRegScore(SQ_MAX_PGM_VGPRS + EXTRA_VGPR_LDS, T, CurrScore); @@ -1136,8 +1142,7 @@ RegInterval CallAddrOpInterval = ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, CallAddrOpIdx); - for (int RegNo = CallAddrOpInterval.first; - RegNo < CallAddrOpInterval.second; ++RegNo) + for (unsigned RegNo : CallAddrOpInterval) ScoreBrackets.determineWait(LGKM_CNT, RegNo, Wait); int RtnAddrOpIdx = @@ -1146,8 +1151,7 @@ RegInterval RtnAddrOpInterval = ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, RtnAddrOpIdx); - for (int RegNo = RtnAddrOpInterval.first; - RegNo < RtnAddrOpInterval.second; ++RegNo) + for (unsigned RegNo : RtnAddrOpInterval) ScoreBrackets.determineWait(LGKM_CNT, RegNo, Wait); } } @@ -1200,7 +1204,7 @@ ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, I); const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg()); - for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { + for (unsigned RegNo : Interval) { if (IsVGPR) { // RAW always needs an s_waitcnt. WAW needs an s_waitcnt unless the // previous write and this write are the same type of VMEM @@ -1777,7 +1781,7 @@ RegInterval Interval = Brackets.getRegInterval(&MI, TII, MRI, TRI, I); // Vgpr use if (Op.isUse()) { - for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { + for (unsigned RegNo : Interval) { // If we find a register that is loaded inside the loop, 1. and 2. // are invalidated and we can exit. if (VgprDef.contains(RegNo)) @@ -1793,7 +1797,7 @@ } // VMem load vgpr def else if (isVMEMOrFlatVMEM(MI) && MI.mayLoad() && Op.isDef()) - for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) { + for (unsigned RegNo : Interval) { // If we find a register that is loaded inside the loop, 1. and 2. // are invalidated and we can exit. if (VgprUse.contains(RegNo)) @@ -1833,11 +1837,11 @@ assert(NumVGPRsMax <= SQ_MAX_PGM_VGPRS); assert(NumSGPRsMax <= SQ_MAX_PGM_SGPRS); - RegisterEncoding Encoding = {}; - Encoding.VGPR0 = TRI->getEncodingValue(AMDGPU::VGPR0); - Encoding.VGPRL = Encoding.VGPR0 + NumVGPRsMax - 1; - Encoding.SGPR0 = TRI->getEncodingValue(AMDGPU::SGPR0); - Encoding.SGPRL = Encoding.SGPR0 + NumSGPRsMax - 1; + RegPoolIntervals Intervals = {}; + Intervals.VGPR.Start = getRegPoint(AMDGPU::VGPR0, *ST); + Intervals.VGPR.End = Intervals.VGPR.Start + NumVGPRsMax - 1; + Intervals.SGPR.Start = getRegPoint(AMDGPU::SGPR0, *ST); + Intervals.SGPR.End = Intervals.SGPR.Start + NumSGPRsMax - 1; TrackedWaitcntSet.clear(); BlockInfos.clear(); @@ -1884,9 +1888,9 @@ *Brackets = *BI.Incoming; } else { if (!Brackets) - Brackets = std::make_unique(ST, Limits, Encoding); + Brackets = std::make_unique(ST, Limits, Intervals); else - *Brackets = WaitcntBrackets(ST, Limits, Encoding); + *Brackets = WaitcntBrackets(ST, Limits, Intervals); } Modified |= insertWaitcntInBlock(MF, *MBB, *Brackets);