diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -2033,6 +2033,34 @@ "TargetInstrInfo::createVirtualVectorRegisterForSpillToReg!"); } + /// \Returns true if it is profitable to perform spill2reg on \p MI. + virtual bool isSpill2RegProfitable(const MachineInstr *MI, + const TargetRegisterInfo *TRI, + const MachineRegisterInfo *MRI) const { + llvm_unreachable( + "Target didn't implement TargetInstrInfo::isSpill2RegProfitable!"); + } + + /// Inserts \p SrcReg into the first lane of \p DstReg. + virtual MachineInstr * + spill2RegInsertToVectorReg(Register DstReg, Register SrcReg, + int OperationBits, MachineBasicBlock *MBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const { + llvm_unreachable( + "Target didn't implement TargetInstrInfo::spill2RegInsertToVectorReg!"); + } + + /// Extracts the first lane of \p SrcReg into \p DstReg. + virtual MachineInstr * + spill2RegExtractFromVectorReg(Register DstReg, Register SrcReg, + int OperationBits, MachineBasicBlock *InsertMBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const { + llvm_unreachable("Target didn't implement " + "TargetInstrInfo::spill2RegExtractFromVectorReg!"); + } + private: mutable std::unique_ptr Formatter; unsigned CallFrameSetupOpcode, CallFrameDestroyOpcode; diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h --- a/llvm/lib/Target/X86/X86InstrInfo.h +++ b/llvm/lib/Target/X86/X86InstrInfo.h @@ -680,6 +680,22 @@ const TargetRegisterClass * getVectorRegisterClassForSpill2Reg(const TargetRegisterInfo *TRI, Register SpilledReg) const override; + + bool isSpill2RegProfitable(const MachineInstr *MI, + const TargetRegisterInfo *TRI, + const MachineRegisterInfo *MRI) const override; + + MachineInstr * + spill2RegInsertToVectorReg(Register DstReg, Register SrcReg, + int OperationBits, MachineBasicBlock *MBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const override; + + MachineInstr * + spill2RegExtractFromVectorReg(Register DstReg, Register SrcReg, + int OperationBits, MachineBasicBlock *InsertMBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const override; }; } // namespace llvm diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -76,6 +76,21 @@ "certain undef register reads"), cl::init(128), cl::Hidden); +// A value of 2 was empirically found to work with Skylake. +static cl::opt Spill2RegMemInstrsThreshold( + "spill2reg-mem-instrs", cl::Hidden, cl::init(80), + cl::desc("Apply spill2reg if we find at least this much percentage of " + "memory nstrs within the explored distance.")); + +static cl::opt Spill2RegVecInstrsThreshold( + "spill2reg-vec-instrs", cl::Hidden, cl::init(1), + cl::desc("Apply spill2reg if we find fewer than this many vector instrs " + "within the explored distance.")); + +static cl::opt Spill2RegExplorationDst( + "spill2reg-exploration-distance", cl::Hidden, cl::init(4), + cl::desc("When checking for profitability, explore nearby instructions " + "at this maximum distance.")); // Pin the vtable to this file. void X86InstrInfo::anchor() {} @@ -9702,5 +9717,117 @@ return VecRegClass; } +bool X86InstrInfo::isSpill2RegProfitable(const MachineInstr *MI, + const TargetRegisterInfo *TRI, + const MachineRegisterInfo *MRI) const { + auto IsVecMO = [TRI, MI](const MachineOperand &MO) { + const MachineFunction *MF = MI->getParent()->getParent(); + if (MO.isReg() && MO.getReg().isPhysical()) { + for (auto ClassID : + {X86::VR128RegClassID, X86::VR256RegClassID, X86::VR512RegClassID}) + if (TRI->getRegClass(ClassID)->contains(MO.getReg())) + return true; + } + if (MO.isFI()) { + const unsigned MinVecBits = + TRI->getRegSizeInBits(*TRI->getRegClass(X86::VR128RegClassID)); + if (MF->getFrameInfo().getObjectSize(MO.getIndex()) >= MinVecBits) + return true; + } + return false; + }; + + /// \Returns the previous instruction, skipping debug instrs. + auto GetPrevNonDebug = [](const MachineInstr *MI) { + do { + MI = MI->getPrevNode(); + } while (MI != nullptr && MI->isDebugInstr()); + return MI; + }; + /// \Returns the next instruction, skipping debug instrs. + auto GetNextNonDebug = [](const MachineInstr *MI) { + do { + MI = MI->getNextNode(); + } while (MI != nullptr && MI->isDebugInstr()); + return MI; + }; + + // This is a simple heuristic. We count the number of memory and vector + // instructions both above and below `MI` with a radius of + // Spill2RegExplorationDst, and check against threshold values. + int CntMem = 0; + int CntAll = 0; + int CntVec = 0; + const MachineInstr *TopMI = MI; + const MachineInstr *BotMI = GetNextNonDebug(MI); + for (int Radius = 0, MaxRadius = Spill2RegExplorationDst; + (TopMI != nullptr || BotMI != nullptr) && Radius < MaxRadius; ++Radius) { + if (TopMI != nullptr && !TopMI->memoperands_empty()) + ++CntMem; + if (BotMI != nullptr && !BotMI->memoperands_empty()) + ++CntMem; + if (TopMI != nullptr && llvm::any_of(TopMI->operands(), IsVecMO)) + ++CntVec; + if (BotMI != nullptr && llvm::any_of(BotMI->operands(), IsVecMO)) + ++CntVec; + + if (TopMI != nullptr) { + TopMI = GetPrevNonDebug(TopMI); + ++CntAll; + } + if (BotMI != nullptr) { + BotMI = GetNextNonDebug(BotMI); + ++CntAll; + } + } + // Return false if exploration ended early because we reached the end of BB. + if (Spill2RegMemInstrsThreshold != 0 && CntAll < 2 * Spill2RegExplorationDst) + return false; + // Else check against the thresholds. + bool MemHeuristic = Spill2RegMemInstrsThreshold == 0 || + (CntMem * 100) / CntAll >= Spill2RegMemInstrsThreshold; + bool VecHeuristic = + Spill2RegVecInstrsThreshold == 0 || CntVec < Spill2RegVecInstrsThreshold; + return MemHeuristic && VecHeuristic; +} + +static unsigned getInsertOrExtractOpcode(unsigned Bits, bool Insert) { + switch (Bits) { + case 32: + return Insert ? X86::MOVDI2PDIrr : X86::MOVPDI2DIrr; + case 64: + return Insert ? X86::MOV64toPQIrr : X86::MOVPQIto64rr; + default: + llvm_unreachable("Unsupported bits"); + } +} + +MachineInstr *X86InstrInfo::spill2RegInsertToVectorReg( + Register DstReg, Register SrcReg, int OperationBits, MachineBasicBlock *MBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const { + DebugLoc DL; + unsigned InsertOpcode = + getInsertOrExtractOpcode(OperationBits, true /*insert*/); + const MCInstrDesc &InsertMCID = get(InsertOpcode); + MachineInstr *InsertMI = + BuildMI(*MBB, InsertBeforeIt, DL, InsertMCID, DstReg).addReg(SrcReg); + return InsertMI; +} + +MachineInstr *X86InstrInfo::spill2RegExtractFromVectorReg( + Register DstReg, Register SrcReg, int OperationBits, + MachineBasicBlock *InsertMBB, MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const { + DebugLoc DL; + unsigned ExtractOpcode = + getInsertOrExtractOpcode(OperationBits, false /*extract*/); + const MCInstrDesc &ExtractMCID = get(ExtractOpcode); + MachineInstr *ExtractMI = + BuildMI(*InsertMBB, InsertBeforeIt, DL, ExtractMCID, DstReg) + .addReg(SrcReg); + return ExtractMI; +} + #define GET_INSTRINFO_HELPERS #include "X86GenInstrInfo.inc"