diff --git a/llvm/include/llvm/CodeGen/TargetInstrInfo.h b/llvm/include/llvm/CodeGen/TargetInstrInfo.h --- a/llvm/include/llvm/CodeGen/TargetInstrInfo.h +++ b/llvm/include/llvm/CodeGen/TargetInstrInfo.h @@ -2033,6 +2033,34 @@ "TargetInstrInfo::createVirtualVectorRegisterForSpillToReg!"); } + /// \Returns true if it is profitable to perform spill2reg on \p MI. + virtual bool isSpill2RegProfitable(const MachineInstr *MI, + const TargetRegisterInfo *TRI, + const MachineRegisterInfo *MRI) const { + llvm_unreachable( + "Target didn't implement TargetInstrInfo::isSpill2RegProfitable!"); + } + + /// Inserts \p SrcReg into the first lane of \p DstReg. + virtual MachineInstr * + spill2RegInsertToVectorReg(Register DstReg, Register SrcReg, + int OperationBits, MachineBasicBlock *MBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const { + llvm_unreachable( + "Target didn't implement TargetInstrInfo::spill2RegInsertToVectorReg!"); + } + + /// Extracts the first lane of \p SrcReg into \p DstReg. + virtual MachineInstr * + spill2RegExtractFromVectorReg(Register DstReg, Register SrcReg, + int OperationBits, MachineBasicBlock *InsertMBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const { + llvm_unreachable("Target didn't implement " + "TargetInstrInfo::spill2RegExtractFromVectorReg!"); + } + private: mutable std::unique_ptr Formatter; unsigned CallFrameSetupOpcode, CallFrameDestroyOpcode; diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h --- a/llvm/lib/Target/X86/X86InstrInfo.h +++ b/llvm/lib/Target/X86/X86InstrInfo.h @@ -680,6 +680,22 @@ const TargetRegisterClass * getVectorRegisterClassForSpill2Reg(const TargetRegisterInfo *TRI, Register SpilledReg) const override; + + bool isSpill2RegProfitable(const MachineInstr *MI, + const TargetRegisterInfo *TRI, + const MachineRegisterInfo *MRI) const override; + + MachineInstr * + spill2RegInsertToVectorReg(Register DstReg, Register SrcReg, + int OperationBits, MachineBasicBlock *MBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const override; + + MachineInstr * + spill2RegExtractFromVectorReg(Register DstReg, Register SrcReg, + int OperationBits, MachineBasicBlock *InsertMBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const override; }; } // namespace llvm diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -76,6 +76,21 @@ "certain undef register reads"), cl::init(128), cl::Hidden); +// A value of 2 was empirically found to work with Skylake. +static cl::opt Spill2RegMemInstrsThreshold( + "spill2reg-mem-instrs", cl::Hidden, cl::init(2), + cl::desc("Apply spill2reg if we find at least this many mem " + "instrs within the explored distance.")); + +static cl::opt Spill2RegVecInstrsThreshold( + "spill2reg-vec-instrs", cl::Hidden, cl::init(0), + cl::desc("Apply spill2reg if we find at most this many vector instrs " + "within the explored distance.")); + +static cl::opt Spill2RegExplorationDst( + "spill2reg-exploration-distance", cl::Hidden, cl::init(4), + cl::desc("When checking for profitability, explore up to this many nearby " + "instructions.")); // Pin the vtable to this file. void X86InstrInfo::anchor() {} @@ -9702,5 +9717,90 @@ return VecRegClass; } +bool X86InstrInfo::isSpill2RegProfitable(const MachineInstr *MI, + const TargetRegisterInfo *TRI, + const MachineRegisterInfo *MRI) const { + const MachineInstr *TopMI = MI; + auto IsVecMO = [TRI, MI](const MachineOperand &MO) { + const MachineFunction *MF = MI->getParent()->getParent(); + if (MO.isReg() && MO.getReg().isPhysical()) { + for (auto ClassID : + {X86::VR128RegClassID, X86::VR256RegClassID, X86::VR512RegClassID}) + if (TRI->getRegClass(ClassID)->contains(MO.getReg())) + return true; + } + if (MO.isFI()) { + const unsigned MinVecBits = + TRI->getRegSizeInBits(*TRI->getRegClass(X86::VR128RegClassID)); + if (MF->getFrameInfo().getObjectSize(MO.getIndex()) >= MinVecBits) + return true; + } + return false; + }; + + // This is a simple heuristic. We count the number of memory and vector + // instructions and check against threshold values. + int CntMem = 0; + int CntAll = 0; + int CntVec = 0; + for (unsigned Cnt = 0, E = Spill2RegExplorationDst; + Cnt < E && TopMI != nullptr; ++Cnt) { + if (!TopMI->memoperands_empty()) + ++CntMem; + if (llvm::any_of(TopMI->operands(), IsVecMO)) + ++CntVec; + ++CntAll; + do { + TopMI = TopMI->getPrevNode(); + } while (TopMI != nullptr && TopMI->isDebugInstr()); + } + // Return false if we have not explored as many instructions as we want. + if (Spill2RegMemInstrsThreshold != 0 && CntAll < Spill2RegExplorationDst) + return false; + // Else check against the thresholds. + bool MemHeuristic = + Spill2RegMemInstrsThreshold == 0 || CntMem >= Spill2RegMemInstrsThreshold; + bool VecHeuristic = CntVec <= Spill2RegVecInstrsThreshold; + return MemHeuristic && VecHeuristic; +} + +static unsigned getInsertOrExtractOpcode(unsigned Bits, bool Insert) { + switch (Bits) { + case 32: + return Insert ? X86::MOVDI2PDIrr : X86::MOVPDI2DIrr; + case 64: + return Insert ? X86::MOV64toPQIrr : X86::MOVPQIto64rr; + default: + llvm_unreachable("Unsupported bits"); + } +} + +MachineInstr *X86InstrInfo::spill2RegInsertToVectorReg( + Register DstReg, Register SrcReg, int OperationBits, MachineBasicBlock *MBB, + MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const { + DebugLoc DL; + unsigned InsertOpcode = + getInsertOrExtractOpcode(OperationBits, true /*insert*/); + const MCInstrDesc &InsertMCID = get(InsertOpcode); + MachineInstr *InsertMI = + BuildMI(*MBB, InsertBeforeIt, DL, InsertMCID, DstReg).addReg(SrcReg); + return InsertMI; +} + +MachineInstr *X86InstrInfo::spill2RegExtractFromVectorReg( + Register DstReg, Register SrcReg, int OperationBits, + MachineBasicBlock *InsertMBB, MachineBasicBlock::iterator InsertBeforeIt, + const TargetRegisterInfo *TRI) const { + DebugLoc DL; + unsigned ExtractOpcode = + getInsertOrExtractOpcode(OperationBits, false /*extract*/); + const MCInstrDesc &ExtractMCID = get(ExtractOpcode); + MachineInstr *ExtractMI = + BuildMI(*InsertMBB, InsertBeforeIt, DL, ExtractMCID, DstReg) + .addReg(SrcReg); + return ExtractMI; +} + #define GET_INSTRINFO_HELPERS #include "X86GenInstrInfo.inc"