Index: lib/Target/AMDGPU/SIMachineScheduler.h =================================================================== --- lib/Target/AMDGPU/SIMachineScheduler.h +++ lib/Target/AMDGPU/SIMachineScheduler.h @@ -20,6 +20,7 @@ #include "llvm/CodeGen/MachineScheduler.h" #include "llvm/CodeGen/RegisterPressure.h" #include "llvm/CodeGen/ScheduleDAG.h" +#include "llvm/MC/LaneBitmask.h" #include #include #include @@ -85,8 +86,8 @@ // Note that some registers are not 32 bits, // and thus the pressure is not equal // to the number of live registers. - std::set LiveInRegs; - std::set LiveOutRegs; + SmallVector LiveInRegs; + SmallVector LiveOutRegs; bool Scheduled = false; bool HighLatencyBlock = false; @@ -161,8 +162,8 @@ return InternalAdditionnalPressure; } - std::set &getInRegs() { return LiveInRegs; } - std::set &getOutRegs() { return LiveOutRegs; } + SmallVector &getInRegs() { return LiveInRegs; } + SmallVector &getOutRegs() { return LiveOutRegs; } void printDebug(bool Full); @@ -306,6 +307,7 @@ void topologicalSort(); + void adjustLaneLiveness(MachineInstr &MI); void scheduleInsideBlocks(); void fillStats(); @@ -322,10 +324,22 @@ SISchedulerBlockSchedulerVariant Variant; std::vector Blocks; - std::vector> LiveOutRegsNumUsages; - std::set LiveRegs; + // For each register, basis of LaneBitmask to generate all + // LaneBitmasks involved in the scheduling. + // If the register is not in the map, it is assumed it is always + // used with full mask. + std::map> LaneMaskBasisForReg; + + // The RegisterMaskPair below are assumed to be pairs (Reg, Mask) + // such that Mask is in LaneMaskBasisForReg[Reg] if exists. + std::vector> LiveOutRegsNumUsages; + std::map LiveRegs; // Num of schedulable unscheduled blocks reading the register. - std::map LiveRegsConsumers; + std::map LiveRegsConsumers; + // Blocks's getInRegs, but with RegisterMaskPair with Mask + // in LaneMaskBasisForReg[Reg] if exists. + DenseMap> InRegsForBlock; + DenseMap> OutRegsForBlock; std::vector LastPosHighLatencyParentScheduled; int LastPosWaitedHighLatency; @@ -356,6 +370,18 @@ unsigned getSGPRUsage() { return maxSregUsage; } private: + + // Convert Reg/Mask to a list of Reg/Mask, with Mask in + // LaneMaskBasisForReg. + SmallVector getPairsForReg(unsigned Reg, + LaneBitmask Mask); + // ToAppend: where to append the result. + void getPairsForReg(SmallVector &ToAppend, + unsigned Reg, LaneBitmask Mask); + // Idem for a list of Reg/Mask + SmallVector getPairsForRegs( + const SmallVector &Regs); + struct SIBlockSchedCandidate : SISchedulerCandidate { // The best Block candidate. SIScheduleBlock *Block = nullptr; @@ -391,15 +417,17 @@ SIBlockSchedCandidate &TryCand); SIScheduleBlock *pickBlock(); - void addLiveRegs(std::set &Regs); - void decreaseLiveRegs(SIScheduleBlock *Block, std::set &Regs); + // The Mask in RegisterMaskPair needs to be + // element of LaneMaskBasisForReg. + void addLiveRegs(SmallVector &Regs); + void decreaseLiveRegs(SIScheduleBlock *Block, + SmallVector &Regs); void releaseBlockSuccs(SIScheduleBlock *Parent); void blockScheduled(SIScheduleBlock *Block); // Check register pressure change - // by scheduling a block with these LiveIn and LiveOut. - std::vector checkRegUsageImpact(std::set &InRegs, - std::set &OutRegs); + // by scheduling a block + void checkRegUsageImpact(unsigned BlockID, int &DiffVGPR, int &DiffSGPR); void schedule(); }; @@ -415,7 +443,8 @@ SIScheduleBlockCreator BlockCreator; public: - SIScheduler(SIScheduleDAGMI *DAG) : DAG(DAG), BlockCreator(DAG) {} + SIScheduler(SIScheduleDAGMI *DAG) : + DAG(DAG), BlockCreator(DAG) {} ~SIScheduler() = default; @@ -447,7 +476,7 @@ // To init Block's RPTracker. void initRPTracker(RegPressureTracker &RPTracker) { - RPTracker.init(&MF, RegClassInfo, LIS, BB, RegionBegin, false, false); + RPTracker.init(&MF, RegClassInfo, LIS, BB, RegionBegin, ShouldTrackLaneMasks, false); } MachineBasicBlock *getBB() { return BB; } @@ -467,25 +496,19 @@ unsigned &VgprUsage, unsigned &SgprUsage); - std::set getInRegs() { - std::set InRegs; - for (const auto &RegMaskPair : RPTracker.getPressure().LiveInRegs) { - InRegs.insert(RegMaskPair.RegUnit); - } - return InRegs; + const SmallVector &getInRegs() { + return RPTracker.getPressure().LiveInRegs; } - std::set getOutRegs() { - std::set OutRegs; - for (const auto &RegMaskPair : RPTracker.getPressure().LiveOutRegs) { - OutRegs.insert(RegMaskPair.RegUnit); - } - return OutRegs; + const SmallVector &getOutRegs() { + return RPTracker.getPressure().LiveOutRegs; }; unsigned getVGPRSetID() const { return VGPRSetID; } unsigned getSGPRSetID() const { return SGPRSetID; } + bool shouldTrackLaneMasks() const { return ShouldTrackLaneMasks; } + private: void topologicalSort(); // After scheduling is done, improve low latency placements. Index: lib/Target/AMDGPU/SIMachineScheduler.cpp =================================================================== --- lib/Target/AMDGPU/SIMachineScheduler.cpp +++ lib/Target/AMDGPU/SIMachineScheduler.cpp @@ -25,6 +25,7 @@ #include "llvm/CodeGen/MachineScheduler.h" #include "llvm/CodeGen/RegisterPressure.h" #include "llvm/CodeGen/SlotIndexes.h" +#include "llvm/MC/LaneBitmask.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -298,22 +299,26 @@ Scheduled = true; } -// Returns if the register was set between first and last. -static bool isDefBetween(unsigned Reg, +// Returns if one of the register lanes was set between first and last. +// Lanes is set to the lane mask set. +static LaneBitmask findDefsBetween(unsigned Reg, SlotIndex First, SlotIndex Last, const MachineRegisterInfo *MRI, const LiveIntervals *LIS) { - for (MachineRegisterInfo::def_instr_iterator - UI = MRI->def_instr_begin(Reg), - UE = MRI->def_instr_end(); UI != UE; ++UI) { - const MachineInstr* MI = &*UI; + LaneBitmask LaneMask = LaneBitmask::getNone(); + const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo(); + for (const MachineOperand &MO : MRI->def_operands(Reg)) { + const MachineInstr* MI = MO.getParent(); if (MI->isDebugValue()) continue; SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot(); - if (InstSlot >= First && InstSlot <= Last) - return true; + if (InstSlot >= First && InstSlot <= Last) { + unsigned SubRegIdx = MO.getSubReg(); + LaneBitmask DefMask = TRI->getSubRegIndexLaneMask(SubRegIdx); + LaneMask |= DefMask; + } } - return false; + return LaneMask; } void SIScheduleBlock::initRegPressure(MachineBasicBlock::iterator BeginBlock, @@ -322,6 +327,8 @@ RegPressureTracker RPTracker(Pressure), BotRPTracker(BotPressure); LiveIntervals *LIS = DAG->getLIS(); MachineRegisterInfo *MRI = DAG->getMRI(); + SlotIndex BeginBlockIdx = LIS->getInstructionIndex(*BeginBlock).getRegSlot(); + SlotIndex EndBlockIdx = LIS->getInstructionIndex(*EndBlock).getRegSlot(); DAG->initRPTracker(TopRPTracker); DAG->initRPTracker(BotRPTracker); DAG->initRPTracker(RPTracker); @@ -340,10 +347,12 @@ TopRPTracker.addLiveRegs(RPTracker.getPressure().LiveInRegs); BotRPTracker.addLiveRegs(RPTracker.getPressure().LiveOutRegs); + LiveInRegs.clear(); + // Do not Track Physical Registers, because it messes up. for (const auto &RegMaskPair : RPTracker.getPressure().LiveInRegs) { if (TargetRegisterInfo::isVirtualRegister(RegMaskPair.RegUnit)) - LiveInRegs.insert(RegMaskPair.RegUnit); + LiveInRegs.push_back(RegMaskPair); } LiveOutRegs.clear(); // There is several possibilities to distinguish: @@ -368,14 +377,21 @@ // The RPTracker's LiveOutRegs has 1, 3, (some correct or incorrect)4, 5, 7 // Comparing to LiveInRegs is not sufficient to differenciate 4 vs 5, 7 // The use of findDefBetween removes the case 4. + // LaneMask: Conceptually this is the same than described above, except + // we differenciate what happens for all reg lanes. for (const auto &RegMaskPair : RPTracker.getPressure().LiveOutRegs) { unsigned Reg = RegMaskPair.RegUnit; - if (TargetRegisterInfo::isVirtualRegister(Reg) && - isDefBetween(Reg, LIS->getInstructionIndex(*BeginBlock).getRegSlot(), - LIS->getInstructionIndex(*EndBlock).getRegSlot(), MRI, - LIS)) { - LiveOutRegs.insert(Reg); - } + if (!TargetRegisterInfo::isVirtualRegister(Reg)) + continue; + LaneBitmask LaneMask = findDefsBetween(Reg, BeginBlockIdx, + EndBlockIdx, MRI, LIS); + // Being in LaneMask but not in RegMaskPair.LaneMask means the lane + // Was temporarily defined, but then consumed in the block. + LaneMask &= RegMaskPair.LaneMask; + if (LaneMask.any()) + LiveOutRegs.push_back(RegisterMaskPair(RegMaskPair.RegUnit, + DAG->shouldTrackLaneMasks() ? LaneMask : + LaneBitmask::getAll())); } // Pressure = sum_alive_registers register size @@ -594,12 +610,20 @@ dbgs() << "LiveOutPressure " << LiveOutPressure[DAG->getSGPRSetID()] << ' ' << LiveOutPressure[DAG->getVGPRSetID()] << "\n\n"; dbgs() << "LiveIns:\n"; - for (unsigned Reg : LiveInRegs) - dbgs() << PrintVRegOrUnit(Reg, DAG->getTRI()) << ' '; + for (const auto &Ins : LiveInRegs) { + dbgs() << PrintVRegOrUnit(Ins.RegUnit, DAG->getTRI()); + if (!Ins.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(Ins.LaneMask); + dbgs() << ' '; + } dbgs() << "\nLiveOuts:\n"; - for (unsigned Reg : LiveOutRegs) - dbgs() << PrintVRegOrUnit(Reg, DAG->getTRI()) << ' '; + for (const auto &Outs : LiveOutRegs) { + dbgs() << PrintVRegOrUnit(Outs.RegUnit, DAG->getTRI()); + if (!Outs.LaneMask.all()) + dbgs() << ':' << PrintLaneMask(Outs.LaneMask); + dbgs() << ' '; + } } dbgs() << "\nInstructions:\n"; @@ -1271,6 +1295,23 @@ TopDownIndex2Block.rend()); } +void SIScheduleBlockCreator::adjustLaneLiveness(MachineInstr &MI) +{ + if (!MI.isDebugValue()) { + // Reset read - undef flags and update them later. + for (auto &Op : MI.operands()) + if (Op.isReg() && Op.isDef()) + Op.setIsUndef(false); + + RegisterOperands RegOpers; + RegOpers.collect(MI, *DAG->getTRI(), *DAG->getMRI(), + DAG->shouldTrackLaneMasks(), false); + // Adjust liveness and add missing dead+read-undef flags. + auto SlotIdx = DAG->getLIS()->getInstructionIndex(MI).getRegSlot(); + RegOpers.adjustLaneLiveness(*DAG->getLIS(), *DAG->getMRI(), SlotIdx, &MI); + } +} + void SIScheduleBlockCreator::scheduleInsideBlocks() { unsigned DAGSize = CurrentBlocks.size(); @@ -1317,6 +1358,7 @@ // It would gain a lot if there was a way to recompute the // LiveIntervals for the entire scheduling region. DAG->getLIS()->handleMove(*MI, /*UpdateFlags=*/true); + adjustLaneLiveness(*MI); PosNew.push_back(CurrentTopFastSched); } } @@ -1338,11 +1380,13 @@ MachineBasicBlock::iterator POld = PosOld[i-1]; MachineBasicBlock::iterator PNew = PosNew[i-1]; if (PNew != POld) { + MachineInstr &MI = *POld; // Update the instruction stream. DAG->getBB()->splice(POld, DAG->getBB(), PNew); // Update LiveIntervals. DAG->getLIS()->handleMove(*POld, /*UpdateFlags=*/true); + adjustLaneLiveness(MI); } } @@ -1395,6 +1439,75 @@ LastPosWaitedHighLatency(0), NumBlockScheduled(0), VregCurrentUsage(0), SregCurrentUsage(0), maxVregUsage(0), maxSregUsage(0) { + MachineRegisterInfo *MRI = DAG->getMRI(); + // To track register usage, we define for each register definition + // a number of usages before it gets released. + // This doesn't work with LaneMasks. + // To handle LaneMasks, we 'cut' registers affected by LaneMasks + // into all their different Lanes possible + // and behave as if that (reg, LaneMask) was a register. + std::map> RegWithLaneMask; + for (SIScheduleBlock *Block : Blocks) { + for (const auto &Ins : Block->getInRegs()) { + if (!Ins.LaneMask.all() && + Ins.LaneMask != MRI->getMaxLaneMaskForVReg(Ins.RegUnit)) + RegWithLaneMask[Ins.RegUnit].push_back(Ins.LaneMask); + } + for (const auto &Outs : Block->getOutRegs()) { + if (!Outs.LaneMask.all() && + Outs.LaneMask != MRI->getMaxLaneMaskForVReg(Outs.RegUnit)) + RegWithLaneMask[Outs.RegUnit].push_back(Outs.LaneMask); + } + } + for (const auto &Ins : DAG->getInRegs()) { + if (!Ins.LaneMask.all() && + Ins.LaneMask != MRI->getMaxLaneMaskForVReg(Ins.RegUnit)) + RegWithLaneMask[Ins.RegUnit].push_back(Ins.LaneMask); + } + for (const auto &Outs : DAG->getOutRegs()) { + if (!Outs.LaneMask.all() && + Outs.LaneMask != MRI->getMaxLaneMaskForVReg(Outs.RegUnit)) + RegWithLaneMask[Outs.RegUnit].push_back(Outs.LaneMask); + } + // Since we ignored when the lane mask was getMaxLaneMaskForVReg, + // we need to add it back. It doesn't hurt if there was no element + // with this mask for this register. + for (auto &RegLaneMasks : RegWithLaneMask) { + RegLaneMasks.second.push_back(MRI->getMaxLaneMaskForVReg(RegLaneMasks.first)); + } + + for (const auto &RegLaneMasks : RegWithLaneMask) { + SmallVector &LaneBasis = + LaneMaskBasisForReg[RegLaneMasks.first]; + for (const LaneBitmask &LaneMask : RegLaneMasks.second) { + LaneBitmask Remaining = LaneMask; + for (SmallVector::iterator I = LaneBasis.begin(); + I != LaneBasis.end(); ++I) { + LaneBitmask Elem = *I; + if ((Remaining & Elem).none()) + continue; + if ((Remaining & Elem) == Elem) { + Remaining &= ~Elem; + continue; + } + // Remaining intersects with Elem, but Elem is not + // included in remaining. We divide Elem into two elements. + // The one included in Remaining, and the rest. + LaneBitmask NewElem = Elem & ~Remaining; + *I = Elem & Remaining; + LaneBasis.push_back(NewElem); + } + if (Remaining.any()) + LaneBasis.push_back(Remaining); + } + } + + for (unsigned i = 0, e = Blocks.size(); i != e; ++i) { + SIScheduleBlock *Block = Blocks[i]; + InRegsForBlock[i] = getPairsForRegs(Block->getInRegs()); + OutRegsForBlock[i] = getPairsForRegs(Block->getOutRegs()); + } + // Fill the usage of every output // Warning: while by construction we always have a link between two blocks // when one needs a result from the other, the number of users of an output @@ -1408,26 +1521,31 @@ LiveOutRegsNumUsages.resize(Blocks.size()); for (unsigned i = 0, e = Blocks.size(); i != e; ++i) { SIScheduleBlock *Block = Blocks[i]; - for (unsigned Reg : Block->getInRegs()) { + for (const auto &RegPair : InRegsForBlock[i]) { bool Found = false; int topoInd = -1; for (SIScheduleBlock* Pred: Block->getPreds()) { - std::set PredOutRegs = Pred->getOutRegs(); - std::set::iterator RegPos = PredOutRegs.find(Reg); - - if (RegPos != PredOutRegs.end()) { - Found = true; - if (topoInd < BlocksStruct.TopDownBlock2Index[Pred->getID()]) { - topoInd = BlocksStruct.TopDownBlock2Index[Pred->getID()]; + const SmallVector &PredOutRegs = + OutRegsForBlock[Pred->getID()]; + for (const auto &RegPair2 : PredOutRegs) { + if (RegPair == RegPair2) { + Found = true; + if (topoInd < BlocksStruct.TopDownBlock2Index[Pred->getID()]) { + topoInd = BlocksStruct.TopDownBlock2Index[Pred->getID()]; + } + break; } } } if (!Found) - continue; - - int PredID = BlocksStruct.TopDownIndex2Block[topoInd]; - ++LiveOutRegsNumUsages[PredID][Reg]; + // Fill LiveRegsConsumers for regs that were already + // defined before scheduling. + ++LiveRegsConsumers[RegPair]; + else { + int PredID = BlocksStruct.TopDownIndex2Block[topoInd]; + ++LiveOutRegsNumUsages[PredID][RegPair]; + } } } @@ -1448,45 +1566,33 @@ } #endif - std::set InRegs = DAG->getInRegs(); + SmallVector InRegs = getPairsForRegs(DAG->getInRegs()); addLiveRegs(InRegs); // Increase LiveOutRegsNumUsages for blocks // producing registers consumed in another // scheduling region. - for (unsigned Reg : DAG->getOutRegs()) { + for (const RegisterMaskPair &RegPair : getPairsForRegs(DAG->getOutRegs())) { for (unsigned i = 0, e = Blocks.size(); i != e; ++i) { // Do reverse traversal + bool Found = false; int ID = BlocksStruct.TopDownIndex2Block[Blocks.size()-1-i]; SIScheduleBlock *Block = Blocks[ID]; - const std::set &OutRegs = Block->getOutRegs(); - - if (OutRegs.find(Reg) == OutRegs.end()) - continue; + const SmallVector &OutRegs = + OutRegsForBlock[Block->getID()]; - ++LiveOutRegsNumUsages[ID][Reg]; - break; - } - } - - // Fill LiveRegsConsumers for regs that were already - // defined before scheduling. - for (unsigned i = 0, e = Blocks.size(); i != e; ++i) { - SIScheduleBlock *Block = Blocks[i]; - for (unsigned Reg : Block->getInRegs()) { - bool Found = false; - for (SIScheduleBlock* Pred: Block->getPreds()) { - std::set PredOutRegs = Pred->getOutRegs(); - std::set::iterator RegPos = PredOutRegs.find(Reg); - - if (RegPos != PredOutRegs.end()) { + for (const auto &RegPair2 : OutRegs) { + if (RegPair == RegPair2) { Found = true; break; } } if (!Found) - ++LiveRegsConsumers[Reg]; + continue; + + ++LiveOutRegsNumUsages[ID][RegPair]; + break; } } @@ -1511,6 +1617,59 @@ ); } +SmallVector +SIScheduleBlockScheduler::getPairsForReg(unsigned Reg, LaneBitmask Mask) +{ + SmallVector Result; + + getPairsForReg(Result, Reg, Mask); + + return Result; +} + +void +SIScheduleBlockScheduler::getPairsForReg(SmallVector &ToAppend, + unsigned Reg, LaneBitmask Mask) +{ + auto Basis = LaneMaskBasisForReg.find(Reg); + if (Basis == LaneMaskBasisForReg.end()) { + assert(Mask.all() || Mask == DAG->getMRI()->getMaxLaneMaskForVReg(Reg)); + // We want unicity of the RegisterMaskPair for a same register/mask + // Thus replace getMaxLaneMaskForVReg by all, since they have the same + // meaning. + // Note: Physical registers have Mask.all(), but are disallowed + // to call getMaxLaneMaskForVReg. + if (!Mask.all() && Mask == DAG->getMRI()->getMaxLaneMaskForVReg(Reg)) + Mask = LaneBitmask::getAll(); + ToAppend.push_back(RegisterMaskPair(Reg, Mask)); + } else { + for (const auto &Elem : Basis->second) { + if ((Mask & Elem).any()) { + assert((Mask & Elem) == Elem); + ToAppend.push_back(RegisterMaskPair(Reg, Elem)); + Mask &= ~Elem; + } + } + // Mask.all will have a non-none value. + // We want Mask.all equivalent to the max lane mask. + assert((Mask & DAG->getMRI()->getMaxLaneMaskForVReg(Reg)).none()); + } +} + +SmallVector +SIScheduleBlockScheduler::getPairsForRegs(const SmallVector &Regs) +{ + SmallVector Result; + + std::for_each(Regs.begin(), Regs.end(), + [&](const RegisterMaskPair &RegPair){ + getPairsForReg(Result, RegPair.RegUnit, + RegPair.LaneMask); + }); + + return Result; +} + bool SIScheduleBlockScheduler::tryCandidateLatency(SIBlockSchedCandidate &Cand, SIBlockSchedCandidate &TryCand) { if (!Cand.isValid()) { @@ -1577,8 +1736,12 @@ for (SIScheduleBlock* Block : ReadyBlocks) dbgs() << Block->getID() << ' '; dbgs() << "\nCurrent Live:\n"; - for (unsigned Reg : LiveRegs) - dbgs() << PrintVRegOrUnit(Reg, DAG->getTRI()) << ' '; + for (const auto &RegPair : LiveRegs) { + dbgs() << PrintVRegOrUnit(RegPair.first, DAG->getTRI()); + if (!RegPair.second.all()) + dbgs() << ':' << PrintLaneMask(RegPair.second); + dbgs() << ' '; + } dbgs() << '\n'; dbgs() << "Current VGPRs: " << VregCurrentUsage << '\n'; dbgs() << "Current SGPRs: " << SregCurrentUsage << '\n'; @@ -1588,17 +1751,19 @@ for (std::vector::iterator I = ReadyBlocks.begin(), E = ReadyBlocks.end(); I != E; ++I) { SIBlockSchedCandidate TryCand; + unsigned TryCandID; + int SGPRUsageDiff; + TryCand.Block = *I; + TryCandID = TryCand.Block->getID(); TryCand.IsHighLatency = TryCand.Block->isHighLatencyBlock(); - TryCand.VGPRUsageDiff = - checkRegUsageImpact(TryCand.Block->getInRegs(), - TryCand.Block->getOutRegs())[DAG->getVGPRSetID()]; + checkRegUsageImpact(TryCandID, TryCand.VGPRUsageDiff, SGPRUsageDiff); TryCand.NumSuccessors = TryCand.Block->getSuccs().size(); TryCand.NumHighLatencySuccessors = TryCand.Block->getNumHighLatencySuccessors(); TryCand.LastPosHighLatParentScheduled = (unsigned int) std::max (0, - LastPosHighLatencyParentScheduled[TryCand.Block->getID()] - + LastPosHighLatencyParentScheduled[TryCandID] - LastPosWaitedHighLatency); TryCand.Height = TryCand.Block->Height; // Try not to increase VGPR usage too much, else we may spill. @@ -1636,27 +1801,32 @@ // Tracking of currently alive registers to determine VGPR Usage. -void SIScheduleBlockScheduler::addLiveRegs(std::set &Regs) { - for (unsigned Reg : Regs) { +void SIScheduleBlockScheduler::addLiveRegs(SmallVector &Regs) { + for (const RegisterMaskPair &RegPair : Regs) { + unsigned Reg = RegPair.RegUnit; // For now only track virtual registers. if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; - // If not already in the live set, then add it. - (void) LiveRegs.insert(Reg); + + LiveRegs[Reg] |= RegPair.LaneMask; } } void SIScheduleBlockScheduler::decreaseLiveRegs(SIScheduleBlock *Block, - std::set &Regs) { - for (unsigned Reg : Regs) { + SmallVector &Regs) { + for (const RegisterMaskPair &RegPair : Regs) { // For now only track virtual registers. - std::set::iterator Pos = LiveRegs.find(Reg); + std::map::iterator Pos = LiveRegs.find(RegPair.RegUnit); assert (Pos != LiveRegs.end() && // Reg must be live. - LiveRegsConsumers.find(Reg) != LiveRegsConsumers.end() && - LiveRegsConsumers[Reg] >= 1); - --LiveRegsConsumers[Reg]; - if (LiveRegsConsumers[Reg] == 0) - LiveRegs.erase(Pos); + LiveRegsConsumers.find(RegPair) != LiveRegsConsumers.end() && + LiveRegsConsumers[RegPair] >= 1); + --LiveRegsConsumers[RegPair]; + if (LiveRegsConsumers[RegPair] == 0) { + if (Pos->second == RegPair.LaneMask) + LiveRegs.erase(Pos); + else + Pos->second &= ~RegPair.LaneMask; + } } } @@ -1672,54 +1842,82 @@ } void SIScheduleBlockScheduler::blockScheduled(SIScheduleBlock *Block) { - decreaseLiveRegs(Block, Block->getInRegs()); - addLiveRegs(Block->getOutRegs()); + unsigned ID = Block->getID(); + decreaseLiveRegs(Block, InRegsForBlock[ID]); + addLiveRegs(OutRegsForBlock[ID]); releaseBlockSuccs(Block); - for (std::map::iterator RegI = - LiveOutRegsNumUsages[Block->getID()].begin(), - E = LiveOutRegsNumUsages[Block->getID()].end(); RegI != E; ++RegI) { - std::pair RegP = *RegI; + for (std::map::iterator RegI = + LiveOutRegsNumUsages[ID].begin(), + E = LiveOutRegsNumUsages[ID].end(); RegI != E; ++RegI) { + std::pair RegP = *RegI; // We produce this register, thus it must not be previously alive. assert(LiveRegsConsumers.find(RegP.first) == LiveRegsConsumers.end() || LiveRegsConsumers[RegP.first] == 0); LiveRegsConsumers[RegP.first] += RegP.second; } - if (LastPosHighLatencyParentScheduled[Block->getID()] > + if (LastPosHighLatencyParentScheduled[ID] > (unsigned)LastPosWaitedHighLatency) LastPosWaitedHighLatency = - LastPosHighLatencyParentScheduled[Block->getID()]; + LastPosHighLatencyParentScheduled[ID]; ++NumBlockScheduled; } -std::vector -SIScheduleBlockScheduler::checkRegUsageImpact(std::set &InRegs, - std::set &OutRegs) { - std::vector DiffSetPressure; - DiffSetPressure.assign(DAG->getTRI()->getNumRegPressureSets(), 0); +void +SIScheduleBlockScheduler::checkRegUsageImpact(unsigned BlockID, + int &DiffVGPR, + int &DiffSGPR) { + SmallDenseMap Map; + SmallPtrSet Set; + unsigned VGPRSetID = DAG->getVGPRSetID(); + unsigned SGPRSetID = DAG->getSGPRSetID(); - for (unsigned Reg : InRegs) { + DiffVGPR = 0; + DiffSGPR = 0; + + for (const auto &RegPair : InRegsForBlock[BlockID]) { // For now only track virtual registers. + unsigned Reg = RegPair.RegUnit; if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; - if (LiveRegsConsumers[Reg] > 1) + + if (LiveRegsConsumers[RegPair] > 1) continue; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg); - for (; PSetI.isValid(); ++PSetI) { - DiffSetPressure[*PSetI] -= PSetI.getWeight(); + Map[Reg] |= RegPair.LaneMask; + } + + for (const auto &RegPair : Map) { + if (LiveRegs[RegPair.first] == RegPair.second) { + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RegPair.first); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == VGPRSetID) + DiffVGPR -= PSetI.getWeight(); + if (*PSetI == SGPRSetID) + DiffSGPR -= PSetI.getWeight(); + } } } - for (unsigned Reg : OutRegs) { + for (const auto &RegPair : OutRegsForBlock[BlockID]) { // For now only track virtual registers. + unsigned Reg = RegPair.RegUnit; if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue; - PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg); - for (; PSetI.isValid(); ++PSetI) { - DiffSetPressure[*PSetI] += PSetI.getWeight(); - } + + Set.insert(Reg); } - return DiffSetPressure; + for (unsigned Reg : Set) { + // Check register is not already alive (at least some lanes) + if (LiveRegs.find(Reg) == LiveRegs.end()) { + PSetIterator PSetI = DAG->getMRI()->getPressureSets(Reg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == VGPRSetID) + DiffVGPR += PSetI.getWeight(); + if (*PSetI == SGPRSetID) + DiffSGPR += PSetI.getWeight(); + } + } + } } // SIScheduler // @@ -1858,7 +2056,7 @@ VgprUsage = 0; SgprUsage = 0; for (_Iterator RegI = First; RegI != End; ++RegI) { - unsigned Reg = *RegI; + unsigned Reg = RegI->first; // For now only track virtual registers if (!TargetRegisterInfo::isVirtualRegister(Reg)) continue;