Index: lib/Target/AMDGPU/SIFixSGPRCopies.cpp =================================================================== --- lib/Target/AMDGPU/SIFixSGPRCopies.cpp +++ lib/Target/AMDGPU/SIFixSGPRCopies.cpp @@ -69,6 +69,7 @@ #include "AMDGPUSubtarget.h" #include "SIInstrInfo.h" #include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachinePostDominators.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" @@ -85,12 +86,12 @@ class SIFixSGPRCopies : public MachineFunctionPass { MachineDominatorTree *MDT; + MachinePostDominatorTree *PDT; public: static char ID; SIFixSGPRCopies() : MachineFunctionPass(ID) { } - bool runOnMachineFunction(MachineFunction &MF) override; StringRef getPassName() const override { return "SI Fix SGPR copies"; } @@ -98,6 +99,8 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); AU.setPreservesCFG(); MachineFunctionPass::getAnalysisUsage(AU); } @@ -333,6 +336,7 @@ const SIRegisterInfo *TRI = ST.getRegisterInfo(); const SIInstrInfo *TII = ST.getInstrInfo(); MDT = &getAnalysis(); + PDT = &getAnalysis(); SmallVector Worklist; @@ -380,12 +384,15 @@ // We don't need to fix the PHI if the common dominator of the // two incoming blocks terminates with a uniform branch. if (MI.getNumExplicitOperands() == 5) { - MachineBasicBlock *MBB0 = MI.getOperand(2).getMBB(); - MachineBasicBlock *MBB1 = MI.getOperand(4).getMBB(); - - MachineBasicBlock *NCD = MDT->findNearestCommonDominator(MBB0, MBB1); - if (NCD && !hasTerminatorThatModifiesExec(*NCD, *TRI)) { - DEBUG(dbgs() << "Not fixing PHI for uniform branch: " << MI << '\n'); + MachineBasicBlock *parentBB = MI.getParent(); + MachineDomTreeNode *N = PDT->getNode(parentBB); + if (!N) + break; + MachineDomTreeNode *IPostDom = !(N->getIDom()) ? nullptr : + N->getIDom(); + if (!hasTerminatorThatModifiesExec(*parentBB, *TRI) && IPostDom) { + DEBUG(dbgs() << "Not fixing PHI for uniform branch: " + << MI << '\n'); break; } }