diff --git a/llvm/lib/CodeGen/FixupStatepointCallerSaved.cpp b/llvm/lib/CodeGen/FixupStatepointCallerSaved.cpp --- a/llvm/lib/CodeGen/FixupStatepointCallerSaved.cpp +++ b/llvm/lib/CodeGen/FixupStatepointCallerSaved.cpp @@ -46,8 +46,20 @@ cl::desc("Allow spill in spill slot of greater size than register size"), cl::Hidden); +static cl::opt PassGCPtrInCSR( + "fixup-allow-gcptr-in-csr", cl::Hidden, cl::init(false), + cl::desc("Allow passing GC Pointer arguments in callee saved registers")); + +// This is purely debugging option. +// It may be handy for investigating statepoint spilling issues. +static cl::opt MaxStatepointsWithRegs( + "fixup-max-csr-statepoints", cl::Hidden, + cl::desc("Max number of statepoints allowed to pass GC Ptrs in registers")); + namespace { +class FrameIndexesCache; + class FixupStatepointCallerSaved : public MachineFunctionPass { public: static char ID; @@ -66,7 +78,12 @@ } bool runOnMachineFunction(MachineFunction &MF) override; + +private: + void collectGlobalFIs(MachineBasicBlock &BB, FrameIndexesCache &Cache, + const TargetRegisterInfo *TRI); }; + } // End anonymous namespace. char FixupStatepointCallerSaved::ID = 0; @@ -83,6 +100,105 @@ return TRI.getSpillSize(*RC); } +// Advance iterator to the next stack map entry +static MachineInstr::const_mop_iterator +advanceToNextStackMapElt(MachineInstr::const_mop_iterator MOI) { + if (MOI->isImm()) { + switch (MOI->getImm()) { + default: + llvm_unreachable("Unrecognized operand type."); + case StackMaps::DirectMemRefOp: + MOI += 2; // , + break; + case StackMaps::IndirectMemRefOp: + MOI += 3; // , , + break; + case StackMaps::ConstantOp: + MOI += 1; + break; + } + } + return ++MOI; +} + +// Return statepoint GC args as a set +static SmallSet collectGCRegs(MachineInstr &MI) { + StatepointOpers SO(&MI); + unsigned NumDeoptIdx = SO.getNumDeoptArgsIdx(); + unsigned NumDeoptArgs = MI.getOperand(NumDeoptIdx).getImm(); + MachineInstr::const_mop_iterator MOI(MI.operands_begin() + NumDeoptIdx + 1), + MOE(MI.operands_end()); + + // Skip deopt args + for (unsigned i = 0; i < NumDeoptArgs; ++i) + MOI = advanceToNextStackMapElt(MOI); + + SmallSet Result; + while (MOI != MOE) { + if (MOI->isReg() && !MOI->isImplicit()) + Result.insert(MOI->getReg()); + MOI = advanceToNextStackMapElt(MOI); + } + return Result; +} + +// Try to eliminate redundant copy to register which we're going to +// spill, i.e. try to change: +// X = COPY Y +// SPILL X +// to +// SPILL Y +// If there are no uses of X between copy and STATEPOINT, that COPY +// may be eliminated. +// Reg - register we're about to spill +// RI - On entry points to statepoint. +// On successful copy propagation set to new spill point. +// IsKill - set to true if COPY is Kill (there are no uses of Y) +// Returns either found source copy register or original one. +static Register performCopyPropagation(Register Reg, + MachineBasicBlock::iterator &RI, + bool &IsKill, const TargetInstrInfo &TII, + const TargetRegisterInfo &TRI) { + MachineBasicBlock *MBB = RI->getParent(); + MachineBasicBlock::reverse_iterator E = MBB->rend(); + MachineInstr *Def = nullptr, *Use = nullptr; + for (auto It = ++(RI.getReverse()); It != E; ++It) { + if (It->readsRegister(Reg, &TRI) && !Use) + Use = &*It; + if (It->modifiesRegister(Reg, &TRI)) { + Def = &*It; + break; + } + } + if (!Def) + return Reg; + + auto DestSrc = TII.isCopyInstr(*Def); + if (!DestSrc || DestSrc->Destination->getReg() != Reg) + return Reg; + + Register SrcReg = DestSrc->Source->getReg(); + + if (getRegisterSize(TRI, Reg) != getRegisterSize(TRI, SrcReg)) + return Reg; + + LLVM_DEBUG(dbgs() << "spillRegisters: perform copy propagation " + << printReg(Reg, &TRI) << " -> " << printReg(SrcReg, &TRI) + << "\n"); + + // Insert spill immediately after Def + RI = ++MachineBasicBlock::iterator(Def); + IsKill = DestSrc->Source->isKill(); + + // There are no uses of original register between COPY and STATEPOINT. + // There can't be any after STATEPOINT, so we can eliminate Def. + if (!Use) { + LLVM_DEBUG(dbgs() << "spillRegisters: removing dead copy " << *Def); + Def->eraseFromParent(); + } + return SrcReg; +} + namespace { // Cache used frame indexes during statepoint re-write to re-use them in // processing next statepoint instruction. @@ -105,6 +221,13 @@ // size will be increased. DenseMap Cache; + // Landing pad can be destination of several statepoints. Every register + // defined by such statepoints must be spilled to the same stack slot. + // This map keeps that information. + // NOTE: we assume that spill slot live ranges do not intersect. + using RegStatepointPair = std::pair; + DenseMap GlobalIndices; + public: FrameIndexesCache(MachineFrameInfo &MFI, const TargetRegisterInfo &TRI) : MFI(MFI), TRI(TRI) {} @@ -114,8 +237,19 @@ for (auto &It : Cache) It.second.Index = 0; } + // Get frame index to spill the register. - int getFrameIndex(Register Reg) { + int getFrameIndex(Register Reg, MachineInstr *MI = nullptr) { + if (MI) { + auto It = GlobalIndices.find(std::make_pair(Reg, MI)); + if (It != GlobalIndices.end()) { + int FI = It->second; + LLVM_DEBUG(dbgs() << "Found global FI " << FI << " for register " + << printReg(Reg, &TRI) << " at " << *MI); + return FI; + } + } + unsigned Size = getRegisterSize(TRI, Reg); // In FixupSCSExtendSlotSize mode the bucket with 0 index is used // for all sizes. @@ -148,8 +282,32 @@ return getRegisterSize(TRI, A) > getRegisterSize(TRI, B); }); } + + // Record frame index to be used to spill register \p Reg at instr \p MI. + void addGlobalSpillSlot(Register Reg, MachineInstr *MI, int FI) { + auto P = std::make_pair(Reg, MI); + GlobalIndices.insert(std::make_pair(P, FI)); + } }; +// Check if we already inserted reload of register Reg from spill slot FI +// in basic block MBB. +// This can happen in EH pad block which is successor of several +// statepoints. +static bool hasRegReload(Register Reg, int FI, MachineBasicBlock *MBB, + const TargetInstrInfo *TII, + const TargetRegisterInfo *TRI) { + auto I = MBB->SkipPHIsLabelsAndDebug(MBB->begin()), E = MBB->end(); + int Dummy; + for (; I != E; ++I) { + if (TII->isLoadFromStackSlot(*I, Dummy) == Reg && Dummy == FI) + return true; + if (I->modifiesRegister(Reg, TRI) || I->readsRegister(Reg, TRI)) + return false; + } + return false; +} + // Describes the state of the current processing statepoint instruction. class StatepointState { private: @@ -163,26 +321,32 @@ const uint32_t *Mask; // Cache of frame indexes used on previous instruction processing. FrameIndexesCache &CacheFI; + bool AllowGCPtrInCSR; // Operands with physical registers requiring spilling. SmallVector OpsToSpill; // Set of register to spill. SmallVector RegsToSpill; + // Set of registers to reload after statepoint. + SmallVector RegsToReload; // Map Register to Frame Slot index. DenseMap RegToSlotIdx; public: StatepointState(MachineInstr &MI, const uint32_t *Mask, - FrameIndexesCache &CacheFI) + FrameIndexesCache &CacheFI, bool AllowGCPtrInCSR) : MI(MI), MF(*MI.getMF()), TRI(*MF.getSubtarget().getRegisterInfo()), TII(*MF.getSubtarget().getInstrInfo()), MFI(MF.getFrameInfo()), - Mask(Mask), CacheFI(CacheFI) {} + Mask(Mask), CacheFI(CacheFI), AllowGCPtrInCSR(AllowGCPtrInCSR) {} + // Return true if register is callee saved. bool isCalleeSaved(Register Reg) { return (Mask[Reg / 32] >> Reg % 32) & 1; } + // Iterates over statepoint meta args to find caller saver registers. // Also cache the size of found registers. // Returns true if caller save registers found. bool findRegistersToSpill() { SmallSet VisitedRegs; + SmallSet GCRegs = collectGCRegs(MI); for (unsigned Idx = StatepointOpers(&MI).getVarIdx(), EndIdx = MI.getNumOperands(); Idx < EndIdx; ++Idx) { @@ -191,8 +355,13 @@ continue; Register Reg = MO.getReg(); assert(Reg.isPhysical() && "Only physical regs are expected"); - if (isCalleeSaved(Reg)) + + if (isCalleeSaved(Reg) && (AllowGCPtrInCSR || !is_contained(GCRegs, Reg))) continue; + + LLVM_DEBUG(dbgs() << "Will spill " << printReg(Reg, &TRI) << " at index " + << Idx << "\n"); + if (VisitedRegs.insert(Reg).second) RegsToSpill.push_back(Reg); OpsToSpill.push_back(Idx); @@ -200,30 +369,117 @@ CacheFI.sortRegisters(RegsToSpill); return !RegsToSpill.empty(); } + // Spill all caller saved registers right before statepoint instruction. // Remember frame index where register is spilled. void spillRegisters() { for (Register Reg : RegsToSpill) { - int FI = CacheFI.getFrameIndex(Reg); + int FI = CacheFI.getFrameIndex(Reg, &MI); const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg); - TII.storeRegToStackSlot(*MI.getParent(), MI, Reg, true /*is_Kill*/, FI, - RC, &TRI); + NumSpilledRegisters++; RegToSlotIdx[Reg] = FI; + + LLVM_DEBUG(dbgs() << "Spilling " << printReg(Reg, &TRI) << "\n"); + + // Perform trivial copy propagation + bool IsKill = true; + MachineBasicBlock::iterator InsertBefore(MI); + Reg = performCopyPropagation(Reg, InsertBefore, IsKill, TII, TRI); + + LLVM_DEBUG(dbgs() << "Insert spill before " << *InsertBefore); + TII.storeRegToStackSlot(*MI.getParent(), InsertBefore, Reg, IsKill, FI, + RC, &TRI); + } + } + + void insertReloadBefore(unsigned Reg, MachineBasicBlock::iterator It, + MachineBasicBlock *MBB) { + const TargetRegisterClass *RC = TRI.getMinimalPhysRegClass(Reg); + int FI = RegToSlotIdx[Reg]; + if (It != MBB->end()) { + TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI); + return; + } + + // To insert reload at the end of MBB, insert it before last instruction + // and then swap them. + assert(MBB->begin() != MBB->end() && "Empty block"); + --It; + TII.loadRegFromStackSlot(*MBB, It, Reg, FI, RC, &TRI); + MachineInstr *Reload = It->getPrevNode(); + int Dummy = 0; + assert(TII.isLoadFromStackSlot(*Reload, Dummy) == Reg); + assert(Dummy == FI); + MBB->remove(Reload); + MBB->insertAfter(It, Reload); + } + + // Insert reloads of (relocated) registers spilled in statepoint. + void insertReloads(MachineInstr *NewStatepoint) { + MachineBasicBlock *MBB = NewStatepoint->getParent(); + auto It = std::next(NewStatepoint->getIterator()); + + MachineBasicBlock *EHPad = nullptr; + MachineBasicBlock::iterator EHPadInsertPoint; + // Invoke statepoint must be last one in block. + bool Last = + std::any_of(It, MBB->end().getInstrIterator(), [](MachineInstr &I) { + return I.getOpcode() == TargetOpcode::STATEPOINT; + }); + if (Last) { + auto It = llvm::find_if(MBB->successors(), [](MachineBasicBlock *B) { + return B->isEHPad(); + }); + if (It != MBB->succ_end()) { + EHPad = *It; + EHPadInsertPoint = EHPad->SkipPHIsLabelsAndDebug(EHPad->begin()); + } + } + + for (auto Reg : RegsToReload) { + insertReloadBefore(Reg, It, MBB); + + if (EHPad && !hasRegReload(Reg, RegToSlotIdx[Reg], EHPad, &TII, &TRI)) + insertReloadBefore(Reg, EHPadInsertPoint, EHPad); } } + // Re-write statepoint machine instruction to replace caller saved operands // with indirect memory location (frame index). - void rewriteStatepoint() { + MachineInstr* rewriteStatepoint() { MachineInstr *NewMI = MF.CreateMachineInstr(TII.get(MI.getOpcode()), MI.getDebugLoc(), true); MachineInstrBuilder MIB(MF, NewMI); + unsigned NumOps = MI.getNumOperands(); + + // New indices for the remaining defs. + SmallVector NewIndices; + unsigned NumDefs = MI.getNumDefs(); + for (unsigned I = 0; I < NumDefs; ++I) { + MachineOperand &DefMO = MI.getOperand(I); + assert(DefMO.isReg() && DefMO.isDef() && "Expected Reg Def operand"); + Register Reg = DefMO.getReg(); + if (!AllowGCPtrInCSR) { + assert(is_contained(RegsToSpill, Reg)); + RegsToReload.push_back(Reg); + } else { + if (isCalleeSaved(Reg)) { + NewIndices.push_back(NewMI->getNumOperands()); + MIB.addReg(Reg, RegState::Define); + } else { + NewIndices.push_back(NumOps); + RegsToReload.push_back(Reg); + } + } + } + // Add End marker. OpsToSpill.push_back(MI.getNumOperands()); unsigned CurOpIdx = 0; - for (unsigned I = 0; I < MI.getNumOperands(); ++I) { + for (unsigned I = NumDefs; I < MI.getNumOperands(); ++I) { MachineOperand &MO = MI.getOperand(I); if (I == OpsToSpill[CurOpIdx]) { int FI = RegToSlotIdx[MO.getReg()]; @@ -234,8 +490,15 @@ MIB.addFrameIndex(FI); MIB.addImm(0); ++CurOpIdx; - } else + } else { MIB.add(MO); + unsigned OldDef; + if (AllowGCPtrInCSR && MI.isRegTiedToDefOperand(I, &OldDef)) { + assert(OldDef < NumDefs); + assert(NewIndices[OldDef] < NumOps); + MIB->tieOperands(NewIndices[OldDef], MIB->getNumOperands() - 1); + } + } } assert(CurOpIdx == (OpsToSpill.size() - 1) && "Not all operands processed"); // Add mem operands. @@ -248,9 +511,13 @@ MFI.getObjectAlign(FrameIndex)); NewMI->addMemOperand(MF, MMO); } + // Insert new statepoint and erase old one. MI.getParent()->insert(MI, NewMI); + + LLVM_DEBUG(dbgs() << "rewritten statepoint to : " << *NewMI << "\n"); MI.eraseFromParent(); + return NewMI; } }; @@ -265,27 +532,104 @@ : MF(MF), TRI(*MF.getSubtarget().getRegisterInfo()), CacheFI(MF.getFrameInfo(), TRI) {} - bool process(MachineInstr &MI) { + StatepointProcessor(MachineFunction &MF, FrameIndexesCache &Cache) + : MF(MF), TRI(*MF.getSubtarget().getRegisterInfo()), CacheFI(Cache) {} + + bool process(MachineInstr &MI, bool AllowGCPtrInCSR) { StatepointOpers SO(&MI); uint64_t Flags = SO.getFlags(); // Do nothing for LiveIn, it supports all registers. if (Flags & (uint64_t)StatepointFlags::DeoptLiveIn) return false; + LLVM_DEBUG(dbgs() << "\nMBB " << MI.getParent()->getNumber() << " " + << MI.getParent()->getName() << " : process statepoint " + << MI); CallingConv::ID CC = SO.getCallingConv(); const uint32_t *Mask = TRI.getCallPreservedMask(MF, CC); CacheFI.reset(); - StatepointState SS(MI, Mask, CacheFI); + StatepointState SS(MI, Mask, CacheFI, AllowGCPtrInCSR); if (!SS.findRegistersToSpill()) return false; SS.spillRegisters(); - SS.rewriteStatepoint(); + auto *NewStatepoint = SS.rewriteStatepoint(); + SS.insertReloads(NewStatepoint); return true; } }; } // namespace +// Return live out definition of Reg in MBB or null +static MachineInstr *findLiveOutDef(Register Reg, MachineBasicBlock *MBB, + const TargetRegisterInfo *TRI) { + bool StatepointSeen = false; + for (auto I = MBB->rbegin(), E = MBB->rend(); I != E; ++I) { + // Special case for statepoint because we're looking specifically + // for explicit defs, not any implicit effects like regmask or register + // implicit def. + if (I->getOpcode() == TargetOpcode::STATEPOINT) { + int DefIdx = I->findRegisterDefOperandIdx(Reg, false, false, TRI); + if (DefIdx >= 0 && (unsigned)DefIdx < I->getNumDefs()) { + assert(!StatepointSeen && "Found gc value live across statepoint"); + return &*I; + } + StatepointSeen = true; + } else if (I->modifiesRegister(Reg, TRI)) { + return &*I; + } + } + return nullptr; +} + +// For EH pad block with multiple predecessors check if its live-in +// registers are defined by statepoints in preds. If so, assign same +// spill slot for each register at each statepoint. +// NOTE: It works only if all reaching definitions of register are statepoints, +// otherwise we cannot insert reload into EH pad and must insert multiple +// reloads on edges. +void FixupStatepointCallerSaved::collectGlobalFIs( + MachineBasicBlock &BB, FrameIndexesCache &Cache, + const TargetRegisterInfo *TRI) { + if (!BB.isEHPad() || BB.livein_empty() || BB.pred_size() == 1) + return; + SmallVector Preds(BB.predecessors()); + auto isStatepoint = [](MachineInstr *I) { + return I && I->getOpcode() == TargetOpcode::STATEPOINT; + }; + + // Resetting the cache allows us to reuse stack slots between + // different 'statepoint sets' (a set of statepoints reaching + // same EH Pad). This works under assumption that we allocate + // these 'global' spill slots before starting to process + // individual statepoints. + Cache.reset(); + + for (const auto &LI : BB.liveins()) { + Register Reg = LI.PhysReg; + SmallVector RegDefs; + for (auto *B : Preds) + RegDefs.push_back(findLiveOutDef(Reg, B, TRI)); + if (llvm::all_of(RegDefs, isStatepoint)) { + int FI = Cache.getFrameIndex(Reg); + for (auto *Def : RegDefs) { + Cache.addGlobalSpillSlot(Reg, Def, FI); + LLVM_DEBUG(dbgs() << "EH Pad bb." << BB.getNumber() << ": reserving FI " + << FI << " to spill register " << printReg(Reg, TRI) + << " at statepoint in bb." + << Def->getParent()->getNumber() << "\n"); + } + } else { + // That spilling stuff is all-or-nothing: either all defining instructions + // are statepoints (and we can spill to the same slot) or none of them are + // statepoints (so we do not need any reloads). Otherwise we're in + // trouble. + assert(llvm::none_of(RegDefs, isStatepoint) && + "Cannot safely reload register"); + } + } +} + bool FixupStatepointCallerSaved::runOnMachineFunction(MachineFunction &MF) { if (skipFunction(MF.getFunction())) return false; @@ -294,18 +638,30 @@ if (!F.hasGC()) return false; + const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo(); + FrameIndexesCache FICache(MF.getFrameInfo(), *TRI); + SmallVector Statepoints; - for (MachineBasicBlock &BB : MF) + for (MachineBasicBlock &BB : MF) { + collectGlobalFIs(BB, FICache, TRI); for (MachineInstr &I : BB) if (I.getOpcode() == TargetOpcode::STATEPOINT) Statepoints.push_back(&I); + } if (Statepoints.empty()) return false; bool Changed = false; - StatepointProcessor SPP(MF); - for (MachineInstr *I : Statepoints) - Changed |= SPP.process(*I); + StatepointProcessor SPP(MF, FICache); + unsigned NumStatepoints = 0; + bool AllowGCPtrInCSR = PassGCPtrInCSR; + for (MachineInstr *I : Statepoints) { + ++NumStatepoints; + if (MaxStatepointsWithRegs.getNumOccurrences() && + NumStatepoints >= MaxStatepointsWithRegs) + AllowGCPtrInCSR = false; + Changed |= SPP.process(*I, AllowGCPtrInCSR); + } return Changed; } diff --git a/llvm/test/CodeGen/X86/statepoint-vreg.mir b/llvm/test/CodeGen/X86/statepoint-vreg.mir --- a/llvm/test/CodeGen/X86/statepoint-vreg.mir +++ b/llvm/test/CodeGen/X86/statepoint-vreg.mir @@ -1,5 +1,5 @@ # NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -# RUN: llc -o - %s -start-after=finalize-isel | FileCheck %s +# RUN: llc -o - %s -fixup-allow-gcptr-in-csr=true -start-after=finalize-isel | FileCheck %s --- | ; ModuleID = 'test/CodeGen/X86/statepoint-vreg.ll'