Index: llvm/include/llvm/CodeGen/ReachingDefAnalysis.h =================================================================== --- llvm/include/llvm/CodeGen/ReachingDefAnalysis.h +++ llvm/include/llvm/CodeGen/ReachingDefAnalysis.h @@ -72,6 +72,7 @@ const int ReachingDefDefaultVal = -(1 << 20); using InstSet = SmallPtrSetImpl; + using BlockSet = SmallPtrSetImpl; public: static char ID; // Pass identification, replacement for typeid @@ -107,14 +108,6 @@ /// PhysReg that reaches MI, relative to the begining of MI's basic block. int getReachingDef(MachineInstr *MI, int PhysReg) const; - /// Provides the instruction of the closest reaching def instruction of - /// PhysReg that reaches MI, relative to the begining of MI's basic block. - MachineInstr *getReachingMIDef(MachineInstr *MI, int PhysReg) const; - - /// Provides the MI, from the given block, corresponding to the Id or a - /// nullptr if the id does not refer to the block. - MachineInstr *getInstFromId(MachineBasicBlock *MBB, int InstId) const; - /// Return whether A and B use the same def of PhysReg. bool hasSameReachingDef(MachineInstr *A, MachineInstr *B, int PhysReg) const; @@ -127,6 +120,18 @@ MachineInstr *getLocalLiveOutMIDef(MachineBasicBlock *MBB, int PhysReg) const; + /// If a single MachineInstr creates the reaching definition, then return it. + /// Otherwise return null. + MachineInstr *getUniqueReachingMIDef(MachineInstr *MI, int PhysReg) const; + + /// If a single MachineInstr creates the reaching definition, for MIs operand + /// at Idx, then return it. Otherwise return null. + MachineInstr *getMIOperand(MachineInstr *MI, unsigned Idx) const; + + /// If a single MachineInstr creates the reaching definition, for MIs MO, + /// then return it. Otherwise return null. + MachineInstr *getMIOperand(MachineInstr *MI, MachineOperand &MO) const; + /// Provide whether the register has been defined in the same basic block as, /// and before, MI. bool hasLocalDefBefore(MachineInstr *MI, int PhysReg) const; @@ -147,6 +152,11 @@ void getReachingLocalUses(MachineInstr *MI, int PhysReg, InstSet &Uses) const; + /// Search MBB for a definition of PhysReg and insert it into Incoming. If no + /// definition is found, recursively search the predecessor blocks for them. + void getLiveOuts(MachineBasicBlock *MBB, int PhysReg, InstSet &Defs, + BlockSet &VisitedBBs) const; + /// For the given block, collect the instructions that use the live-in /// value of the provided register. Return whether the value is still /// live on exit. @@ -206,6 +216,14 @@ /// the redundant use-def chain. bool isSafeToRemove(MachineInstr *MI, InstSet &Visited, InstSet &ToRemove, InstSet &Ignore) const; + + /// Provides the MI, from the given block, corresponding to the Id or a + /// nullptr if the id does not refer to the block. + MachineInstr *getInstFromId(MachineBasicBlock *MBB, int InstId) const; + + /// Provides the instruction of the closest reaching def instruction of + /// PhysReg that reaches MI, relative to the begining of MI's basic block. + MachineInstr *getReachingLocalMIDef(MachineInstr *MI, int PhysReg) const; }; } // namespace llvm Index: llvm/lib/CodeGen/ReachingDefAnalysis.cpp =================================================================== --- llvm/lib/CodeGen/ReachingDefAnalysis.cpp +++ llvm/lib/CodeGen/ReachingDefAnalysis.cpp @@ -195,7 +195,7 @@ return LatestDef; } -MachineInstr* ReachingDefAnalysis::getReachingMIDef(MachineInstr *MI, +MachineInstr* ReachingDefAnalysis::getReachingLocalMIDef(MachineInstr *MI, int PhysReg) const { return getInstFromId(MI->getParent(), getReachingDef(MI, PhysReg)); } @@ -250,7 +250,7 @@ // If/when we find a new reaching def, we know that there's no more uses // of 'Def'. - if (getReachingMIDef(&*MI, PhysReg) != Def) + if (getReachingLocalMIDef(&*MI, PhysReg) != Def) return; for (auto &MO : MI->operands()) { @@ -311,6 +311,59 @@ } } +void +ReachingDefAnalysis::getLiveOuts(MachineBasicBlock *MBB, int PhysReg, + InstSet &Defs, BlockSet &VisitedBBs) const { + if (VisitedBBs.count(MBB)) + return; + + VisitedBBs.insert(MBB); + LivePhysRegs LiveRegs(*TRI); + LiveRegs.addLiveOuts(*MBB); + if (!LiveRegs.contains(PhysReg)) + return; + + if (auto *Def = getLocalLiveOutMIDef(MBB, PhysReg)) + Defs.insert(Def); + else + for (auto *Pred : MBB->predecessors()) + getLiveOuts(Pred, PhysReg, Defs, VisitedBBs); +} + +MachineInstr *ReachingDefAnalysis::getUniqueReachingMIDef(MachineInstr *MI, + int PhysReg) const { + // If there's a local def before MI, return it. + MachineInstr *LocalDef = getReachingLocalMIDef(MI, PhysReg); + if (InstIds.lookup(LocalDef) < InstIds.lookup(MI)) + return LocalDef; + + SmallPtrSet VisitedBBs; + SmallPtrSet Incoming; + for (auto *Pred : MI->getParent()->predecessors()) + getLiveOuts(Pred, PhysReg, Incoming, VisitedBBs); + + // If we have a local def and an incoming instruction, then there's not a + // unique instruction def. + if (!Incoming.empty() && LocalDef) + return nullptr; + else if (Incoming.size() == 1) + return *Incoming.begin(); + else + return LocalDef; +} + +MachineInstr *ReachingDefAnalysis::getMIOperand(MachineInstr *MI, + unsigned Idx) const { + assert(MI->getOperand(Idx).isReg() && "Expected register operand"); + return getUniqueReachingMIDef(MI, MI->getOperand(Idx).getReg()); +} + +MachineInstr *ReachingDefAnalysis::getMIOperand(MachineInstr *MI, + MachineOperand &MO) const { + assert(MO.isReg() && "Expected register operand"); + return getUniqueReachingMIDef(MI, MO.getReg()); +} + bool ReachingDefAnalysis::isRegUsedAfter(MachineInstr *MI, int PhysReg) const { MachineBasicBlock *MBB = MI->getParent(); LivePhysRegs LiveRegs(*TRI); @@ -337,7 +390,7 @@ return true; if (auto *Def = getLocalLiveOutMIDef(MBB, PhysReg)) - return Def == getReachingMIDef(MI, PhysReg); + return Def == getReachingLocalMIDef(MI, PhysReg); return false; } @@ -482,7 +535,7 @@ InstSet &Ignore) const { // Check for any uses of the register after MI. if (isRegUsedAfter(MI, PhysReg)) { - if (auto *Def = getReachingMIDef(MI, PhysReg)) { + if (auto *Def = getReachingLocalMIDef(MI, PhysReg)) { SmallPtrSet Uses; getReachingLocalUses(Def, PhysReg, Uses); for (auto *Use : Uses) Index: llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp =================================================================== --- llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp +++ llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp @@ -342,7 +342,7 @@ // Find an insertion point: // - Is there a (mov lr, Count) before Start? If so, and nothing else writes // to Count before Start, we can insert at that mov. - if (auto *LRDef = RDA.getReachingMIDef(Start, ARM::LR)) + if (auto *LRDef = RDA.getUniqueReachingMIDef(Start, ARM::LR)) if (IsMoveLR(LRDef) && RDA.hasSameReachingDef(Start, LRDef, CountReg)) return LRDef; @@ -479,7 +479,7 @@ }; MBB = VCTP->getParent(); - if (MachineInstr *Def = RDA.getReachingMIDef(&MBB->back(), NumElements)) { + if (auto *Def = RDA.getUniqueReachingMIDef(&MBB->back(), NumElements)) { SmallPtrSet ElementChain; SmallPtrSet Ignore = { VCTP }; unsigned ExpectedVectorWidth = getTailPredVectorWidth(VCTP->getOpcode()); @@ -897,8 +897,7 @@ if (!LoLoop.IsTailPredicationLegal()) return; - if (auto *Def = RDA->getReachingMIDef(LoLoop.Start, - LoLoop.Start->getOperand(0).getReg())) { + if (auto *Def = RDA->getMIOperand(LoLoop.Start, 0)) { SmallPtrSet Remove; SmallPtrSet Ignore = { LoLoop.Start, LoLoop.Dec, LoLoop.End, LoLoop.InsertPt }; @@ -945,7 +944,7 @@ for (auto &MO : MI->operands()) { if (!MO.isReg() || !MO.isUse() || MO.getReg() == 0) continue; - if (auto *Op = RDA->getReachingMIDef(MI, MO.getReg())) + if (auto *Op = RDA->getMIOperand(MI, MO)) Chain.push_back(Op); } Ignore.insert(MI); @@ -1133,6 +1132,9 @@ for (auto *MBB : reverse(PostOrder)) recomputeLivenessFlags(*MBB); + + // We've moved, removed and inserted new instructions, so update RDA. + RDA->reset(); } bool ARMLowOverheadLoops::RevertNonLoops() {