diff --git a/llvm/include/llvm/CodeGen/MachinePostDominators.h b/llvm/include/llvm/CodeGen/MachinePostDominators.h --- a/llvm/include/llvm/CodeGen/MachinePostDominators.h +++ b/llvm/include/llvm/CodeGen/MachinePostDominators.h @@ -71,11 +71,20 @@ return DT->properlyDominates(A, B); } + bool isVirtualRoot(const MachineDomTreeNode *BB) const { + return DT->isVirtualRoot(BB); + } + MachineBasicBlock *findNearestCommonDominator(MachineBasicBlock *A, - MachineBasicBlock *B) { + MachineBasicBlock *B) const { return DT->findNearestCommonDominator(A, B); } + /// Returns the nearest common dominator of the given blocks. + /// If that tree node is a virtual root, a nullptr will be returned. + MachineBasicBlock * + findNearestCommonDominator(ArrayRef Blocks) const; + bool runOnMachineFunction(MachineFunction &MF) override; void getAnalysisUsage(AnalysisUsage &AU) const override; void print(llvm::raw_ostream &OS, const Module *M = nullptr) const override; diff --git a/llvm/lib/CodeGen/MachinePostDominators.cpp b/llvm/lib/CodeGen/MachinePostDominators.cpp --- a/llvm/lib/CodeGen/MachinePostDominators.cpp +++ b/llvm/lib/CodeGen/MachinePostDominators.cpp @@ -13,6 +13,8 @@ #include "llvm/CodeGen/MachinePostDominators.h" +#include "llvm/ADT/STLExtras.h" + using namespace llvm; namespace llvm { @@ -51,7 +53,25 @@ MachineFunctionPass::getAnalysisUsage(AU); } -void -MachinePostDominatorTree::print(llvm::raw_ostream &OS, const Module *M) const { +MachineBasicBlock *MachinePostDominatorTree::findNearestCommonDominator( + ArrayRef Blocks) const { + assert(!Blocks.empty()); + assert(llvm::all_of(Blocks, [](MachineBasicBlock *B) { return B; }) && + "Invalid block found"); + + MachineBasicBlock *NCD = Blocks.front(); + for (MachineBasicBlock *BB : Blocks.drop_front()) { + NCD = DT->findNearestCommonDominator(NCD, BB); + + // Stop when the root is reached. + if (DT->isVirtualRoot(DT->getNode(NCD))) + break; + } + + return NCD; +} + +void MachinePostDominatorTree::print(llvm::raw_ostream &OS, + const Module *M) const { DT->print(OS); } diff --git a/llvm/lib/Target/AMDGPU/SILowerI1Copies.cpp b/llvm/lib/Target/AMDGPU/SILowerI1Copies.cpp --- a/llvm/lib/Target/AMDGPU/SILowerI1Copies.cpp +++ b/llvm/lib/Target/AMDGPU/SILowerI1Copies.cpp @@ -589,12 +589,12 @@ // Phis in a loop that are observed outside the loop receive a simple but // conservatively correct treatment. - MachineBasicBlock *PostDomBound = &MBB; - for (MachineInstr &Use : MRI->use_instructions(DstReg)) { - PostDomBound = - PDT->findNearestCommonDominator(PostDomBound, Use.getParent()); - } + std::vector DomBlocks = {&MBB}; + for (MachineInstr &Use : MRI->use_instructions(DstReg)) + DomBlocks.push_back(Use.getParent()); + MachineBasicBlock *PostDomBound = + PDT->findNearestCommonDominator(DomBlocks); unsigned FoundLoopLevel = LF.findLoop(PostDomBound); SSAUpdater.Initialize(DstReg); @@ -711,12 +711,12 @@ // Defs in a loop that are observed outside the loop must be transformed // into appropriate bit manipulation. - MachineBasicBlock *PostDomBound = &MBB; - for (MachineInstr &Use : MRI->use_instructions(DstReg)) { - PostDomBound = - PDT->findNearestCommonDominator(PostDomBound, Use.getParent()); - } + std::vector DomBlocks = {&MBB}; + for (MachineInstr &Use : MRI->use_instructions(DstReg)) + DomBlocks.push_back(Use.getParent()); + MachineBasicBlock *PostDomBound = + PDT->findNearestCommonDominator(DomBlocks); unsigned FoundLoopLevel = LF.findLoop(PostDomBound); if (FoundLoopLevel) { SSAUpdater.Initialize(DstReg);