Index: lib/Target/AMDGPU/SIMachineScheduler.h =================================================================== --- lib/Target/AMDGPU/SIMachineScheduler.h +++ lib/Target/AMDGPU/SIMachineScheduler.h @@ -20,6 +20,7 @@ #include "llvm/CodeGen/MachineScheduler.h" #include "llvm/CodeGen/RegisterPressure.h" #include "llvm/CodeGen/ScheduleDAG.h" +#include "llvm/MC/LaneBitmask.h" #include #include #include @@ -85,8 +86,8 @@ // Note that some registers are not 32 bits, // and thus the pressure is not equal // to the number of live registers. - std::set LiveInRegs; - std::set LiveOutRegs; + SmallVector LiveInRegs; + SmallVector LiveOutRegs; bool Scheduled = false; bool HighLatencyBlock = false; @@ -161,8 +162,8 @@ return InternalAdditionnalPressure; } - std::set &getInRegs() { return LiveInRegs; } - std::set &getOutRegs() { return LiveOutRegs; } + SmallVector &getInRegs() { return LiveInRegs; } + SmallVector &getOutRegs() { return LiveOutRegs; } void printDebug(bool Full); @@ -225,6 +226,7 @@ class SIScheduleBlockCreator { SIScheduleDAGMI *DAG; + bool ShouldTrackLaneMasks; // unique_ptr handles freeing memory for us. std::vector> BlockPtrs; std::map CurrentBottomUpReservedDependencyColoring; public: - SIScheduleBlockCreator(SIScheduleDAGMI *DAG); + SIScheduleBlockCreator(SIScheduleDAGMI *DAG, bool ShouldTrackLaneMasks); ~SIScheduleBlockCreator(); SIScheduleBlocks @@ -322,10 +324,22 @@ SISchedulerBlockSchedulerVariant Variant; std::vector Blocks; - std::vector> LiveOutRegsNumUsages; - std::set LiveRegs; + // For each register, basis of LaneBitmask to generate all + // LaneBitmasks involved in the scheduling. + // If the register is not in the map, it is assumed it is always + // used with full mask. + std::map> LaneMaskBasisForReg; + + // The RegisterMaskPair below are assumed to be pairs (Reg, Mask) + // such that Mask is in LaneMaskBasisForReg[Reg] if exists. + std::vector> LiveOutRegsNumUsages; + std::map LiveRegs; // Num of schedulable unscheduled blocks reading the register. - std::map LiveRegsConsumers; + std::map LiveRegsConsumers; + // Blocks's getInRegs, but with RegisterMaskPair with Mask + // in LaneMaskBasisForReg[Reg] if exists. + DenseMap> InRegsForBlock; + DenseMap> OutRegsForBlock; std::vector LastPosHighLatencyParentScheduled; int LastPosWaitedHighLatency; @@ -357,6 +371,15 @@ unsigned getSGPRUsage() { return maxSregUsage; } private: + + // Convert Reg/Mask to a list of Reg/Mask, with Mask in + // LaneMaskBasisForReg. + SmallVector getPairsForReg(unsigned Reg, + LaneBitmask Mask); + // Idem for a list of Reg/Mask + SmallVector getPairsForRegs( + const SmallVector Regs); + struct SIBlockSchedCandidate : SISchedulerCandidate { // The best Block candidate. SIScheduleBlock *Block = nullptr; @@ -392,15 +415,17 @@ SIBlockSchedCandidate &TryCand); SIScheduleBlock *pickBlock(); - void addLiveRegs(std::set &Regs); - void decreaseLiveRegs(SIScheduleBlock *Block, std::set &Regs); + // The Mask in RegisterMaskPair needs to be + // element of LaneMaskBasisForReg. + void addLiveRegs(SmallVector &Regs); + void decreaseLiveRegs(SIScheduleBlock *Block, + SmallVector &Regs); void releaseBlockSuccs(SIScheduleBlock *Parent); void blockScheduled(SIScheduleBlock *Block); // Check register pressure change - // by scheduling a block with these LiveIn and LiveOut. - std::vector checkRegUsageImpact(std::set &InRegs, - std::set &OutRegs); + // by scheduling a block + void checkRegUsageImpact(unsigned BlockID, int &DiffVGPR, int &DiffSGPR); void schedule(); std::set findPathRegUsage(int SearchDepthLimit, @@ -420,7 +445,8 @@ SIScheduleBlockCreator BlockCreator; public: - SIScheduler(SIScheduleDAGMI *DAG) : DAG(DAG), BlockCreator(DAG) {} + SIScheduler(SIScheduleDAGMI *DAG, bool ShouldTrackLaneMasks) : + DAG(DAG), BlockCreator(DAG, ShouldTrackLaneMasks) {} ~SIScheduler() = default; @@ -452,7 +478,7 @@ // To init Block's RPTracker. void initRPTracker(RegPressureTracker &RPTracker) { - RPTracker.init(&MF, RegClassInfo, LIS, BB, RegionBegin, false, false); + RPTracker.init(&MF, RegClassInfo, LIS, BB, RegionBegin, ShouldTrackLaneMasks, false); } MachineBasicBlock *getBB() { return BB; } @@ -472,20 +498,12 @@ unsigned &VgprUsage, unsigned &SgprUsage); - std::set getInRegs() { - std::set InRegs; - for (const auto &RegMaskPair : RPTracker.getPressure().LiveInRegs) { - InRegs.insert(RegMaskPair.RegUnit); - } - return InRegs; + const SmallVector getInRegs() { + return RPTracker.getPressure().LiveInRegs; } - std::set getOutRegs() { - std::set OutRegs; - for (const auto &RegMaskPair : RPTracker.getPressure().LiveOutRegs) { - OutRegs.insert(RegMaskPair.RegUnit); - } - return OutRegs; + const SmallVector getOutRegs() { + return RPTracker.getPressure().LiveOutRegs; }; unsigned getVGPRSetID() const { return VGPRSetID; } Index: lib/Target/AMDGPU/SIMachineScheduler.cpp =================================================================== --- lib/Target/AMDGPU/SIMachineScheduler.cpp +++ lib/Target/AMDGPU/SIMachineScheduler.cpp @@ -28,6 +28,7 @@ #include "llvm/CodeGen/MachineScheduler.h" #include "llvm/CodeGen/RegisterPressure.h" #include "llvm/CodeGen/SlotIndexes.h" +#include "llvm/MC/LaneBitmask.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -301,22 +302,26 @@ Scheduled = true; } -// Returns if the register was set between first and last. -static bool isDefBetween(unsigned Reg, +// Returns if one of the register lanes was set between first and last. +// Lanes is set to the lane mask set. +static LaneBitmask findDefsBetween(unsigned Reg, SlotIndex First, SlotIndex Last, const MachineRegisterInfo *MRI, const LiveIntervals *LIS) { - for (MachineRegisterInfo::def_instr_iterator - UI = MRI->def_instr_begin(Reg), - UE = MRI->def_instr_end(); UI != UE; ++UI) { - const MachineInstr* MI = &*UI; + LaneBitmask LaneMask = LaneBitmask::getNone(); + const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo(); + for (const MachineOperand &MO : MRI->def_operands(Reg)) { + const MachineInstr* MI = MO.getParent(); if (MI->isDebugValue()) continue; SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot(); - if (InstSlot >= First && InstSlot <= Last) - return true; + if (InstSlot >= First && InstSlot <= Last) { + unsigned SubRegIdx = MO.getSubReg(); + LaneBitmask DefMask = TRI->getSubRegIndexLaneMask(SubRegIdx); + LaneMask |= DefMask; + } } - return false; + return LaneMask; } void SIScheduleBlock::initRegPressure(MachineBasicBlock::iterator BeginBlock, @@ -343,10 +348,12 @@ TopRPTracker.addLiveRegs(RPTracker.getPressure().LiveInRegs); BotRPTracker.addLiveRegs(RPTracker.getPressure().LiveOutRegs); + LiveInRegs.clear(); + // Do not Track Physical Registers, because it messes up. for (const auto &RegMaskPair : RPTracker.getPressure().LiveInRegs) { if (TargetRegisterInfo::isVirtualRegister(RegMaskPair.RegUnit)) - LiveInRegs.insert(RegMaskPair.RegUnit); + LiveInRegs.push_back(RegMaskPair); } LiveOutRegs.clear(); // There is several possibilities to distinguish: @@ -371,14 +378,21 @@ // The RPTracker's LiveOutRegs has 1, 3, (some correct or incorrect)4, 5, 7 // Comparing to LiveInRegs is not sufficient to differenciate 4 vs 5, 7 // The use of findDefBetween removes the case 4. + // LaneMask: Conceptually this is the same than described above, except + // we differenciate what happens for all reg lanes. for (const auto &RegMaskPair : RPTracker.getPressure().LiveOutRegs) { unsigned Reg = RegMaskPair.RegUnit; - if (TargetRegisterInfo::isVirtualRegister(Reg) && - isDefBetween(Reg, LIS->getInstructionIndex(*BeginBlock).getRegSlot(), - LIS->getInstructionIndex(*EndBlock).getRegSlot(), MRI, - LIS)) { - LiveOutRegs.insert(Reg); - } + if (!TargetRegisterInfo::isVirtualRegister(Reg)) + continue; + LaneBitmask LaneMask = findDefsBetween(Reg, + LIS->getInstructionIndex(*BeginBlock).getRegSlot(), + LIS->getInstructionIndex(*EndBlock).getRegSlot(), MRI, + LIS); + // Being in LaneMask but not in RegMaskPair.LaneMask means the lane + // Was temporarily defined, but then consumed in the block. + LaneMask &= RegMaskPair.LaneMask; + if (LaneMask.any()) + LiveOutRegs.push_back(RegisterMaskPair(RegMaskPair.RegUnit, LaneMask)); } // Pressure = sum_alive_registers register size @@ -597,12 +611,20 @@ dbgs() << "LiveOutPressure " << LiveOutPressure[DAG->getSGPRSetID()] << ' ' << LiveOutPressure[DAG->getVGPRSetID()] << "\n\n"; dbgs() << "LiveIns:\n"; - for (unsigned Reg : LiveInRegs) - dbgs() << PrintVRegOrUnit(Reg, DAG->getTRI()) << ' '; + for (const auto &Ins : LiveInRegs) { + dbgs() << PrintVRegOrUnit(Ins.RegUnit, DAG->getTRI()); + if (!Ins.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(Ins.LaneMask); + dbgs() << ' '; + } dbgs() << "\nLiveOuts:\n"; - for (unsigned Reg : LiveOutRegs) - dbgs() << PrintVRegOrUnit(Reg, DAG->getTRI()) << ' '; + for (const auto &Outs : LiveOutRegs) { + dbgs() << PrintVRegOrUnit(Outs.RegUnit, DAG->getTRI()); + if (!Outs.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(Outs.LaneMask); + dbgs() << ' '; + } } dbgs() << "\nInstructions:\n"; @@ -622,8 +644,9 @@ // SIScheduleBlockCreator // -SIScheduleBlockCreator::SIScheduleBlockCreator(SIScheduleDAGMI *DAG) : -DAG(DAG) { +SIScheduleBlockCreator::SIScheduleBlockCreator(SIScheduleDAGMI *DAG, + bool ShouldTrackLaneMasks) : +DAG(DAG), ShouldTrackLaneMasks(ShouldTrackLaneMasks) { } SIScheduleBlockCreator::~SIScheduleBlockCreator() = default; @@ -1303,6 +1326,20 @@ // It would gain a lot if there was a way to recompute the // LiveIntervals for the entire scheduling region. DAG->getLIS()->handleMove(*MI, /*UpdateFlags=*/true); + + if (!MI->isDebugValue()) { + // Reset read - undef flags and update them later. + for (auto &Op : MI->operands()) + if (Op.isReg() && Op.isDef()) + Op.setIsUndef(false); + + RegisterOperands RegOpers; + RegOpers.collect(*MI, *DAG->getTRI(), *DAG->getMRI(), + ShouldTrackLaneMasks, false); + // Adjust liveness and add missing dead+read-undef flags. + auto SlotIdx = DAG->getLIS()->getInstructionIndex(*MI).getRegSlot(); + RegOpers.adjustLaneLiveness(*DAG->getLIS(), *DAG->getMRI(), SlotIdx, MI); + } PosNew.push_back(CurrentTopFastSched); } } @@ -1324,11 +1361,26 @@ MachineBasicBlock::iterator POld = PosOld[i-1]; MachineBasicBlock::iterator PNew = PosNew[i-1]; if (PNew != POld) { + MachineInstr &MI = *POld; // Update the instruction stream. DAG->getBB()->splice(POld, DAG->getBB(), PNew); // Update LiveIntervals. DAG->getLIS()->handleMove(*POld, /*UpdateFlags=*/true); + + if (!MI.isDebugValue()) { + // Reset read - undef flags and update them later. + for (auto &Op : MI.operands()) + if (Op.isReg() && Op.isDef()) + Op.setIsUndef(false); + + RegisterOperands RegOpers; + RegOpers.collect(MI, *DAG->getTRI(), *DAG->getMRI(), + ShouldTrackLaneMasks, false); + // Adjust liveness and add missing dead+read-undef flags. + auto SlotIdx = DAG->getLIS()->getInstructionIndex(MI).getRegSlot(); + RegOpers.adjustLaneLiveness(*DAG->getLIS(), *DAG->getMRI(), SlotIdx, &MI); + } } } @@ -1384,6 +1436,75 @@ LastPosWaitedHighLatency(0), NumBlockScheduled(0), VregCurrentUsage(0), SregCurrentUsage(0), maxVregUsage(0), maxSregUsage(0) { + MachineRegisterInfo *MRI = DAG->getMRI(); + // To track register usage, we define for each register definition + // a number of usages before it gets released. + // This doesn't work with LaneMasks. + // To handle LaneMasks, we 'cut' registers affected by LaneMasks + // into all their different Lanes possible + // and behave as if that (reg, LaneMask) was a register. + std::map> RegWithLaneMask; + for (SIScheduleBlock *Block : Blocks) { + for (const auto Ins : Block->getInRegs()) { + if (!Ins.LaneMask.all() && + Ins.LaneMask != MRI->getMaxLaneMaskForVReg(Ins.RegUnit)) + RegWithLaneMask[Ins.RegUnit].push_back(Ins.LaneMask); + } + for (const auto Outs : Block->getOutRegs()) { + if (!Outs.LaneMask.all() && + Outs.LaneMask != MRI->getMaxLaneMaskForVReg(Outs.RegUnit)) + RegWithLaneMask[Outs.RegUnit].push_back(Outs.LaneMask); + } + } + for (const auto Ins : DAG->getInRegs()) { + if (!Ins.LaneMask.all() && + Ins.LaneMask != MRI->getMaxLaneMaskForVReg(Ins.RegUnit)) + RegWithLaneMask[Ins.RegUnit].push_back(Ins.LaneMask); + } + for (const auto Outs : DAG->getOutRegs()) { + if (!Outs.LaneMask.all() && + Outs.LaneMask != MRI->getMaxLaneMaskForVReg(Outs.RegUnit)) + RegWithLaneMask[Outs.RegUnit].push_back(Outs.LaneMask); + } + // Since we ignored when the lane mask was getMaxLaneMaskForVReg, + // we need to add it back. It doesn't hurt if there was no element + // with this mask for this register. + for (auto &RegLaneMasks : RegWithLaneMask) { + RegLaneMasks.second.push_back(MRI->getMaxLaneMaskForVReg(RegLaneMasks.first)); + } + + for (const auto RegLaneMasks : RegWithLaneMask) { + SmallVector &LaneBasis = + LaneMaskBasisForReg[RegLaneMasks.first]; + for (const LaneBitmask &LaneMask : RegLaneMasks.second) { + LaneBitmask Remaining = LaneMask; + for (SmallVector::iterator I = LaneBasis.begin(); + I != LaneBasis.end(); ++I) { + LaneBitmask Elem = *I; + if ((Remaining & Elem).none()) + continue; + if ((Remaining & Elem) == Elem) { + Remaining &= ~Elem; + continue; + } + // Remaining intersects with Elem, but Elem is not + // included in remaining. We divide Elem into two elements. + // The one included in Remaining, and the rest. + LaneBitmask NewElem = Elem & ~Remaining; + *I = Elem & Remaining; + LaneBasis.push_back(NewElem); + } + if (Remaining.any()) + LaneBasis.push_back(Remaining); + } + } + + for (unsigned i = 0, e = Blocks.size(); i != e; ++i) { + SIScheduleBlock *Block = Blocks[i]; + InRegsForBlock[i] = getPairsForRegs(Block->getInRegs()); + OutRegsForBlock[i] = getPairsForRegs(Block->getOutRegs()); + } + // Fill the usage of every output // Warning: while by construction we always have a link between two blocks // when one needs a result from the other, the number of users of an output @@ -1397,26 +1518,31 @@ LiveOutRegsNumUsages.resize(Blocks.size()); for (unsigned i = 0, e = Blocks.size(); i != e; ++i) { SIScheduleBlock *Block = Blocks[i]; - for (unsigned Reg : Block->getInRegs()) { + for (const auto &RegPair : InRegsForBlock[i]) { bool Found = false; int topoInd = -1; for (SIScheduleBlock* Pred: Block->getPreds()) { - std::set PredOutRegs = Pred->getOutRegs(); - std::set::iterator RegPos = PredOutRegs.find(Reg); - - if (RegPos != PredOutRegs.end()) { - Found = true; - if (topoInd < BlocksStruct.TopDownBlock2Index[Pred->getID()]) { - topoInd = BlocksStruct.TopDownBlock2Index[Pred->getID()]; + const SmallVector &PredOutRegs = + OutRegsForBlock[Pred->getID()]; + for (const auto &RegPair2 : PredOutRegs) { + if (RegPair == RegPair2) { + Found = true; + if (topoInd < BlocksStruct.TopDownBlock2Index[Pred->getID()]) { + topoInd = BlocksStruct.TopDownBlock2Index[Pred->getID()]; + } + break; } } } if (!Found) - continue; - - int PredID = BlocksStruct.TopDownIndex2Block[topoInd]; - ++LiveOutRegsNumUsages[PredID][Reg]; + // Fill LiveRegsConsumers for regs that were already + // defined before scheduling. + ++LiveRegsConsumers[RegPair]; + else { + int PredID = BlocksStruct.TopDownIndex2Block[topoInd]; + ++LiveOutRegsNumUsages[PredID][RegPair]; + } } } @@ -1437,45 +1563,33 @@ } #endif - std::set InRegs = DAG->getInRegs(); + SmallVector InRegs = getPairsForRegs(DAG->getInRegs()); addLiveRegs(InRegs); // Increase LiveOutRegsNumUsages for blocks // producing registers consumed in another // scheduling region. - for (unsigned Reg : DAG->getOutRegs()) { + for (const RegisterMaskPair &RegPair : getPairsForRegs(DAG->getOutRegs())) { for (unsigned i = 0, e = Blocks.size(); i != e; ++i) { // Do reverse traversal + bool Found = false; int ID = BlocksStruct.TopDownIndex2Block[Blocks.size()-1-i]; SIScheduleBlock *Block = Blocks[ID]; - const std::set OutRegs = Block->getOutRegs(); - - if (OutRegs.find(Reg) == OutRegs.end()) - continue; - - ++LiveOutRegsNumUsages[ID][Reg]; - break; - } - } + const SmallVector &OutRegs = + OutRegsForBlock[Block->getID()]; - // Fill LiveRegsConsumers for regs that were already - // defined before scheduling. - for (unsigned i = 0, e = Blocks.size(); i != e; ++i) { - SIScheduleBlock *Block = Blocks[i]; - for (unsigned Reg : Block->getInRegs()) { - bool Found = false; - for (SIScheduleBlock* Pred: Block->getPreds()) { - std::set PredOutRegs = Pred->getOutRegs(); - std::set::iterator RegPos = PredOutRegs.find(Reg); - - if (RegPos != PredOutRegs.end()) { + for (const auto &RegPair2 : OutRegs) { + if (RegPair == RegPair2) { Found = true; break; } } if (!Found) - ++LiveRegsConsumers[Reg]; + continue; + + ++LiveOutRegsNumUsages[ID][RegPair]; + break; } } @@ -1500,6 +1614,51 @@ ); } +SmallVector +SIScheduleBlockScheduler::getPairsForReg(unsigned Reg, LaneBitmask Mask) +{ + SmallVector Result; + auto Basis = LaneMaskBasisForReg.find(Reg); + if (Basis == LaneMaskBasisForReg.end()) { + assert(Mask.all() || Mask == DAG->getMRI()->getMaxLaneMaskForVReg(Reg)); + // We want unicity of the RegisterMaskPair for a same register/mask + // Thus replace getMaxLaneMaskForVReg by all, since they have the same + // meaning. + // Note: Physical registers have Mask.all(), but are disallowed + // to call getMaxLaneMaskForVReg. + if (!Mask.all() && Mask == DAG->getMRI()->getMaxLaneMaskForVReg(Reg)) + Mask = LaneBitmask::getAll(); + Result.push_back(RegisterMaskPair(Reg, Mask)); + } else { + for (const auto Elem : Basis->second) { + if ((Mask & Elem).any()) { + assert((Mask & Elem) == Elem); + Result.push_back(RegisterMaskPair(Reg, Elem)); + Mask &= ~Elem; + } + } + // Mask.all will have a non-none value. + // We want Mask.all equivalent to the max lane mask. + assert((Mask & DAG->getMRI()->getMaxLaneMaskForVReg(Reg)).none()); + } + + return Result; +} + +SmallVector +SIScheduleBlockScheduler::getPairsForRegs(const SmallVector Regs) +{ + SmallVector Result; + for (const auto RegPair : Regs) { + for (const auto RegPairRes : getPairsForReg(RegPair.RegUnit, + RegPair.LaneMask)) { + Result.push_back(RegPairRes); + } + } + + return Result; +} + bool SIScheduleBlockScheduler::tryCandidateLatency(SIBlockSchedCandidate &Cand, SIBlockSchedCandidate &TryCand) { if (!Cand.isValid()) { @@ -1566,8 +1725,12 @@ for (SIScheduleBlock* Block : ReadyBlocks) dbgs() << Block->getID() << ' '; dbgs() << "\nCurrent Live:\n"; - for (unsigned Reg : LiveRegs) - dbgs() << PrintVRegOrUnit(Reg, DAG->getTRI()) << ' '; + for (const auto RegPair : LiveRegs) { + dbgs() << PrintVRegOrUnit(RegPair.first, DAG->getTRI()); + if (!RegPair.second.all()) + dbgs() << ':' << PrintLaneMask(RegPair.second); + dbgs() << ' '; + } dbgs() << '\n'; dbgs() << "Current VGPRs: " << VregCurrentUsage << '\n'; dbgs() << "Current SGPRs: " << SregCurrentUsage << '\n'; @@ -1604,22 +1767,23 @@ for (std::vector::iterator I = ReadyBlocks.begin(), E = ReadyBlocks.end(); I != E; ++I) { SIBlockSchedCandidate TryCand; + unsigned TryCandID; + int SGPRUsageDiff; if (!CurrentPathRegUsage.empty() && CurrentPathRegUsage.find(*I) == CurrentPathRegUsage.end()) continue; TryCand.Block = *I; + TryCandID = TryCand.Block->getID(); TryCand.IsHighLatency = TryCand.Block->isHighLatencyBlock(); - TryCand.VGPRUsageDiff = - checkRegUsageImpact(TryCand.Block->getInRegs(), - TryCand.Block->getOutRegs())[DAG->getVGPRSetID()]; + checkRegUsageImpact(TryCandID, TryCand.VGPRUsageDiff, SGPRUsageDiff); TryCand.NumSuccessors = TryCand.Block->getSuccs().size(); TryCand.NumHighLatencySuccessors = TryCand.Block->getNumHighLatencySuccessors(); TryCand.LastPosHighLatParentScheduled = (unsigned int) std::max (0, - LastPosHighLatencyParentScheduled[TryCand.Block->getID()] - + LastPosHighLatencyParentScheduled[TryCandID] - LastPosWaitedHighLatency); TryCand.Height = TryCand.Block->Height; // Try not to increase VGPR usage too much, else we may spill. @@ -1659,27 +1823,32 @@ // Tracking of currently alive registers to determine VGPR Usage. -void SIScheduleBlockScheduler::addLiveRegs(std::set &Regs) { - for (unsigned Reg : Regs) { +void SIScheduleBlockScheduler::addLiveRegs(SmallVector &Regs) { + for (const RegisterMaskPair &RegPair : Regs) { + unsigned Reg = RegPair.RegUnit; // For now only track virtual registers. if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; - // If not already in the live set, then add it. - (void) LiveRegs.insert(Reg); + + LiveRegs[Reg] |= RegPair.LaneMask; } } void SIScheduleBlockScheduler::decreaseLiveRegs(SIScheduleBlock *Block, - std::set &Regs) { - for (unsigned Reg : Regs) { + SmallVector &Regs) { + for (const RegisterMaskPair &RegPair : Regs) { // For now only track virtual registers. - std::set::iterator Pos = LiveRegs.find(Reg); + std::map::iterator Pos = LiveRegs.find(RegPair.RegUnit); assert (Pos != LiveRegs.end() && // Reg must be live. - LiveRegsConsumers.find(Reg) != LiveRegsConsumers.end() && - LiveRegsConsumers[Reg] >= 1); - --LiveRegsConsumers[Reg]; - if (LiveRegsConsumers[Reg] == 0) - LiveRegs.erase(Pos); + LiveRegsConsumers.find(RegPair) != LiveRegsConsumers.end() && + LiveRegsConsumers[RegPair] >= 1); + --LiveRegsConsumers[RegPair]; + if (LiveRegsConsumers[RegPair] == 0) { + if (Pos->second == RegPair.LaneMask) + LiveRegs.erase(Pos); + else + Pos->second &= ~RegPair.LaneMask; + } } } @@ -1698,54 +1867,82 @@ } void SIScheduleBlockScheduler::blockScheduled(SIScheduleBlock *Block) { - decreaseLiveRegs(Block, Block->getInRegs()); - addLiveRegs(Block->getOutRegs()); + unsigned ID = Block->getID(); + decreaseLiveRegs(Block, InRegsForBlock[ID]); + addLiveRegs(OutRegsForBlock[ID]); releaseBlockSuccs(Block); - for (std::map::iterator RegI = - LiveOutRegsNumUsages[Block->getID()].begin(), - E = LiveOutRegsNumUsages[Block->getID()].end(); RegI != E; ++RegI) { - std::pair RegP = *RegI; + for (std::map::iterator RegI = + LiveOutRegsNumUsages[ID].begin(), + E = LiveOutRegsNumUsages[ID].end(); RegI != E; ++RegI) { + std::pair RegP = *RegI; // We produce this register, thus it must not be previously alive. assert(LiveRegsConsumers.find(RegP.first) == LiveRegsConsumers.end() || LiveRegsConsumers[RegP.first] == 0); LiveRegsConsumers[RegP.first] += RegP.second; } - if (LastPosHighLatencyParentScheduled[Block->getID()] > + if (LastPosHighLatencyParentScheduled[ID] > (unsigned)LastPosWaitedHighLatency) LastPosWaitedHighLatency = - LastPosHighLatencyParentScheduled[Block->getID()]; + LastPosHighLatencyParentScheduled[ID]; ++NumBlockScheduled; } -std::vector -SIScheduleBlockScheduler::checkRegUsageImpact(std::set &InRegs, - std::set &OutRegs) { - std::vector DiffSetPressure; - DiffSetPressure.assign(DAG->getTRI()->getNumRegPressureSets(), 0); +void +SIScheduleBlockScheduler::checkRegUsageImpact(unsigned BlockID, + int &DiffVGPR, + int &DiffSGPR) { + SmallDenseMap Map; + SmallPtrSet Set; + unsigned VGPRSetID = DAG->getVGPRSetID(); + unsigned SGPRSetID = DAG->getSGPRSetID(); - for (unsigned Reg : InRegs) { + DiffVGPR = 0; + DiffSGPR = 0; + + for (const auto RegPair : InRegsForBlock[BlockID]) { // For now only track virtual registers. + unsigned Reg = RegPair.RegUnit; if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; - if (LiveRegsConsumers[Reg] > 1) + + if (LiveRegsConsumers[RegPair] > 1) continue; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg); - for (; PSetI.isValid(); ++PSetI) { - DiffSetPressure[*PSetI] -= PSetI.getWeight(); + Map[Reg] |= RegPair.LaneMask; + } + + for (const auto RegPair : Map) { + if (LiveRegs[RegPair.first] == RegPair.second) { + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RegPair.first); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == VGPRSetID) + DiffVGPR -= PSetI.getWeight(); + if (*PSetI == SGPRSetID) + DiffSGPR -= PSetI.getWeight(); + } } } - for (unsigned Reg : OutRegs) { + for (const auto RegPair : OutRegsForBlock[BlockID]) { // For now only track virtual registers. + unsigned Reg = RegPair.RegUnit; if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg); - for (; PSetI.isValid(); ++PSetI) { - DiffSetPressure[*PSetI] += PSetI.getWeight(); - } + + Set.insert(Reg); } - return DiffSetPressure; + for (unsigned Reg : Set) { + // Check register is not already alive (at least some lanes) + if (LiveRegs.find(Reg) == LiveRegs.end()) { + PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == VGPRSetID) + DiffVGPR += PSetI.getWeight(); + if (*PSetI == SGPRSetID) + DiffSGPR += PSetI.getWeight(); + } + } + } } // Strategy to reduce register pressure: @@ -1817,7 +2014,7 @@ #ifndef NDEBUG void printDebug(SIScheduleDAGMI *DAG, - DenseMap &IdentifierToReg) + DenseMap &IdentifierToReg) { dbgs() << "Block list: "; for (SIScheduleBlock *Block : Dependencies) @@ -1825,49 +2022,55 @@ dbgs() << '\n'; dbgs() << "Consumed registers: "; for (unsigned Reg : ConsumedRegisters) { - unsigned RealReg = IdentifierToReg[Reg]; + RegisterMaskPair &RealReg = IdentifierToReg[Reg]; int DiffV = 0; int DiffS = 0; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg.RegUnit); for (; PSetI.isValid(); ++PSetI) { if (*PSetI == DAG->getVGPRSetID()) DiffV -= PSetI.getWeight(); if (*PSetI == DAG->getSGPRSetID()) DiffS -= PSetI.getWeight(); } - dbgs() << PrintVRegOrUnit(RealReg, DAG->getTRI()); + dbgs() << PrintVRegOrUnit(RealReg.RegUnit, DAG->getTRI()); + if (!RealReg.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(RealReg.LaneMask); dbgs() << "(" << DiffV << ", " << DiffS << "), "; } dbgs() << '\n'; dbgs() << "Intermediate registers: "; for (unsigned Reg : ProducedConsumedRegisters) { - unsigned RealReg = IdentifierToReg[Reg]; + RegisterMaskPair &RealReg = IdentifierToReg[Reg]; int DiffV = 0; int DiffS = 0; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg.RegUnit); for (; PSetI.isValid(); ++PSetI) { if (*PSetI == DAG->getVGPRSetID()) DiffV += PSetI.getWeight(); if (*PSetI == DAG->getSGPRSetID()) DiffS += PSetI.getWeight(); } - dbgs() << PrintVRegOrUnit(RealReg, DAG->getTRI()); + dbgs() << PrintVRegOrUnit(RealReg.RegUnit, DAG->getTRI()); + if (!RealReg.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(RealReg.LaneMask); dbgs() << "(" << DiffV << ", " << DiffS << "), "; } dbgs() << '\n'; dbgs() << "Produced registers: "; for (unsigned Reg : ProducedRegisters) { - unsigned RealReg = IdentifierToReg[Reg]; + RegisterMaskPair &RealReg = IdentifierToReg[Reg]; int DiffV = 0; int DiffS = 0; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg.RegUnit); for (; PSetI.isValid(); ++PSetI) { if (*PSetI == DAG->getVGPRSetID()) DiffV += PSetI.getWeight(); if (*PSetI == DAG->getSGPRSetID()) DiffS += PSetI.getWeight(); } - dbgs() << PrintVRegOrUnit(RealReg, DAG->getTRI()); + dbgs() << PrintVRegOrUnit(RealReg.RegUnit, DAG->getTRI()); + if (!RealReg.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(RealReg.LaneMask); dbgs() << "(" << DiffV << ", " << DiffS << "), "; } dbgs() << '\n'; @@ -1893,8 +2096,11 @@ // (input and output of a block). We thus use a mapping // "unique reg identifier" -> register, // and an opposite mapping register -> current associated identifier. - DenseMap IdentifierToReg; - DenseMap RegToIdentifier; + // Similar to previously, we will say a given RegisterMaskPair is an + // uniquer register (by construction, the Masks for a same register + // don't intersect). + DenseMap IdentifierToReg; + std::map RegToIdentifier; unsigned NextIdentifier = 0; DenseMap BlockInfos; @@ -1907,6 +2113,13 @@ int BestDiffVGPR = INT_MAX; int BestDiffSGPR = INT_MAX; + unsigned BestDiffReg = ~0u; + + // Used to compute scores + SmallDenseMap Map; + SmallPtrSet Set; + unsigned VGPRSetID = DAG->getVGPRSetID(); + unsigned SGPRSetID = DAG->getSGPRSetID(); if (ReadyBlocks.empty()) return std::set(); @@ -1915,28 +2128,35 @@ DEBUG(dbgs() << "Initial Live regs:\n"); // Fill info for initial registers - for (unsigned Reg : LiveRegs) { + for (const auto &RegP : LiveRegs) { + unsigned Reg = RegP.first; // Ignoring physical registers if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; - (void) LiveRegsInitId.insert(NextIdentifier); - assert(LiveRegsConsumers.find(Reg) != LiveRegsConsumers.end()); - assert(LiveRegsConsumers[Reg] > 0); - RegsConsumers[NextIdentifier] = LiveRegsConsumers[Reg]; + for (const auto RegPair : getPairsForReg(Reg, RegP.second)) { + (void) LiveRegsInitId.insert(NextIdentifier); + assert(LiveRegsConsumers.find(RegPair) != LiveRegsConsumers.end()); + assert(LiveRegsConsumers[RegPair] > 0); + RegsConsumers[NextIdentifier] = LiveRegsConsumers[RegPair]; - IdentifierToReg[NextIdentifier] = Reg; - RegToIdentifier[Reg] = NextIdentifier; - DEBUG(dbgs() << PrintVRegOrUnit(Reg, DAG->getTRI()) << " (--> " << NextIdentifier << ")\n"); - ++NextIdentifier; + IdentifierToReg[NextIdentifier] = RegPair; + RegToIdentifier[RegPair] = NextIdentifier; + DEBUG(dbgs() << PrintVRegOrUnit(Reg, DAG->getTRI()); + if (!RegPair.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(RegPair.LaneMask); + dbgs() << " (--> " << NextIdentifier << ")\n";); + ++NextIdentifier; + } } // Fill BlockInfos while (SearchDepthLimit > 0 && !SchedulableBlocks.empty()) { DEBUG(dbgs() << "Iterating... Remaining levels: " << SearchDepthLimit << '\n'); for (SIScheduleBlock* Block : SchedulableBlocks) { - SIBlockInfo &BlockInfo = BlockInfos[Block->getID()]; + unsigned ID = Block->getID(); + SIBlockInfo &BlockInfo = BlockInfos[ID]; - DEBUG(dbgs() << "Computing data for Block: " << Block->getID() << '\n'); + DEBUG(dbgs() << "Computing data for Block: " << ID << '\n'); for (SIScheduleBlock *Parent : Block->getPreds()) { DenseMap::iterator ParentBlockInfoPair = @@ -2007,12 +2227,16 @@ // At this point, we have merged the data from all parents. BlockInfo.Dependencies.insert(Block); - for (unsigned Reg : Block->getInRegs()) { + for (const auto &RegPair : InRegsForBlock[ID]) { + unsigned Reg = RegPair.RegUnit; if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; - DEBUG(dbgs() << "InReg : " << PrintVRegOrUnit(Reg, DAG->getTRI()) << '\n'); - assert(RegToIdentifier.find(Reg) != RegToIdentifier.end()); - unsigned RegIdentifier = RegToIdentifier[Reg]; + DEBUG(dbgs() << "InReg : " << PrintVRegOrUnit(Reg, DAG->getTRI()); + if (!RegPair.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(RegPair.LaneMask); + dbgs () << '\n'); + assert(RegToIdentifier.find(RegPair) != RegToIdentifier.end()); + unsigned RegIdentifier = RegToIdentifier[RegPair]; SmallPtrSet &BlockRegisterConsumers = BlockInfo.RegisterConsumers[RegIdentifier]; @@ -2038,19 +2262,23 @@ } } } - for (unsigned Reg : Block->getOutRegs()) { + for (const auto &RegPair : OutRegsForBlock[ID]) { + unsigned Reg = RegPair.RegUnit; if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; // Note: By construction, we can overwrite RegToIdentifier[Reg], // because we schedule in a valid order (Reg was consumed at this // point). unsigned RegIdentifier = NextIdentifier; - IdentifierToReg[RegIdentifier] = Reg; - RegToIdentifier[Reg] = RegIdentifier; - DEBUG(dbgs() << "OutReg : " << PrintVRegOrUnit(Reg, DAG->getTRI()) << '\n'); + IdentifierToReg[RegIdentifier] = RegPair; + RegToIdentifier[RegPair] = RegIdentifier; + DEBUG(dbgs() << "OutReg : " << PrintVRegOrUnit(Reg, DAG->getTRI()); + if (!RegPair.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(RegPair.LaneMask); + dbgs () << '\n'); ++NextIdentifier; - RegsConsumers[RegIdentifier] = LiveOutRegsNumUsages[Block->getID()][Reg]; + RegsConsumers[RegIdentifier] = LiveOutRegsNumUsages[ID][RegPair]; BlockInfo.ProducedRegisters.insert(RegIdentifier); } @@ -2176,9 +2404,11 @@ DEBUG( for (const auto RInfo : RegisterInfos) { unsigned Reg = RInfo.first; - dbgs() << Reg << "(" << PrintVRegOrUnit(IdentifierToReg[Reg], - DAG->getTRI()) - << ")" << " :\nConsumed: "; + RegisterMaskPair &RealReg = IdentifierToReg[Reg]; + dbgs() << Reg << "(" << PrintVRegOrUnit(RealReg.RegUnit, DAG->getTRI()); + if (!RealReg.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(RealReg.LaneMask); + dbgs() << ")" << " :\nConsumed: "; for (unsigned Reg2 : RInfo.second.ConsumedRegisters) dbgs() << Reg2 << " "; dbgs() << "\nProducedConsumed: "; @@ -2198,26 +2428,44 @@ int DiffVGPR = 0; int DiffSGPR = 0; + Map.clear(); + Set.clear(); + for (unsigned Reg : RInfo.second.ConsumedRegisters) { - unsigned RealReg = IdentifierToReg[Reg]; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); - for (; PSetI.isValid(); ++PSetI) { - if (*PSetI == DAG->getVGPRSetID()) - DiffVGPR -= PSetI.getWeight(); - if (*PSetI == DAG->getSGPRSetID()) - DiffSGPR -= PSetI.getWeight(); + RegisterMaskPair RealReg = IdentifierToReg[Reg]; + Map[RealReg.RegUnit] |= RealReg.LaneMask; + } + + for (const auto RegPair : Map) { + if (LiveRegs[RegPair.first] == RegPair.second) { + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RegPair.first); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == VGPRSetID) + DiffVGPR -= PSetI.getWeight(); + if (*PSetI == SGPRSetID) + DiffSGPR -= PSetI.getWeight(); + } } } + for (unsigned Reg : RInfo.second.ProducedRegisters) { - unsigned RealReg = IdentifierToReg[Reg]; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); - for (; PSetI.isValid(); ++PSetI) { - if (*PSetI == DAG->getVGPRSetID()) - DiffVGPR += PSetI.getWeight(); - if (*PSetI == DAG->getSGPRSetID()) - DiffSGPR += PSetI.getWeight(); + RegisterMaskPair RealReg = IdentifierToReg[Reg]; + Set.insert(RealReg.RegUnit); + } + + for (unsigned Reg : Set) { + // Check register is not already alive (at least some lanes) + if (LiveRegs.find(Reg) == LiveRegs.end()) { + PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == VGPRSetID) + DiffVGPR += PSetI.getWeight(); + if (*PSetI == SGPRSetID) + DiffSGPR += PSetI.getWeight(); + } } } + DEBUG(dbgs() << RInfo.first << ": diff = (" << DiffVGPR << ", " << DiffSGPR << ")\n"); // Remove cases that don't match the target. @@ -2228,20 +2476,26 @@ if (BestDiffVGPR > DiffVGPR) { BestDiffVGPR = DiffVGPR; BestDiffSGPR = DiffSGPR; + BestDiffReg = RInfo.first; } else if (BestDiffVGPR == DiffVGPR) { - if (BestDiffSGPR > DiffSGPR) + if (BestDiffSGPR > DiffSGPR) { BestDiffSGPR = DiffSGPR; + BestDiffReg = RInfo.first; + } } } else { if (BestDiffSGPR > DiffSGPR) { BestDiffVGPR = DiffVGPR; BestDiffSGPR = DiffSGPR; + BestDiffReg = RInfo.first; } else if (BestDiffSGPR == DiffSGPR) { - if (BestDiffVGPR > DiffVGPR) + if (BestDiffVGPR > DiffVGPR) { BestDiffVGPR = DiffVGPR; + BestDiffReg = RInfo.first; + } } } } @@ -2257,37 +2511,10 @@ DEBUG(dbgs() << "Best diff score: (" << BestDiffVGPR << ", " << BestDiffSGPR << ")\n"); - for (auto RInfo : RegisterInfos) { - int DiffVGPR = 0; - int DiffSGPR = 0; - for (unsigned Reg : RInfo.second.ConsumedRegisters) { - unsigned RealReg = IdentifierToReg[Reg]; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); - for (; PSetI.isValid(); ++PSetI) { - if (*PSetI == DAG->getVGPRSetID()) - DiffVGPR -= PSetI.getWeight(); - if (*PSetI == DAG->getSGPRSetID()) - DiffSGPR -= PSetI.getWeight(); - } - } - for (unsigned Reg : RInfo.second.ProducedRegisters) { - unsigned RealReg = IdentifierToReg[Reg]; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); - for (; PSetI.isValid(); ++PSetI) { - if (*PSetI == DAG->getVGPRSetID()) - DiffVGPR += PSetI.getWeight(); - if (*PSetI == DAG->getSGPRSetID()) - DiffSGPR += PSetI.getWeight(); - } - } - if (BestDiffVGPR == DiffVGPR && BestDiffSGPR == DiffSGPR) { - DEBUG(RInfo.second.printDebug(DAG, IdentifierToReg)); - return std::set(RInfo.second.Dependencies.begin(), - RInfo.second.Dependencies.end()); - } - } - - llvm_unreachable("internal error"); + SIMIRegisterInfo &RInfo = RegisterInfos[BestDiffReg]; + DEBUG(RInfo.printDebug(DAG, IdentifierToReg)); + return std::set(RInfo.Dependencies.begin(), + RInfo.Dependencies.end()); } // SIScheduler // @@ -2426,7 +2653,7 @@ VgprUsage = 0; SgprUsage = 0; for (_Iterator RegI = First; RegI != End; ++RegI) { - unsigned Reg = *RegI; + unsigned Reg = RegI->first; // For now only track virtual registers if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; @@ -2485,7 +2712,7 @@ IsHighLatencySU[i] = 1; } - SIScheduler Scheduler(this); + SIScheduler Scheduler(this, ShouldTrackLaneMasks); Best = Scheduler.scheduleVariant(SISchedulerBlockCreatorVariant::LatenciesAlone, SISchedulerBlockSchedulerVariant::BlockLatencyRegUsage);