Index: lib/Target/AMDGPU/SIMachineScheduler.h =================================================================== --- lib/Target/AMDGPU/SIMachineScheduler.h +++ lib/Target/AMDGPU/SIMachineScheduler.h @@ -326,6 +326,7 @@ std::vector BlocksScheduled; unsigned NumBlockScheduled; std::vector ReadyBlocks; + std::set CurrentPathRegUsage; unsigned VregCurrentUsage; unsigned SregCurrentUsage; @@ -395,6 +396,10 @@ std::set &OutRegs); void schedule(); + std::set findPathRegUsage(int SearchDepthLimit, + int VGPRDiffGoal, + int SGPRDiffGoal, + bool PriorityVGPR); }; struct SIScheduleBlockResult { Index: lib/Target/AMDGPU/SIMachineScheduler.cpp =================================================================== --- lib/Target/AMDGPU/SIMachineScheduler.cpp +++ lib/Target/AMDGPU/SIMachineScheduler.cpp @@ -1503,10 +1503,42 @@ dbgs() << "Current SGPRs: " << SregCurrentUsage << '\n'; ); + // 200 and 70 are arbitrary thresholds. + // TODO: better way to set them ? + // If finding a path of depth 6 fails, try to find a path + // of depth 12. + if (CurrentPathRegUsage.empty()) { + if (VregCurrentUsage > 200) { + CurrentPathRegUsage = findPathRegUsage(6, 0, INT_MAX, true); + if (CurrentPathRegUsage.empty()) + CurrentPathRegUsage = findPathRegUsage(12, 0, INT_MAX, true); + } else if (SregCurrentUsage > 70) { + CurrentPathRegUsage = findPathRegUsage(6, INT_MAX, 0, false); + if (CurrentPathRegUsage.empty()) + CurrentPathRegUsage = findPathRegUsage(12, INT_MAX, 0, false); + } + } + + DEBUG( + if (!CurrentPathRegUsage.empty()) { + dbgs() << "Restricting research among: "; + for (SIScheduleBlock* Block : ReadyBlocks) { + if (CurrentPathRegUsage.find(Block) == CurrentPathRegUsage.end()) + continue; + dbgs() << Block->getID() << ' '; + } + dbgs() << '\n'; + } + ); Cand.Block = nullptr; for (std::vector::iterator I = ReadyBlocks.begin(), E = ReadyBlocks.end(); I != E; ++I) { SIBlockSchedCandidate TryCand; + + if (!CurrentPathRegUsage.empty() && + CurrentPathRegUsage.find(*I) == CurrentPathRegUsage.end()) + continue; + TryCand.Block = *I; TryCand.IsHighLatency = TryCand.Block->isHighLatencyBlock(); TryCand.VGPRUsageDiff = @@ -1550,6 +1582,8 @@ Block = Cand.Block; ReadyBlocks.erase(Best); + if (!CurrentPathRegUsage.empty()) + CurrentPathRegUsage.erase(Block); return Block; } @@ -1647,6 +1681,513 @@ return DiffSetPressure; } +// Strategy to reduce register pressure: +// Idea: Reducing register pressure is hard. Heuristics with +// scores have failed (scores for each instruction about a "potential" +// to release registers). +// Instead of using heuristics, try to find a path of instructions with a +// reduced number of instructions, such that at the end of the path the +// number of live registers is reduced. +// +// Algorithm: +// For each alive register or register not produced: +// -> if we need more than k instructions to consume the register, ignore. +// -> else compute the minimal set of registers that would be produced to +// schedule all the instructions needed to consume the register, idem +// for the set of registers that would be consumed (separate currently +// alive registers and intermediate registers). +// +// We then look at the register such that the path to consume it released the +// most registers. +// +// A way to get better results would then to try doing combinations that +// would directly give better consumed/produced ratio. +// +// Another way to improve would be to pick amond all the possible choices +// the best latency-friendly path. + +struct SIBlockInfo { + // All Blocks that have to be scheduled first + Block + std::set Dependencies; + // The registers that were initially defined, that + // scheduling both this block and its dependencies + // consumed. + std::set ConsumedRegisters; + // Idem for the intermediate registers (that were produced + // at some point, and then consumed). + std::set ProducedConsumedRegisters; + // idem for registers that were produced at some point, but + // are still alive. + std::set ProducedRegisters; + // Registers that were specifically eaten by this block (identical to + // to LiveInRegs, but in the "unique reg identifier" space). + std::set BlockInRegs; + // unique identifier -> which blocks in the Dependencies do eat this register + // We use this only to count how many times a give register, but we use + // set to enable union when merging several branches. + // Registers that are consumed are removed from the map. + std::map> RegisterConsumers; +}; + +struct SIMIRegisterInfo { + // The minimal set of Blocks to schedule to release this register. + std::set Dependencies; + // The livein registers that are consumed by scheduling the Dependencies. + std::set ConsumedRegisters; + // The registers that are produced then consumed by scheduling the + // Dependencies. + std::set ProducedConsumedRegisters; + // The registers that are produced then consumed by scheduling the + // Dependencies. + std::set ProducedRegisters; + // Idem than for Blocks + std::map> RegisterConsumers; +}; + +std::set +SIScheduleBlockScheduler::findPathRegUsage(int SearchDepthLimit, + int VGPRDiffGoal, + int SGPRDiffGoal, + bool PriorityVGPR) +{ + std::set LiveRegsInitId; + std::map RegsConsumers; + + std::vector BlockNumPredsLeftCurrent = BlockNumPredsLeft; + + // We want an unique identifier per register, but a register can be reused + // (input and output of a block). We thus use a mapping + // "unique reg identifier" -> register + what produced it, + // and an opposite mapping register -> current associated identifier. + std::map> IdentifierToReg; + std::map RegToIdentifier; + unsigned CurrentIdentifier = 0; + + std::map BlockInfos; + std::vector SchedulableBlocks = ReadyBlocks; + std::vector SchedulableBlockNextDepth; + + std::map RegisterInfos; + + int BestDiffVGPR = INT_MAX; + int BestDiffSGPR = INT_MAX; + + if (ReadyBlocks.empty()) + return std::set(); + + DEBUG(dbgs() << "findPathRegUsage(" << SearchDepthLimit << ")\n"); + DEBUG(dbgs() << "Initial Live regs:\n"); + + // Fill info for initial registers + for (unsigned Reg : LiveRegs) { + // Ignoring physical registers + if (!TargetRegisterInfo::isVirtualRegister(Reg)) + continue; + (void) LiveRegsInitId.insert(CurrentIdentifier); + assert(LiveRegsConsumers.find(Reg) != LiveRegsConsumers.end()); + assert(LiveRegsConsumers[Reg] > 0); + RegsConsumers[CurrentIdentifier] = LiveRegsConsumers[Reg]; + + IdentifierToReg[CurrentIdentifier] = + std::pair(Reg, nullptr); + RegToIdentifier[Reg] = CurrentIdentifier; + DEBUG(dbgs() << PrintVRegOrUnit(Reg, DAG->getTRI()) << " (--> " << CurrentIdentifier << ")\n"); + ++CurrentIdentifier; + } + + // Fill BlockInfos + while (SearchDepthLimit > 0 && !SchedulableBlocks.empty()) { + DEBUG(dbgs() << "Iterating... Remaining levels: " << SearchDepthLimit << '\n'); + for (SIScheduleBlock* Block : SchedulableBlocks) { + struct SIBlockInfo BlockInfo = {};// TODO: check memset + + DEBUG(dbgs() << "Computing data for Block: " << Block->getID() << '\n'); + + for (SIScheduleBlock *Parent : Block->getPreds()) { + if (BlockInfos.find(Parent) != BlockInfos.end()) { + // The Parent was not scheduled before findPathRegUsage. + // Add the dependencies + BlockInfo.Dependencies.insert( + BlockInfos[Parent].Dependencies.begin(), + BlockInfos[Parent].Dependencies.end()); + // Idem for consumed and produced registers + BlockInfo.ConsumedRegisters.insert( + BlockInfos[Parent].ConsumedRegisters.begin(), + BlockInfos[Parent].ConsumedRegisters.end()); + BlockInfo.ProducedConsumedRegisters.insert( + BlockInfos[Parent].ProducedConsumedRegisters.begin(), + BlockInfos[Parent].ProducedConsumedRegisters.end()); + BlockInfo.ProducedRegisters.insert( + BlockInfos[Parent].ProducedRegisters.begin(), + BlockInfos[Parent].ProducedRegisters.end()); + for (const auto RConsumers: BlockInfos[Parent].RegisterConsumers) { + // Check it's filled in RegsConsumers + assert(RegsConsumers.find(RConsumers.first) != RegsConsumers.end()); + + // Note: the registers in RegisterConsumers are either in + // LiveRegsCurrent or in BlockInfo.ProducedRegisters. + // In all cases it's registers that aren't consumed yet. + if (BlockInfo.RegisterConsumers.find(RConsumers.first) == + BlockInfo.RegisterConsumers.end()) { + // Nothing to do, add to the list. + BlockInfo.RegisterConsumers[RConsumers.first] = + RConsumers.second; + } else { + BlockInfo.RegisterConsumers[RConsumers.first].insert( + RConsumers.second.begin(), RConsumers.second.end()); + // If the register has all its consumers in the Dependencies + // list, move it from RegisterConsumers to the correct list. + if (BlockInfo.RegisterConsumers[RConsumers.first].size() == + RegsConsumers[RConsumers.first]) { + // Register is consumed, add to the correct list + BlockInfo.RegisterConsumers.erase(RConsumers.first); + if (LiveRegsInitId.find(RConsumers.first) != LiveRegsInitId.end()) + BlockInfo.ConsumedRegisters.insert(RConsumers.first); + else { + assert(BlockInfo.ProducedRegisters.find(RConsumers.first) != + BlockInfo.ProducedRegisters.end()); + BlockInfo.ProducedConsumedRegisters.insert(RConsumers.first); + BlockInfo.ProducedRegisters.erase(RConsumers.first); + } + } + } + } + } + } + // At this point, we have merged the data from all parents. + BlockInfo.Dependencies.insert(Block); + + for (unsigned Reg : Block->getInRegs()) { + if (!TargetRegisterInfo::isVirtualRegister(Reg)) + continue; + DEBUG(dbgs() << "InReg : " << Reg << '\n'); + assert(RegToIdentifier.find(Reg) != RegToIdentifier.end()); + unsigned RegIdentifier = RegToIdentifier[Reg]; + + // Check it's filled in RegsConsumers + assert(RegsConsumers.find(RegIdentifier) != RegsConsumers.end()); + + BlockInfo.BlockInRegs.insert(RegIdentifier); + + BlockInfo.RegisterConsumers[RegIdentifier].insert(Block); + // If the register has all its consumers in the Dependencies + // list, move it from RegisterConsumers to the correct list. + if (BlockInfo.RegisterConsumers[RegIdentifier].size() == + RegsConsumers[RegIdentifier]) { + // Register is consumed, add to the correct list + BlockInfo.RegisterConsumers.erase(RegIdentifier); + if (LiveRegsInitId.find(RegIdentifier) != LiveRegsInitId.end()) + BlockInfo.ConsumedRegisters.insert(RegIdentifier); + else { + assert(BlockInfo.ProducedRegisters.find(RegIdentifier) != + BlockInfo.ProducedRegisters.end()); + BlockInfo.ProducedConsumedRegisters.insert(RegIdentifier); + BlockInfo.ProducedRegisters.erase(RegIdentifier); + } + } + } + for (unsigned Reg : Block->getOutRegs()) { + if (!TargetRegisterInfo::isVirtualRegister(Reg)) + continue; + // Note: By construction, we can overwrite RegToIdentifier[Reg], + // because we schedule in a valid order (Reg was consumed at this + // point). + unsigned RegIdentifier = CurrentIdentifier; + IdentifierToReg[RegIdentifier] = + std::pair(Reg, Block); + RegToIdentifier[Reg] = RegIdentifier; + DEBUG(dbgs() << "OutReg : " << Reg << '\n'); + ++CurrentIdentifier; + + RegsConsumers[RegIdentifier] = LiveOutRegsNumUsages[Block->getID()][Reg]; + BlockInfo.ProducedRegisters.insert(RegIdentifier); + } + + for (SIScheduleBlock *Child : Block->getSuccs()) { + --BlockNumPredsLeftCurrent[Child->getID()]; + if (BlockNumPredsLeftCurrent[Child->getID()] == 0) { + SchedulableBlockNextDepth.push_back(Child); + } + } + + BlockInfos[Block] = BlockInfo; + } + --SearchDepthLimit; + SchedulableBlocks = SchedulableBlockNextDepth; + SchedulableBlockNextDepth.clear(); + } + + // We have now the info for each block about produced and + // consumed registers for the block and its dependencies. + // We could at this point take the block such that + // consumed - produced has the best score, but working on + // consumable registers (instead of block) gives better result. + // Reason is working on registers contains also some combination of + // blocks and their dependencies (to release a specific registers). + + // Fill RegisterInfos. The algorithm can be implemented more efficiently. + for (auto BlockInfo : BlockInfos) { + for (unsigned Reg : BlockInfo.second.BlockInRegs) { + if (RegisterInfos.find(Reg) == RegisterInfos.end()) { + SIMIRegisterInfo RInfo; + RInfo.Dependencies = BlockInfo.second.Dependencies; + RInfo.ConsumedRegisters = BlockInfo.second.ConsumedRegisters; + RInfo.ProducedConsumedRegisters = + BlockInfo.second.ProducedConsumedRegisters; + RInfo.ProducedRegisters = BlockInfo.second.ProducedRegisters; + RInfo.RegisterConsumers = BlockInfo.second.RegisterConsumers; + RegisterInfos[Reg] = RInfo; + } else { + SIMIRegisterInfo RInfo = RegisterInfos[Reg]; + RInfo.Dependencies.insert( + BlockInfo.second.Dependencies.begin(), + BlockInfo.second.Dependencies.end()); + RInfo.ConsumedRegisters.insert( + BlockInfo.second.ConsumedRegisters.begin(), + BlockInfo.second.ConsumedRegisters.end()); + RInfo.ProducedConsumedRegisters.insert( + BlockInfo.second.ProducedConsumedRegisters.begin(), + BlockInfo.second.ProducedConsumedRegisters.end()); + // Contrary to Blocks, we can have things in RInfo ProducedRegisters + // That are in BlockInfo *ConsumedRegisters and vice-versa + // -> Do the Union, fix, then apply + RInfo.ProducedRegisters.insert( + BlockInfo.second.ProducedRegisters.begin(), + BlockInfo.second.ProducedRegisters.end()); + for (const auto RConsumers: BlockInfo.second.RegisterConsumers) { + if (RInfo.RegisterConsumers.find(RConsumers.first) == + RInfo.RegisterConsumers.end()) { + RInfo.RegisterConsumers[RConsumers.first] = + RConsumers.second; + } else { + RInfo.RegisterConsumers[RConsumers.first].insert( + RConsumers.second.begin(), RConsumers.second.end()); + } + } + // Fix + for (unsigned Reg : RInfo.ConsumedRegisters) { + if (RInfo.ProducedRegisters.find(Reg) != + RInfo.ProducedRegisters.end()) + RInfo.ProducedRegisters.erase(Reg); + if (RInfo.RegisterConsumers.find(Reg) != + RInfo.RegisterConsumers.end()) + RInfo.RegisterConsumers.erase(Reg); + } + for (unsigned Reg : RInfo.ProducedConsumedRegisters) { + if (RInfo.ProducedRegisters.find(Reg) != + RInfo.ProducedRegisters.end()) + RInfo.ProducedRegisters.erase(Reg); + if (RInfo.RegisterConsumers.find(Reg) != + RInfo.RegisterConsumers.end()) + RInfo.RegisterConsumers.erase(Reg); + } + // Apply + for (std::map>::iterator I = + RInfo.RegisterConsumers.begin(); + I != RInfo.RegisterConsumers.end();) { + // If the register has all its consumers in the Dependencies + // list, move it from RegisterConsumers to the correct list. + std::pair> RConsumers = *I; + bool ToErase = RConsumers.second.size() == + RegsConsumers[RConsumers.first]; + + if (ToErase) { + RInfo.RegisterConsumers.erase(I++); + if (LiveRegsInitId.find(RConsumers.first) != LiveRegsInitId.end()) + RInfo.ConsumedRegisters.insert(RConsumers.first); + else { + assert(RInfo.ProducedRegisters.find(RConsumers.first) != + RInfo.ProducedRegisters.end()); + RInfo.ProducedConsumedRegisters.insert(RConsumers.first); + RInfo.ProducedRegisters.erase(RConsumers.first); + } + } else + ++I; + } + RegisterInfos[Reg] = RInfo; + } + } + } + + DEBUG( + for (const auto RInfo : RegisterInfos) { + unsigned Reg = RInfo.first; + dbgs() << Reg << "(" << PrintVRegOrUnit(IdentifierToReg[Reg].first, + DAG->getTRI()) + << ")" << " :\nConsumed: "; + for (unsigned Reg2 : RInfo.second.ConsumedRegisters) + dbgs() << Reg2 << " "; + dbgs() << "\nProducedConsumed: "; + for (unsigned Reg2 : RInfo.second.ProducedConsumedRegisters) + dbgs() << Reg2 << " "; + dbgs() << "\nProduced: "; + for (unsigned Reg2 : RInfo.second.ProducedRegisters) + dbgs() << Reg2 << " "; + dbgs() << "\nRegisterConsumers: "; + for (const auto RConsumers: RInfo.second.RegisterConsumers) + dbgs() << " " << RConsumers.first << "(" << RConsumers.second.size() << ", " << RegsConsumers[RConsumers.first] << ")"; + dbgs() << "\n\n"; + }); + + for (std::map::iterator I = + RegisterInfos.begin(); I != RegisterInfos.end();) { + // Remove registers that are not fully released + // Those, if in RegisterInfos, are still in the list RegisterConsumers. + if ((*I).second.RegisterConsumers.find((*I).first) != + (*I).second.RegisterConsumers.end()) { + DEBUG(dbgs() << PrintVRegOrUnit(IdentifierToReg[(*I).first].first, + DAG->getTRI()) + << " consumed, but not releasable.\n"); + RegisterInfos.erase(I++); + } else + ++I; + } + + // Find the best score + for (const auto RInfo : RegisterInfos) { + int DiffVGPR = 0; + int DiffSGPR = 0; + for (unsigned Reg : RInfo.second.ConsumedRegisters) { + unsigned RealReg = IdentifierToReg[Reg].first; + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == DAG->getVGPRSetID()) + DiffVGPR -= PSetI.getWeight(); + if (*PSetI == DAG->getSGPRSetID()) + DiffSGPR -= PSetI.getWeight(); + } + } + for (unsigned Reg : RInfo.second.ProducedRegisters) { + unsigned RealReg = IdentifierToReg[Reg].first; + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == DAG->getVGPRSetID()) + DiffVGPR += PSetI.getWeight(); + if (*PSetI == DAG->getSGPRSetID()) + DiffSGPR += PSetI.getWeight(); + } + } + DEBUG(dbgs() << RInfo.first << ": diff = (" << DiffVGPR << ", " + << DiffSGPR << ")\n"); + // Remove cases that don't match the target. + if (DiffVGPR > VGPRDiffGoal || DiffSGPR > SGPRDiffGoal) + continue; + // Priority to reduce VGPR or SGPR + if (PriorityVGPR) { + if (BestDiffVGPR > DiffVGPR) { + BestDiffVGPR = DiffVGPR; + BestDiffSGPR = DiffSGPR; + } else if (BestDiffVGPR == DiffVGPR) { + if (BestDiffSGPR > DiffSGPR) + BestDiffSGPR = DiffSGPR; + } + } else { + if (BestDiffSGPR > DiffSGPR) { + BestDiffVGPR = DiffVGPR; + BestDiffSGPR = DiffSGPR; + } else if (BestDiffSGPR == DiffSGPR) { + if (BestDiffVGPR > DiffVGPR) + BestDiffVGPR = DiffVGPR; + } + } + } + + // No path found + if (BestDiffVGPR == INT_MAX) { + DEBUG(dbgs() << "No good path found\n"); + return std::set(); + } + + assert (RegisterInfos.size() != 0); + + DEBUG(dbgs() << "Best diff score: (" << BestDiffVGPR << ", " << BestDiffSGPR + << ")\n"); + + for (const auto RInfo : RegisterInfos) { + int DiffVGPR = 0; + int DiffSGPR = 0; + for (unsigned Reg : RInfo.second.ConsumedRegisters) { + unsigned RealReg = IdentifierToReg[Reg].first; + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == DAG->getVGPRSetID()) + DiffVGPR -= PSetI.getWeight(); + if (*PSetI == DAG->getSGPRSetID()) + DiffSGPR -= PSetI.getWeight(); + } + } + for (unsigned Reg : RInfo.second.ProducedRegisters) { + unsigned RealReg = IdentifierToReg[Reg].first; + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == DAG->getVGPRSetID()) + DiffVGPR += PSetI.getWeight(); + if (*PSetI == DAG->getSGPRSetID()) + DiffSGPR += PSetI.getWeight(); + } + } + if (BestDiffVGPR == DiffVGPR && BestDiffSGPR == DiffSGPR) { + DEBUG( + dbgs() << "Blocks picked: "; + for (SIScheduleBlock *Block : RInfo.second.Dependencies) + dbgs() << Block->getID() << " "; + dbgs() << '\n'; + dbgs() << "Consumed registers: "; + for (unsigned Reg : RInfo.second.ConsumedRegisters) { + unsigned RealReg = IdentifierToReg[Reg].first; + int DiffV = 0; + int DiffS = 0; + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == DAG->getVGPRSetID()) + DiffV -= PSetI.getWeight(); + if (*PSetI == DAG->getSGPRSetID()) + DiffS -= PSetI.getWeight(); + } + dbgs() << PrintVRegOrUnit(RealReg, DAG->getTRI()); + dbgs() << "(" << DiffV << ", " << DiffS << "), "; + } + dbgs() << '\n'; + dbgs() << "Intermediate registers: "; + for (unsigned Reg : RInfo.second.ProducedConsumedRegisters) { + unsigned RealReg = IdentifierToReg[Reg].first; + int DiffV = 0; + int DiffS = 0; + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == DAG->getVGPRSetID()) + DiffV += PSetI.getWeight(); + if (*PSetI == DAG->getSGPRSetID()) + DiffS += PSetI.getWeight(); + } + dbgs() << PrintVRegOrUnit(RealReg, DAG->getTRI()); + dbgs() << "(" << DiffV << ", " << DiffS << "), "; + } + dbgs() << '\n'; + dbgs() << "Produced registers: "; + for (unsigned Reg : RInfo.second.ProducedRegisters) { + unsigned RealReg = IdentifierToReg[Reg].first; + int DiffV = 0; + int DiffS = 0; + PSetIterator PSetI = DAG->getMRI()->getPressureSets(RealReg); + for (; PSetI.isValid(); ++PSetI) { + if (*PSetI == DAG->getVGPRSetID()) + DiffV += PSetI.getWeight(); + if (*PSetI == DAG->getSGPRSetID()) + DiffS += PSetI.getWeight(); + } + dbgs() << PrintVRegOrUnit(RealReg, DAG->getTRI()); + dbgs() << "(" << DiffV << ", " << DiffS << "), "; + } + dbgs() << '\n'; + ); + return RInfo.second.Dependencies; + } + } + + llvm_unreachable("internal error"); +} + // SIScheduler // struct SIScheduleBlockResult