diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -2030,6 +2030,14 @@ return false; } + virtual const TargetRegisterClass * + getVectorRegisterClassForSpill2Reg(const TargetRegisterInfo *TRI, + Register SpilledReg) const { + llvm_unreachable( + "Target didn't implement " + "TargetInstrInfo::createVirtualVectorRegisterForSpillToReg!"); + } + private: mutable std::unique_ptr Formatter; unsigned CallFrameSetupOpcode, CallFrameDestroyOpcode; diff --git a/llvm/lib/CodeGen/Spill2Reg.cpp b/llvm/lib/CodeGen/Spill2Reg.cpp --- a/llvm/lib/CodeGen/Spill2Reg.cpp +++ b/llvm/lib/CodeGen/Spill2Reg.cpp @@ -15,6 +15,8 @@ /// //===----------------------------------------------------------------------===// +#include "AllocationOrder.h" +#include "llvm/CodeGen/LiveRegUnits.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineRegisterInfo.h" @@ -75,6 +77,22 @@ /// Look for candidates for spill2reg. These candidates are in places with /// high memory unit contention. Fills in StackSlotData. void collectSpillsAndReloads(); + /// \Returns if \p MI is profitable to apply spill-to-reg by checking whether + /// this would remove pipeline bubbles. + bool isProfitable(const MachineInstr *MI) const; + /// \Returns true if any stack-based spill/reload in \p Entry is profitable + /// to replace with a reg-based spill/reload. + bool allAccessesProfitable(const StackSlotDataEntry &Entry) const; + /// Look for a free physical register in \p LRU of reg class \p RegClass. + llvm::Optional + tryGetFreePhysicalReg(const TargetRegisterClass *RegClass, + const LiveRegUnits &LRU); + /// Helper for generateCode(). It eplaces stack spills or reloads with movs + /// to \p LI.reg(). + void replaceStackWithReg(StackSlotDataEntry &Entry, Register VectorReg); + /// Updates \p LRU with the liveness of physical registers around the spills + /// and reloads in \p Entry. + void calculateLiveRegs(StackSlotDataEntry &Entry, LiveRegUnits &LRU); /// Replace spills to stack with spills to registers (same for reloads). void generateCode(); /// Cleanup data structures once the pass is finished. @@ -90,6 +108,7 @@ MachineFrameInfo *MFI = nullptr; const TargetInstrInfo *TII = nullptr; const TargetRegisterInfo *TRI = nullptr; + RegisterClassInfo RegClassInfo; }; } // namespace @@ -115,6 +134,8 @@ if (!TII->targetSupportsSpill2Reg(&MF->getSubtarget())) return false; + RegClassInfo.runOnMachineFunction(MFn); + return run(); } @@ -176,7 +197,66 @@ } } -void Spill2Reg::generateCode() { llvm_unreachable("Unimplemented"); } +bool Spill2Reg::isProfitable(const MachineInstr *MI) const { + // TODO: Unimplemented. + return true; +} + +bool Spill2Reg::allAccessesProfitable(const StackSlotDataEntry &Entry) const { + auto IsProfitable = [this](const auto &MID) { return isProfitable(MID.MI); }; + return llvm::all_of(Entry.Spills, IsProfitable) && + llvm::all_of(Entry.Reloads, IsProfitable); +} + +llvm::Optional +Spill2Reg::tryGetFreePhysicalReg(const TargetRegisterClass *RegClass, + const LiveRegUnits &LRU) { + auto Order = RegClassInfo.getOrder(RegClass); + for (auto I = Order.begin(), E = Order.end(); I != E; ++I) { + MCRegister PhysVectorReg = *I; + if (LRU.available(PhysVectorReg)) + return PhysVectorReg; + } + return None; +} + +// Replace stack-based spills/reloads with register-based ones. +void Spill2Reg::replaceStackWithReg(StackSlotDataEntry &Entry, + Register VectorReg) { + // TODO: Unimplemented +} + +void Spill2Reg::calculateLiveRegs(StackSlotDataEntry &Entry, + LiveRegUnits &LRU) { + // TODO: Unimplemented +} + +void Spill2Reg::generateCode() { + for (auto &Pair : StackSlotData) { + StackSlotDataEntry &Entry = Pair.second; + // Skip if this stack slot was disabled during data collection. + if (Entry.Disable) + continue; + + // We decide to spill2reg if any of the spills/reloads are in a hotspot. + if (!allAccessesProfitable(Entry)) + continue; + + // Calculate liveness for Entry. + LiveRegUnits LRU(*TRI); + calculateLiveRegs(Entry, LRU); + + // Look for a physical register that in LRU. + llvm::Optional PhysVectorRegOpt = tryGetFreePhysicalReg( + TII->getVectorRegisterClassForSpill2Reg(TRI, Entry.getSpilledReg()), + LRU); + if (!PhysVectorRegOpt) + continue; + + // Replace stack accesses with register accesses. + replaceStackWithReg(Entry, *PhysVectorRegOpt); + } +} void Spill2Reg::cleanup() { StackSlotData.clear(); } diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h --- a/llvm/lib/Target/X86/X86InstrInfo.h +++ b/llvm/lib/Target/X86/X86InstrInfo.h @@ -679,6 +679,10 @@ const MachineRegisterInfo *MRI) const override; bool targetSupportsSpill2Reg(const TargetSubtargetInfo *STI) const override; + + const TargetRegisterClass * + getVectorRegisterClassForSpill2Reg(const TargetRegisterInfo *TRI, + Register SpilledReg) const override; }; } // namespace llvm diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -9750,5 +9750,13 @@ return X86STI->hasSSE41(); } +const TargetRegisterClass * +X86InstrInfo::getVectorRegisterClassForSpill2Reg(const TargetRegisterInfo *TRI, + Register SpilledReg) const { + const TargetRegisterClass *VecRegClass = + TRI->getRegClass(X86::VR128RegClassID); + return VecRegClass; +} + #define GET_INSTRINFO_HELPERS #include "X86GenInstrInfo.inc"