Index: include/llvm/CodeGen/MachineScheduler.h =================================================================== --- include/llvm/CodeGen/MachineScheduler.h +++ include/llvm/CodeGen/MachineScheduler.h @@ -462,7 +462,7 @@ void initRegPressure(); - void updatePressureDiffs(ArrayRef LiveUses); + void updatePressureDiffs(ArrayRef LiveUses); void updateScheduledPressure(const SUnit *SU, const std::vector &NewMaxPressure); Index: include/llvm/CodeGen/RegisterPressure.h =================================================================== --- include/llvm/CodeGen/RegisterPressure.h +++ include/llvm/CodeGen/RegisterPressure.h @@ -26,14 +26,22 @@ class RegisterClassInfo; class MachineInstr; +struct RegisterMaskPair { + unsigned RegUnit; ///< Virtual register or register unit. + LaneBitmask LaneMask; + + RegisterMaskPair(unsigned RegUnit, LaneBitmask LaneMask) + : RegUnit(RegUnit), LaneMask(LaneMask) {} +}; + /// Base class for register pressure results. struct RegisterPressure { /// Map of max reg pressure indexed by pressure set ID, not class ID. std::vector MaxSetPressure; /// List of live in virtual registers or physical register units. - SmallVector LiveInRegs; - SmallVector LiveOutRegs; + SmallVector LiveInRegs; + SmallVector LiveOutRegs; void dump(const TargetRegisterInfo *TRI) const; }; @@ -144,16 +152,23 @@ /// List of register defined and used by a machine instruction. class RegisterOperands { public: - SmallVector Uses; - SmallVector Defs; - SmallVector DeadDefs; + SmallVector Uses; + SmallVector Defs; + SmallVector DeadDefs; void collect(const MachineInstr &MI, const TargetRegisterInfo &TRI, - const MachineRegisterInfo &MRI, bool IgnoreDead = false); + const MachineRegisterInfo &MRI, bool TrackLaneMasks, + bool IgnoreDead); /// Use liveness information to find dead defs not marked with a dead flag /// and move them to the DeadDefs vector. void detectDeadDefs(const MachineInstr &MI, const LiveIntervals &LIS); + + /// Use liveness information to find out which uses/defs are partially + /// undefined/dead and adjust the RegisterMaskPairs accordingly. + void adjustLaneLiveness(const LiveIntervals &LIS, + const MachineRegisterInfo &MRI, SlotIndex Pos); + }; /// Array of PressureDiffs. @@ -218,7 +233,20 @@ /// and virtual register indexes to an index usable by the sparse set. class LiveRegSet { private: - SparseSet Regs; + struct IndexMaskPair { + unsigned Index; + LaneBitmask LaneMask; + + IndexMaskPair(unsigned Index, LaneBitmask LaneMask) + : Index(Index), LaneMask(LaneMask) {} + + unsigned getSparseSetIndex() const { + return Index; + } + }; + + typedef SparseSet RegSet; + RegSet Regs; unsigned NumRegUnits; unsigned getSparseIndexFromReg(unsigned Reg) const { @@ -237,19 +265,37 @@ void clear(); void init(const MachineRegisterInfo &MRI); - bool contains(unsigned Reg) const { + LaneBitmask contains(unsigned Reg) const { unsigned SparseIndex = getSparseIndexFromReg(Reg); - return Regs.count(SparseIndex); + RegSet::const_iterator I = Regs.find(SparseIndex); + if (I == Regs.end()) + return 0; + return I->LaneMask; } - bool insert(unsigned Reg) { - unsigned SparseIndex = getSparseIndexFromReg(Reg); - return Regs.insert(SparseIndex).second; + /// Mark the \p Pair.LaneMask lanes of \p Pair.Reg as live. + /// Returns the previously live lanes of \p Pair.Reg. + LaneBitmask insert(RegisterMaskPair Pair) { + unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit); + auto InsertRes = Regs.insert(IndexMaskPair(SparseIndex, Pair.LaneMask)); + if (!InsertRes.second) { + unsigned PrevMask = InsertRes.first->LaneMask; + InsertRes.first->LaneMask |= Pair.LaneMask; + return PrevMask; + } + return 0; } - bool erase(unsigned Reg) { - unsigned SparseIndex = getSparseIndexFromReg(Reg); - return Regs.erase(SparseIndex); + /// Clears the \p Pair.LaneMask lanes of \p Pair.Reg (mark them as dead). + /// Returns the previously live lanes of \p Pair.Reg. + LaneBitmask erase(RegisterMaskPair Pair) { + unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit); + RegSet::iterator I = Regs.find(SparseIndex); + if (I == Regs.end()) + return 0; + unsigned PrevMask = I->LaneMask; + I->LaneMask &= ~Pair.LaneMask; + return PrevMask; } size_t size() const { @@ -258,9 +304,10 @@ template void appendTo(ContainerT &To) const { - for (unsigned I : Regs) { - unsigned Reg = getRegFromSparseIndex(I); - To.push_back(Reg); + for (const IndexMaskPair &P : Regs) { + unsigned Reg = getRegFromSparseIndex(P.Index); + if (P.LaneMask != 0) + To.push_back(RegisterMaskPair(Reg, P.LaneMask)); } } }; @@ -301,6 +348,9 @@ /// True if UntiedDefs will be populated. bool TrackUntiedDefs; + /// True if lanemasks should be tracked. + bool TrackLaneMasks; + /// Register pressure corresponds to liveness before this instruction /// iterator. It may point to the end of the block or a DebugValue rather than /// an instruction. @@ -320,23 +370,23 @@ public: RegPressureTracker(IntervalPressure &rp) : MF(nullptr), TRI(nullptr), RCI(nullptr), LIS(nullptr), MBB(nullptr), P(rp), - RequireIntervals(true), TrackUntiedDefs(false) {} + RequireIntervals(true), TrackUntiedDefs(false), TrackLaneMasks(false) {} RegPressureTracker(RegionPressure &rp) : MF(nullptr), TRI(nullptr), RCI(nullptr), LIS(nullptr), MBB(nullptr), P(rp), - RequireIntervals(false), TrackUntiedDefs(false) {} + RequireIntervals(false), TrackUntiedDefs(false), TrackLaneMasks(false) {} void reset(); void init(const MachineFunction *mf, const RegisterClassInfo *rci, const LiveIntervals *lis, const MachineBasicBlock *mbb, MachineBasicBlock::const_iterator pos, - bool ShouldTrackUntiedDefs = false); + bool TrackLaneMasks, bool TrackUntiedDefs); /// Force liveness of virtual registers or physical register /// units. Particularly useful to initialize the livein/out state of the /// tracker before the first call to advance/recede. - void addLiveRegs(ArrayRef Regs); + void addLiveRegs(ArrayRef Regs); /// Get the MI position corresponding to this register pressure. MachineBasicBlock::const_iterator getPos() const { return CurrPos; } @@ -348,14 +398,14 @@ void setPos(MachineBasicBlock::const_iterator Pos) { CurrPos = Pos; } /// Recede across the previous instruction. - void recede(SmallVectorImpl *LiveUses = nullptr); + void recede(SmallVectorImpl *LiveUses = nullptr); /// Recede across the previous instruction. /// This "low-level" variant assumes that recedeSkipDebugValues() was /// called previously and takes precomputed RegisterOperands for the /// instruction. void recede(const RegisterOperands &RegOpers, - SmallVectorImpl *LiveUses = nullptr); + SmallVectorImpl *LiveUses = nullptr); /// Recede until we find an instruction which is not a DebugValue. void recedeSkipDebugValues(); @@ -462,18 +512,31 @@ void dump() const; protected: - void discoverLiveOut(unsigned Reg); - void discoverLiveIn(unsigned Reg); + /// Add Reg to the live out set and increase max pressure. + void discoverLiveOut(RegisterMaskPair Pair); + /// Add Reg to the live in set and increase max pressure. + void discoverLiveIn(RegisterMaskPair Pair); /// \brief Get the SlotIndex for the first nondebug instruction including or /// after the current position. SlotIndex getCurrSlot() const; - void increaseRegPressure(ArrayRef Regs); - void decreaseRegPressure(ArrayRef Regs); + void increaseRegPressure(unsigned RegUnit, LaneBitmask PreviousMask, + LaneBitmask NewMask); + void decreaseRegPressure(unsigned RegUnit, LaneBitmask PreviousMask, + LaneBitmask NewMask); + + void bumpDeadDefs(ArrayRef DeadDefs); void bumpUpwardPressure(const MachineInstr *MI); void bumpDownwardPressure(const MachineInstr *MI); + + void discoverLiveInOrOut(RegisterMaskPair Pair, + SmallVectorImpl &LiveInOrOut); + + LaneBitmask getLastUsedLanes(unsigned RegUnit, SlotIndex Pos) const; + LaneBitmask getLiveLanesAt(unsigned RegUnit, SlotIndex Pos) const; + LaneBitmask getLiveThroughAt(unsigned RegUnit, SlotIndex Pos) const; }; void dumpRegSetPressure(ArrayRef SetPressure, Index: lib/CodeGen/MachineScheduler.cpp =================================================================== --- lib/CodeGen/MachineScheduler.cpp +++ lib/CodeGen/MachineScheduler.cpp @@ -874,8 +874,8 @@ // Setup the register pressure trackers for the top scheduled top and bottom // scheduled regions. void ScheduleDAGMILive::initRegPressure() { - TopRPTracker.init(&MF, RegClassInfo, LIS, BB, RegionBegin); - BotRPTracker.init(&MF, RegClassInfo, LIS, BB, LiveRegionEnd); + TopRPTracker.init(&MF, RegClassInfo, LIS, BB, RegionBegin, false, false); + BotRPTracker.init(&MF, RegClassInfo, LIS, BB, LiveRegionEnd, false, false); // Close the RPTracker to finalize live ins. RPTracker.closeRegion(); @@ -905,7 +905,7 @@ // Account for liveness generated by the region boundary. if (LiveRegionEnd != RegionEnd) { - SmallVector LiveUses; + SmallVector LiveUses; BotRPTracker.recede(&LiveUses); updatePressureDiffs(LiveUses); } @@ -969,10 +969,12 @@ /// Update the PressureDiff array for liveness after scheduling this /// instruction. -void ScheduleDAGMILive::updatePressureDiffs(ArrayRef LiveUses) { - for (unsigned LUIdx = 0, LUEnd = LiveUses.size(); LUIdx != LUEnd; ++LUIdx) { +void ScheduleDAGMILive::updatePressureDiffs( + ArrayRef LiveUses) { + for (const RegisterMaskPair &P : LiveUses) { /// FIXME: Currently assuming single-use physregs. - unsigned Reg = LiveUses[LUIdx]; + unsigned Reg = P.RegUnit; + assert(P.LaneMask != 0); DEBUG(dbgs() << " LiveReg: " << PrintVRegOrUnit(Reg, TRI) << "\n"); if (!TRI->isVirtualRegister(Reg)) continue; @@ -1111,7 +1113,7 @@ // Initialize the register pressure tracker used by buildSchedGraph. RPTracker.init(&MF, RegClassInfo, LIS, BB, LiveRegionEnd, - /*TrackUntiedDefs=*/true); + false, /*TrackUntiedDefs=*/true); // Account for liveness generate by the region boundary. if (LiveRegionEnd != RegionEnd) @@ -1167,10 +1169,8 @@ unsigned MaxCyclicLatency = 0; // Visit each live out vreg def to find def/use pairs that cross iterations. - ArrayRef LiveOuts = RPTracker.getPressure().LiveOutRegs; - for (ArrayRef::iterator RI = LiveOuts.begin(), RE = LiveOuts.end(); - RI != RE; ++RI) { - unsigned Reg = *RI; + for (const RegisterMaskPair &P : RPTracker.getPressure().LiveOutRegs) { + unsigned Reg = P.RegUnit; if (!TRI->isVirtualRegister(Reg)) continue; const LiveInterval &LI = LIS->getInterval(Reg); @@ -1265,7 +1265,7 @@ } if (ShouldTrackPressure) { // Update bottom scheduled pressure. - SmallVector LiveUses; + SmallVector LiveUses; BotRPTracker.recede(&LiveUses); assert(BotRPTracker.getPos() == CurrentBottom && "out of sync"); DEBUG( Index: lib/CodeGen/RegisterPressure.cpp =================================================================== --- lib/CodeGen/RegisterPressure.cpp +++ lib/CodeGen/RegisterPressure.cpp @@ -24,7 +24,13 @@ /// Increase pressure for each pressure set provided by TargetRegisterInfo. static void increaseSetPressure(std::vector &CurrSetPressure, - PSetIterator PSetI) { + const MachineRegisterInfo &MRI, unsigned Reg, + LaneBitmask PrevMask, LaneBitmask NewMask) { + assert((PrevMask & ~NewMask) == 0 && "Must not remove bits"); + if (PrevMask != 0 || NewMask == 0) + return; + + PSetIterator PSetI = MRI.getPressureSets(Reg); unsigned Weight = PSetI.getWeight(); for (; PSetI.isValid(); ++PSetI) CurrSetPressure[*PSetI] += Weight; @@ -32,7 +38,13 @@ /// Decrease pressure for each pressure set provided by TargetRegisterInfo. static void decreaseSetPressure(std::vector &CurrSetPressure, - PSetIterator PSetI) { + const MachineRegisterInfo &MRI, unsigned Reg, + LaneBitmask PrevMask, LaneBitmask NewMask) { + assert((NewMask & !PrevMask) == 0 && "Must not add bits"); + if (NewMask != 0 || PrevMask == 0) + return; + + PSetIterator PSetI = MRI.getPressureSets(Reg); unsigned Weight = PSetI.getWeight(); for (; PSetI.isValid(); ++PSetI) { assert(CurrSetPressure[*PSetI] >= Weight && "register pressure underflow"); @@ -59,12 +71,20 @@ dbgs() << "Max Pressure: "; dumpRegSetPressure(MaxSetPressure, TRI); dbgs() << "Live In: "; - for (unsigned Reg : LiveInRegs) - dbgs() << PrintVRegOrUnit(Reg, TRI) << " "; + for (const RegisterMaskPair &P : LiveInRegs) { + dbgs() << PrintVRegOrUnit(P.RegUnit, TRI); + if (P.LaneMask != ~0u) + dbgs() << ':' << PrintLaneMask(P.LaneMask); + dbgs() << ' '; + } dbgs() << '\n'; dbgs() << "Live Out: "; - for (unsigned Reg : LiveOutRegs) - dbgs() << PrintVRegOrUnit(Reg, TRI) << " "; + for (const RegisterMaskPair &P : LiveOutRegs) { + dbgs() << PrintVRegOrUnit(P.RegUnit, TRI); + if (P.LaneMask != ~0u) + dbgs() << ':' << PrintLaneMask(P.LaneMask); + dbgs() << ' '; + } dbgs() << '\n'; } @@ -89,25 +109,25 @@ dbgs() << '\n'; } -/// Increase the current pressure as impacted by these registers and bump -/// the high water mark if needed. -void RegPressureTracker::increaseRegPressure(ArrayRef RegUnits) { - for (unsigned RegUnit : RegUnits) { - PSetIterator PSetI = MRI->getPressureSets(RegUnit); - unsigned Weight = PSetI.getWeight(); - for (; PSetI.isValid(); ++PSetI) { - CurrSetPressure[*PSetI] += Weight; - if (CurrSetPressure[*PSetI] > P.MaxSetPressure[*PSetI]) { - P.MaxSetPressure[*PSetI] = CurrSetPressure[*PSetI]; - } - } +void RegPressureTracker::increaseRegPressure(unsigned RegUnit, + LaneBitmask PreviousMask, + LaneBitmask NewMask) { + if (PreviousMask != 0 || NewMask == 0) + return; + + PSetIterator PSetI = MRI->getPressureSets(RegUnit); + unsigned Weight = PSetI.getWeight(); + for (; PSetI.isValid(); ++PSetI) { + CurrSetPressure[*PSetI] += Weight; + if (CurrSetPressure[*PSetI] > P.MaxSetPressure[*PSetI]) + P.MaxSetPressure[*PSetI] = CurrSetPressure[*PSetI]; } } -/// Simply decrease the current pressure as impacted by these registers. -void RegPressureTracker::decreaseRegPressure(ArrayRef RegUnits) { - for (unsigned RegUnit : RegUnits) - decreaseSetPressure(CurrSetPressure, MRI->getPressureSets(RegUnit)); +void RegPressureTracker::decreaseRegPressure(unsigned RegUnit, + LaneBitmask PreviousMask, + LaneBitmask NewMask) { + decreaseSetPressure(CurrSetPressure, *MRI, RegUnit, PreviousMask, NewMask); } /// Clear the result so it can be used for another round of pressure tracking. @@ -202,8 +222,7 @@ const LiveIntervals *lis, const MachineBasicBlock *mbb, MachineBasicBlock::const_iterator pos, - bool ShouldTrackUntiedDefs) -{ + bool TrackLaneMasks, bool TrackUntiedDefs) { reset(); MF = mf; @@ -211,7 +230,8 @@ RCI = rci; MRI = &MF->getRegInfo(); MBB = mbb; - TrackUntiedDefs = ShouldTrackUntiedDefs; + this->TrackUntiedDefs = TrackUntiedDefs; + this->TrackLaneMasks = TrackLaneMasks; if (RequireIntervals) { assert(lis && "IntervalPressure requires LiveIntervals"); @@ -298,20 +318,92 @@ void RegPressureTracker::initLiveThru(const RegPressureTracker &RPTracker) { LiveThruPressure.assign(TRI->getNumRegPressureSets(), 0); assert(isBottomClosed() && "need bottom-up tracking to intialize."); - for (unsigned Reg : P.LiveOutRegs) { - if (TargetRegisterInfo::isVirtualRegister(Reg) - && !RPTracker.hasUntiedDef(Reg)) { - increaseSetPressure(LiveThruPressure, MRI->getPressureSets(Reg)); - } + for (const RegisterMaskPair &Pair : P.LiveOutRegs) { + unsigned RegUnit = Pair.RegUnit; + if (TargetRegisterInfo::isVirtualRegister(RegUnit) + && !RPTracker.hasUntiedDef(RegUnit)) + increaseSetPressure(LiveThruPressure, *MRI, RegUnit, 0, Pair.LaneMask); + } +} + +static unsigned getRegLanes(ArrayRef RegUnits, + unsigned RegUnit) { + auto I = std::find_if(RegUnits.begin(), RegUnits.end(), + [RegUnit](const RegisterMaskPair Other) { + return Other.RegUnit == RegUnit; + }); + if (I == RegUnits.end()) + return 0; + return I->LaneMask; +} + +static void addRegLanes(SmallVectorImpl &RegUnits, + RegisterMaskPair Pair) { + unsigned RegUnit = Pair.RegUnit; + assert(Pair.LaneMask != 0); + auto I = std::find_if(RegUnits.begin(), RegUnits.end(), + [RegUnit](const RegisterMaskPair Other) { + return Other.RegUnit == RegUnit; + }); + if (I == RegUnits.end()) { + RegUnits.push_back(Pair); + } else { + I->LaneMask |= Pair.LaneMask; + } +} + +static void removeRegLanes(SmallVectorImpl &RegUnits, + RegisterMaskPair Pair) { + unsigned RegUnit = Pair.RegUnit; + assert(Pair.LaneMask != 0); + auto I = std::find_if(RegUnits.begin(), RegUnits.end(), + [RegUnit](const RegisterMaskPair Other) { + return Other.RegUnit == RegUnit; + }); + if (I != RegUnits.end()) { + I->LaneMask &= ~Pair.LaneMask; + if (I->LaneMask == 0) + RegUnits.erase(I); + } +} + +static LaneBitmask getLanesWithProperty(const LiveIntervals &LIS, + const MachineRegisterInfo &MRI, bool TrackLaneMasks, unsigned RegUnit, + SlotIndex Pos, + bool(*Property)(const LiveRange &LR, SlotIndex Pos)) { + if (TargetRegisterInfo::isVirtualRegister(RegUnit)) { + const LiveInterval &LI = LIS.getInterval(RegUnit); + LaneBitmask Result = 0; + if (TrackLaneMasks && LI.hasSubRanges()) { + for (const LiveInterval::SubRange &SR : LI.subranges()) { + if (Property(SR, Pos)) + Result |= SR.LaneMask; + } + } else if (Property(LI, Pos)) + Result = MRI.getMaxLaneMaskForVReg(RegUnit); + + return Result; + } else { + const LiveRange *LR = LIS.getCachedRegUnit(RegUnit); + // Be prepared for missing liveranges: We usually do not compute liveranges + // for physical registers on targets with many registers (GPUs). + if (LR == nullptr) + return 0; + return Property(*LR, Pos) ? ~0u : 0; } } -/// \brief Convenient wrapper for checking membership in RegisterOperands. -/// (std::count() doesn't have an early exit). -static bool containsReg(ArrayRef RegUnits, unsigned RegUnit) { - return std::find(RegUnits.begin(), RegUnits.end(), RegUnit) != RegUnits.end(); +static LaneBitmask getLiveLanesAt(const LiveIntervals &LIS, + const MachineRegisterInfo &MRI, + bool TrackLaneMasks, unsigned RegUnit, + SlotIndex Pos) { + return getLanesWithProperty(LIS, MRI, TrackLaneMasks, RegUnit, Pos, + [](const LiveRange &LR, SlotIndex Pos) { + return LR.liveAt(Pos); + }); } + namespace { /// Collect this instruction's unique uses and defs into SmallVectors for @@ -322,23 +414,23 @@ RegisterOperands &RegOpers; const TargetRegisterInfo &TRI; const MachineRegisterInfo &MRI; + bool TrackLaneMasks; bool IgnoreDead; RegisterOperandsCollector(RegisterOperands &RegOpers, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, - bool IgnoreDead) - : RegOpers(RegOpers), TRI(TRI), MRI(MRI), IgnoreDead(IgnoreDead) {} + bool TrackLaneMasks, bool IgnoreDead) + : RegOpers(RegOpers), TRI(TRI), MRI(MRI), + TrackLaneMasks(TrackLaneMasks), IgnoreDead(IgnoreDead) {} void collectInstr(const MachineInstr &MI) const { for (ConstMIBundleOperands OperI(&MI); OperI.isValid(); ++OperI) collectOperand(*OperI); // Remove redundant physreg dead defs. - SmallVectorImpl::iterator I = - std::remove_if(RegOpers.DeadDefs.begin(), RegOpers.DeadDefs.end(), - std::bind1st(std::ptr_fun(containsReg), RegOpers.Defs)); - RegOpers.DeadDefs.erase(I, RegOpers.DeadDefs.end()); + for (const RegisterMaskPair &P : RegOpers.Defs) + removeRegLanes(RegOpers.DeadDefs, P); } /// Push this operand's register onto the correct vectors. @@ -346,28 +438,39 @@ if (!MO.isReg() || !MO.getReg()) return; unsigned Reg = MO.getReg(); - if (MO.readsReg()) - pushRegUnits(Reg, RegOpers.Uses); - if (MO.isDef()) { + unsigned SubRegIdx = MO.getSubReg(); + if (MO.isUse()) { + if (!MO.isUndef() && !MO.isInternalRead()) + pushRegUnits(Reg, SubRegIdx, RegOpers.Uses); + } else { + assert(MO.isDef()); + if (MO.isUndef()) { + // Treat read-undef subreg defs as definitions of the whole register. + SubRegIdx = 0; + } else if (!TrackLaneMasks && SubRegIdx != 0 && !MO.isInternalRead()) { + // Interpret the subregister def as read-modify-store: A use+def of the + // full register. + pushRegUnits(Reg, SubRegIdx, RegOpers.Uses); + } + if (MO.isDead()) { if (!IgnoreDead) - pushRegUnits(Reg, RegOpers.DeadDefs); + pushRegUnits(Reg, SubRegIdx, RegOpers.DeadDefs); } else - pushRegUnits(Reg, RegOpers.Defs); + pushRegUnits(Reg, SubRegIdx, RegOpers.Defs); } } - void pushRegUnits(unsigned Reg, SmallVectorImpl &RegUnits) const { + void pushRegUnits(unsigned Reg, unsigned SubRegIdx, + SmallVectorImpl &RegUnits) const { if (TargetRegisterInfo::isVirtualRegister(Reg)) { - if (containsReg(RegUnits, Reg)) - return; - RegUnits.push_back(Reg); + LaneBitmask LaneMask = TrackLaneMasks && SubRegIdx != 0 + ? TRI.getSubRegIndexLaneMask(SubRegIdx) + : MRI.getMaxLaneMaskForVReg(Reg); + addRegLanes(RegUnits, RegisterMaskPair(Reg, LaneMask)); } else if (MRI.isAllocatable(Reg)) { - for (MCRegUnitIterator Units(Reg, &TRI); Units.isValid(); ++Units) { - if (containsReg(RegUnits, *Units)) - continue; - RegUnits.push_back(*Units); - } + for (MCRegUnitIterator Units(Reg, &TRI); Units.isValid(); ++Units) + addRegLanes(RegUnits, RegisterMaskPair(*Units, ~0u)); } } @@ -379,24 +482,24 @@ void RegisterOperands::collect(const MachineInstr &MI, const TargetRegisterInfo &TRI, const MachineRegisterInfo &MRI, - bool IgnoreDead) { - RegisterOperandsCollector Collector(*this, TRI, MRI, IgnoreDead); + bool TrackLaneMasks, bool IgnoreDead) { + RegisterOperandsCollector Collector(*this, TRI, MRI, TrackLaneMasks, + IgnoreDead); Collector.collectInstr(MI); } void RegisterOperands::detectDeadDefs(const MachineInstr &MI, const LiveIntervals &LIS) { SlotIndex SlotIdx = LIS.getInstructionIndex(&MI); - for (SmallVectorImpl::iterator RI = Defs.begin(); - RI != Defs.end(); /*empty*/) { - unsigned Reg = *RI; + for (auto RI = Defs.begin(); RI != Defs.end(); /*empty*/) { + unsigned Reg = RI->RegUnit; const LiveRange *LR = getLiveRange(LIS, Reg); if (LR != nullptr) { LiveQueryResult LRQ = LR->Query(SlotIdx); if (LRQ.isDeadDef()) { // LiveIntervals knows this is a dead even though it's MachineOperand is // not flagged as such. - DeadDefs.push_back(Reg); + DeadDefs.push_back(*RI); RI = Defs.erase(RI); continue; } @@ -405,6 +508,38 @@ } } +void RegisterOperands::adjustLaneLiveness(const LiveIntervals &LIS, + const MachineRegisterInfo &MRI, + SlotIndex Pos) { + for (auto I = Defs.begin(); I != Defs.end(); ) { + LaneBitmask LiveAfter = getLiveLanesAt(LIS, MRI, true, I->RegUnit, + Pos.getDeadSlot()); +#if 0 + unsigned DeadDef = I->LaneMask & ~LiveAfter; + if (DeadDef != 0) + addRegLanes(DeadDefs, RegisterMaskPair(I->RegUnit, DeadDef)); +#endif + unsigned LaneMask = I->LaneMask & LiveAfter; + if (LaneMask == 0) + I = Defs.erase(I); + else { + I->LaneMask = LaneMask; + ++I; + } + } + for (auto I = Uses.begin(); I != Uses.end(); ) { + LaneBitmask LiveBefore = getLiveLanesAt(LIS, MRI, true, I->RegUnit, + Pos.getBaseIndex()); + unsigned LaneMask = I->LaneMask & LiveBefore; + if (LaneMask == 0) { + I = Uses.erase(I); + } else { + I->LaneMask = LaneMask; + ++I; + } + } +} + /// Initialize an array of N PressureDiffs. void PressureDiffs::init(unsigned N) { Size = N; @@ -422,11 +557,11 @@ const MachineRegisterInfo &MRI) { PressureDiff &PDiff = (*this)[Idx]; assert(!PDiff.begin()->isValid() && "stale PDiff"); - for (unsigned Reg : RegOpers.Defs) - PDiff.addPressureChange(Reg, true, &MRI); + for (const RegisterMaskPair &P : RegOpers.Defs) + PDiff.addPressureChange(P.RegUnit, true, &MRI); - for (unsigned Reg : RegOpers.Uses) - PDiff.addPressureChange(Reg, false, &MRI); + for (const RegisterMaskPair &P : RegOpers.Uses) + PDiff.addPressureChange(P.RegUnit, false, &MRI); } /// Add a change in pressure to the pressure diff of a given instruction. @@ -466,33 +601,59 @@ } /// Force liveness of registers. -void RegPressureTracker::addLiveRegs(ArrayRef Regs) { - for (unsigned Reg : Regs) { - if (LiveRegs.insert(Reg)) - increaseRegPressure(Reg); +void RegPressureTracker::addLiveRegs(ArrayRef Regs) { + for (const RegisterMaskPair &P : Regs) { + unsigned PrevMask = LiveRegs.insert(P); + unsigned NewMask = PrevMask | P.LaneMask; + increaseRegPressure(P.RegUnit, PrevMask, NewMask); } } -/// Add Reg to the live in set and increase max pressure. -void RegPressureTracker::discoverLiveIn(unsigned Reg) { - assert(!LiveRegs.contains(Reg) && "avoid bumping max pressure twice"); - if (containsReg(P.LiveInRegs, Reg)) +void RegPressureTracker::discoverLiveInOrOut(RegisterMaskPair Pair, + SmallVectorImpl &LiveInOrOut) { + if (Pair.LaneMask == 0) return; - // At live in discovery, unconditionally increase the high water mark. - P.LiveInRegs.push_back(Reg); - increaseSetPressure(P.MaxSetPressure, MRI->getPressureSets(Reg)); + unsigned RegUnit = Pair.RegUnit; + auto I = std::find_if(LiveInOrOut.begin(), LiveInOrOut.end(), + [RegUnit](const RegisterMaskPair &Other) { + return Other.RegUnit == RegUnit; + }); + LaneBitmask PrevMask; + LaneBitmask NewMask; + if (I == LiveInOrOut.end()) { + PrevMask = 0; + NewMask = Pair.LaneMask; + LiveInOrOut.push_back(Pair); + } else { + PrevMask = I->LaneMask; + NewMask = PrevMask | Pair.LaneMask; + I->LaneMask = NewMask; + } + increaseSetPressure(P.MaxSetPressure, *MRI, RegUnit, PrevMask, NewMask); } -/// Add Reg to the live out set and increase max pressure. -void RegPressureTracker::discoverLiveOut(unsigned Reg) { - assert(!LiveRegs.contains(Reg) && "avoid bumping max pressure twice"); - if (containsReg(P.LiveOutRegs, Reg)) - return; +void RegPressureTracker::discoverLiveIn(RegisterMaskPair Pair) { + discoverLiveInOrOut(Pair, P.LiveInRegs); +} - // At live out discovery, unconditionally increase the high water mark. - P.LiveOutRegs.push_back(Reg); - increaseSetPressure(P.MaxSetPressure, MRI->getPressureSets(Reg)); +void RegPressureTracker::discoverLiveOut(RegisterMaskPair Pair) { + discoverLiveInOrOut(Pair, P.LiveOutRegs); +} + +void RegPressureTracker::bumpDeadDefs(ArrayRef DeadDefs) { + for (const RegisterMaskPair &P : DeadDefs) { + unsigned Reg = P.RegUnit; + LaneBitmask LiveMask = LiveRegs.contains(Reg); + LaneBitmask BumpedMask = LiveMask | P.LaneMask; + increaseRegPressure(Reg, LiveMask, BumpedMask); + } + for (const RegisterMaskPair &P : DeadDefs) { + unsigned Reg = P.RegUnit; + LaneBitmask LiveMask = LiveRegs.contains(Reg); + LaneBitmask BumpedMask = LiveMask | P.LaneMask; + decreaseRegPressure(Reg, BumpedMask, LiveMask); + } } /// Recede across the previous instruction. If LiveUses is provided, record any @@ -501,20 +662,29 @@ /// difference pointer is provided record the changes is pressure caused by this /// instruction independent of liveness. void RegPressureTracker::recede(const RegisterOperands &RegOpers, - SmallVectorImpl *LiveUses) { + SmallVectorImpl *LiveUses) { assert(!CurrPos->isDebugValue()); // Boost pressure for all dead defs together. - increaseRegPressure(RegOpers.DeadDefs); - decreaseRegPressure(RegOpers.DeadDefs); + bumpDeadDefs(RegOpers.DeadDefs); // Kill liveness at live defs. // TODO: consider earlyclobbers? - for (unsigned Reg : RegOpers.Defs) { - if (LiveRegs.erase(Reg)) - decreaseRegPressure(Reg); - else - discoverLiveOut(Reg); + for (const RegisterMaskPair &Def : RegOpers.Defs) { + unsigned Reg = Def.RegUnit; + + LaneBitmask PreviousMask = LiveRegs.erase(Def); + LaneBitmask NewMask = PreviousMask & ~Def.LaneMask; + + LaneBitmask LiveOut = Def.LaneMask & ~PreviousMask; + if (LiveOut != 0) { + discoverLiveOut(RegisterMaskPair(Reg, LiveOut)); + // Retroactively model effects on pressure of the live out lanes. + increaseSetPressure(CurrSetPressure, *MRI, Reg, 0, LiveOut); + PreviousMask = LiveOut; + } + + decreaseRegPressure(Reg, PreviousMask, NewMask); } SlotIndex SlotIdx; @@ -522,27 +692,34 @@ SlotIdx = LIS->getInstructionIndex(CurrPos).getRegSlot(); // Generate liveness for uses. - for (unsigned Reg : RegOpers.Uses) { - if (!LiveRegs.contains(Reg)) { - // Adjust liveouts if LiveIntervals are available. - if (RequireIntervals) { - const LiveRange *LR = getLiveRange(*LIS, Reg); - if (LR) { - LiveQueryResult LRQ = LR->Query(SlotIdx); - if (!LRQ.isKill() && !LRQ.valueDefined()) - discoverLiveOut(Reg); - } + for (const RegisterMaskPair &Use : RegOpers.Uses) { + unsigned Reg = Use.RegUnit; + assert(Use.LaneMask != 0); + LaneBitmask PreviousMask = LiveRegs.insert(Use); + LaneBitmask NewMask = PreviousMask | Use.LaneMask; + if (NewMask == PreviousMask) + continue; + + // Did the register just become live? + if (PreviousMask == 0) { + if (LiveUses != nullptr) { + unsigned NewLanes = NewMask & ~PreviousMask; + addRegLanes(*LiveUses, RegisterMaskPair(Reg, NewLanes)); } - increaseRegPressure(Reg); - LiveRegs.insert(Reg); - if (LiveUses && !containsReg(*LiveUses, Reg)) - LiveUses->push_back(Reg); + + // Discover live outs if this may be the first occurance of this register. + LaneBitmask LiveOut = getLiveThroughAt(Reg, SlotIdx); + discoverLiveOut(RegisterMaskPair(Reg, LiveOut)); } + + increaseRegPressure(Reg, PreviousMask, NewMask); } if (TrackUntiedDefs) { - for (unsigned Reg : RegOpers.Defs) { - if (TargetRegisterInfo::isVirtualRegister(Reg) && !LiveRegs.contains(Reg)) - UntiedDefs.insert(Reg); + for (const RegisterMaskPair &Def : RegOpers.Defs) { + unsigned RegUnit = Def.RegUnit; + if (TargetRegisterInfo::isVirtualRegister(RegUnit) && + (LiveRegs.contains(RegUnit) & Def.LaneMask) == 0) + UntiedDefs.insert(RegUnit); } } } @@ -570,14 +747,18 @@ static_cast(P).openTop(SlotIdx); } -void RegPressureTracker::recede(SmallVectorImpl *LiveUses) { +void RegPressureTracker::recede(SmallVectorImpl *LiveUses) { recedeSkipDebugValues(); const MachineInstr &MI = *CurrPos; RegisterOperands RegOpers; - RegOpers.collect(MI, *TRI, *MRI); - if (RequireIntervals) + RegOpers.collect(MI, *TRI, *MRI, TrackLaneMasks, false); + if (TrackLaneMasks) { + SlotIndex SlotIdx = LIS->getInstructionIndex(CurrPos).getRegSlot(); + RegOpers.adjustLaneLiveness(*LIS, *MRI, SlotIdx); + } else if (RequireIntervals) { RegOpers.detectDeadDefs(MI, *LIS); + } recede(RegOpers, LiveUses); } @@ -603,38 +784,36 @@ } RegisterOperands RegOpers; - RegOpers.collect(*CurrPos, *TRI, *MRI); - - for (unsigned Reg : RegOpers.Uses) { - // Discover live-ins. - bool isLive = LiveRegs.contains(Reg); - if (!isLive) - discoverLiveIn(Reg); + RegOpers.collect(*CurrPos, *TRI, *MRI, TrackLaneMasks, false); + if (TrackLaneMasks) + RegOpers.adjustLaneLiveness(*LIS, *MRI, SlotIdx); + + for (const RegisterMaskPair &Use : RegOpers.Uses) { + unsigned Reg = Use.RegUnit; + LaneBitmask LiveMask = LiveRegs.contains(Reg); + LaneBitmask LiveIn = Use.LaneMask & ~LiveMask; + if (LiveIn != 0) { + discoverLiveIn(RegisterMaskPair(Reg, LiveIn)); + increaseRegPressure(Reg, LiveMask, LiveMask | LiveIn); + LiveRegs.insert(RegisterMaskPair(Reg, LiveIn)); + } // Kill liveness at last uses. - bool lastUse = false; - if (RequireIntervals) { - const LiveRange *LR = getLiveRange(*LIS, Reg); - lastUse = LR && LR->Query(SlotIdx).isKill(); - } else { - // Allocatable physregs are always single-use before register rewriting. - lastUse = !TargetRegisterInfo::isVirtualRegister(Reg); + LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx); + if (LastUseMask != 0) { + LiveRegs.erase(RegisterMaskPair(Reg, LastUseMask)); + decreaseRegPressure(Reg, LiveMask, LiveMask & ~LastUseMask); } - if (lastUse && isLive) { - LiveRegs.erase(Reg); - decreaseRegPressure(Reg); - } else if (!lastUse && !isLive) - increaseRegPressure(Reg); } // Generate liveness for defs. - for (unsigned Reg : RegOpers.Defs) { - if (LiveRegs.insert(Reg)) - increaseRegPressure(Reg); + for (const RegisterMaskPair &Def : RegOpers.Defs) { + LaneBitmask PreviousMask = LiveRegs.insert(Def); + LaneBitmask NewMask = PreviousMask | Def.LaneMask; + increaseRegPressure(Def.RegUnit, PreviousMask, NewMask); } // Boost pressure for all dead defs together. - increaseRegPressure(RegOpers.DeadDefs); - decreaseRegPressure(RegOpers.DeadDefs); + bumpDeadDefs(RegOpers.DeadDefs); // Find the next instruction. do @@ -729,22 +908,38 @@ void RegPressureTracker::bumpUpwardPressure(const MachineInstr *MI) { assert(!MI->isDebugValue() && "Expect a nondebug instruction."); + SlotIndex SlotIdx; + if (RequireIntervals) + SlotIdx = LIS->getInstructionIndex(MI).getRegSlot(); + // Account for register pressure similar to RegPressureTracker::recede(). RegisterOperands RegOpers; - RegOpers.collect(*MI, *TRI, *MRI, /*IgnoreDead=*/true); + RegOpers.collect(*MI, *TRI, *MRI, TrackLaneMasks, /*IgnoreDead=*/true); assert(RegOpers.DeadDefs.size() == 0); - if (RequireIntervals) + if (TrackLaneMasks) + RegOpers.adjustLaneLiveness(*LIS, *MRI, SlotIdx); + else if (RequireIntervals) RegOpers.detectDeadDefs(*MI, *LIS); + // Boost max pressure for all dead defs together. + // Since CurrSetPressure and MaxSetPressure + bumpDeadDefs(RegOpers.DeadDefs); + // Kill liveness at live defs. - for (unsigned Reg : RegOpers.Defs) { - if (!containsReg(RegOpers.Uses, Reg)) - decreaseRegPressure(Reg); + for (const RegisterMaskPair &P : RegOpers.Defs) { + unsigned Reg = P.RegUnit; + LaneBitmask LiveLanes = LiveRegs.contains(Reg); + LaneBitmask UseLanes = getRegLanes(RegOpers.Uses, Reg); + LaneBitmask DefLanes = P.LaneMask; + LaneBitmask LiveAfter = (LiveLanes & ~DefLanes) | UseLanes; + decreaseRegPressure(Reg, LiveLanes, LiveAfter); } // Generate liveness for uses. - for (unsigned Reg : RegOpers.Uses) { - if (!LiveRegs.contains(Reg)) - increaseRegPressure(Reg); + for (const RegisterMaskPair &P : RegOpers.Uses) { + unsigned Reg = P.RegUnit; + LaneBitmask LiveLanes = LiveRegs.contains(Reg); + LaneBitmask LiveAfter = LiveLanes | P.LaneMask; + increaseRegPressure(Reg, LiveLanes, LiveAfter); } } @@ -889,15 +1084,64 @@ } /// Helper to find a vreg use between two indices [PriorUseIdx, NextUseIdx). -static bool findUseBetween(unsigned Reg, SlotIndex PriorUseIdx, - SlotIndex NextUseIdx, const MachineRegisterInfo &MRI, - const LiveIntervals *LIS) { - for (const MachineInstr &MI : MRI.use_nodbg_instructions(Reg)) { - SlotIndex InstSlot = LIS->getInstructionIndex(&MI).getRegSlot(); - if (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx) - return true; +/// The query starts with a lane bitmask which gets lanes/bits removed for every +/// use we find. +static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask, + SlotIndex PriorUseIdx, SlotIndex NextUseIdx, + const MachineRegisterInfo &MRI, + const LiveIntervals *LIS) { + const TargetRegisterInfo &TRI = *MRI.getTargetRegisterInfo(); + for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) { + if (MO.isUndef()) + continue; + const MachineInstr *MI = MO.getParent(); + SlotIndex InstSlot = LIS->getInstructionIndex(MI).getRegSlot(); + if (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx) { + unsigned SubRegIdx = MO.getSubReg(); + LaneBitmask UseMask = TRI.getSubRegIndexLaneMask(SubRegIdx); + LastUseMask &= ~UseMask; + if (LastUseMask == 0) + return 0; + } } - return false; + return LastUseMask; +} + +LaneBitmask RegPressureTracker::getLiveLanesAt(unsigned RegUnit, + SlotIndex Pos) const { + if (!RequireIntervals) + return 0; + + return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit, Pos, + [](const LiveRange &LR, SlotIndex Pos) { + return LR.liveAt(Pos); + }); +} + +LaneBitmask RegPressureTracker::getLastUsedLanes(unsigned RegUnit, + SlotIndex Pos) const { + if (!RequireIntervals) + return 0; + + return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit, + Pos.getBaseIndex(), + [](const LiveRange &LR, SlotIndex Pos) { + const LiveRange::Segment *S = LR.getSegmentContaining(Pos); + return S != nullptr && S->end == Pos.getRegSlot(); + }); +} + +LaneBitmask RegPressureTracker::getLiveThroughAt(unsigned RegUnit, + SlotIndex Pos) const { + if (!RequireIntervals) + return 0; + + return getLanesWithProperty(*LIS, *MRI, TrackLaneMasks, RegUnit, Pos, + [](const LiveRange &LR, SlotIndex Pos) { + const LiveRange::Segment *S = LR.getSegmentContaining(Pos); + return S != nullptr && S->start < Pos.getRegSlot(true) && + S->end != Pos.getDeadSlot(); + }); } /// Record the downward impact of a single instruction on current register @@ -909,39 +1153,49 @@ void RegPressureTracker::bumpDownwardPressure(const MachineInstr *MI) { assert(!MI->isDebugValue() && "Expect a nondebug instruction."); - // Account for register pressure similar to RegPressureTracker::recede(). - RegisterOperands RegOpers; - RegOpers.collect(*MI, *TRI, *MRI); - - // Kill liveness at last uses. Assume allocatable physregs are single-use - // rather than checking LiveIntervals. SlotIndex SlotIdx; if (RequireIntervals) SlotIdx = LIS->getInstructionIndex(MI).getRegSlot(); - for (unsigned Reg : RegOpers.Uses) { + // Account for register pressure similar to RegPressureTracker::recede(). + RegisterOperands RegOpers; + RegOpers.collect(*MI, *TRI, *MRI, TrackLaneMasks, false); + if (TrackLaneMasks) + RegOpers.adjustLaneLiveness(*LIS, *MRI, SlotIdx); + + for (const RegisterMaskPair &Use : RegOpers.Uses) { + unsigned Reg = Use.RegUnit; + LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx); + if (LastUseMask == 0) + continue; if (RequireIntervals) { + // The LastUseMask is queried from the liveness information of instruction + // which may be further down the schedule. Some lanes may actually not be + // last uses for the current position. // FIXME: allow the caller to pass in the list of vreg uses that remain // to be bottom-scheduled to avoid searching uses at each query. SlotIndex CurrIdx = getCurrSlot(); - const LiveRange *LR = getLiveRange(*LIS, Reg); - if (LR) { - LiveQueryResult LRQ = LR->Query(SlotIdx); - if (LRQ.isKill() && !findUseBetween(Reg, CurrIdx, SlotIdx, *MRI, LIS)) - decreaseRegPressure(Reg); - } - } else if (!TargetRegisterInfo::isVirtualRegister(Reg)) { - // Allocatable physregs are always single-use before register rewriting. - decreaseRegPressure(Reg); + LastUseMask + = findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, LIS); + if (LastUseMask == 0) + continue; } + + LaneBitmask LiveMask = LiveRegs.contains(Reg); + LaneBitmask NewMask = LiveMask & ~LastUseMask; + decreaseRegPressure(Reg, LiveMask, NewMask); } // Generate liveness for defs. - increaseRegPressure(RegOpers.Defs); + for (const RegisterMaskPair &Def : RegOpers.Defs) { + unsigned Reg = Def.RegUnit; + LaneBitmask LiveMask = LiveRegs.contains(Reg); + LaneBitmask NewMask = LiveMask | Def.LaneMask; + increaseRegPressure(Reg, LiveMask, NewMask); + } // Boost pressure for all dead defs together. - increaseRegPressure(RegOpers.DeadDefs); - decreaseRegPressure(RegOpers.DeadDefs); + bumpDeadDefs(RegOpers.DeadDefs); } /// Consider the pressure increase caused by traversing this instruction Index: lib/CodeGen/ScheduleDAGInstrs.cpp =================================================================== --- lib/CodeGen/ScheduleDAGInstrs.cpp +++ lib/CodeGen/ScheduleDAGInstrs.cpp @@ -899,7 +899,7 @@ collectVRegUses(SU); RegisterOperands RegOpers; - RegOpers.collect(*MI, *TRI, MRI); + RegOpers.collect(*MI, *TRI, MRI, TrackLaneMasks, false); if (PDiffs != nullptr) PDiffs->addInstruction(SU->NodeNum, RegOpers, MRI);