diff --git a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h --- a/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h +++ b/llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h @@ -90,11 +90,14 @@ /// if BB's Pred has a branch to BB and to AnotherBB, and BB has a single /// successor Sing. In this case the branch will be updated with Sing instead of /// BB, and BB will still be merged into its predecessor and removed. +/// If \p DT is not nullptr, update it directly; in that case, DTU must be +/// nullptr. bool MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU = nullptr, LoopInfo *LI = nullptr, MemorySSAUpdater *MSSAU = nullptr, MemoryDependenceResults *MemDep = nullptr, - bool PredecessorWithTwoSuccessors = false); + bool PredecessorWithTwoSuccessors = false, + DominatorTree *DT = nullptr); /// Merge block(s) sucessors, if possible. Return true if at least two /// of the blocks were merged together. diff --git a/llvm/include/llvm/Transforms/Utils/Local.h b/llvm/include/llvm/Transforms/Utils/Local.h --- a/llvm/include/llvm/Transforms/Utils/Local.h +++ b/llvm/include/llvm/Transforms/Utils/Local.h @@ -342,9 +342,12 @@ /// Insert an unreachable instruction before the specified /// instruction, making it and the rest of the code in the block dead. +/// If \p DT is not nullptr, update it directly; in that case, DTU must be +/// nullptr. unsigned changeToUnreachable(Instruction *I, bool PreserveLCSSA = false, DomTreeUpdater *DTU = nullptr, - MemorySSAUpdater *MSSAU = nullptr); + MemorySSAUpdater *MSSAU = nullptr, + DominatorTree *DT = nullptr); /// Convert the CallInst to InvokeInst with the specified unwind edge basic /// block. This also splits the basic block where CI is located, because @@ -495,6 +498,14 @@ /// function, explicitly materialize the maximal set in the IR. bool inferAttributesFromOthers(Function &F); +/// Update the \p DT for \p BB using its predecessors. This can be used to +/// update the dominator tree if predecessors of \p BB have been added or +/// removed or if the DT for its predecessors changed. The function assumes the +/// dominator tree is valid for all predecessors and sets \p BB's immediate +/// dominator to the nearest common dominator of all predecessors. \p BB cannot +/// be a predecessor of itself. +void updateDominatorTreeUsingPredecessors(BasicBlock *BB, DominatorTree &DT); + } // end namespace llvm #endif // LLVM_TRANSFORMS_UTILS_LOCAL_H diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp --- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -179,7 +179,8 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DomTreeUpdater *DTU, LoopInfo *LI, MemorySSAUpdater *MSSAU, MemoryDependenceResults *MemDep, - bool PredecessorWithTwoSuccessors) { + bool PredecessorWithTwoSuccessors, + DominatorTree *DT) { if (BB->hasAddressTaken()) return false; @@ -232,10 +233,21 @@ FoldSingleEntryPHINodes(BB, MemDep); } + if (DT) { + assert(!DTU && "cannot use both DT and DTU for updates"); + DomTreeNode *PredNode = DT->getNode(PredBB); + DomTreeNode *BBNode = DT->getNode(BB); + if (PredNode) { + assert(BBNode && "PredNode unreachable but BBNode reachable?"); + for (DomTreeNode *C : to_vector(BBNode->children())) + C->setIDom(PredNode); + } + } // DTU update: Collect all the edges that exit BB. // These dominator edges will be redirected from Pred. std::vector Updates; if (DTU) { + assert(!DT && "cannot use both DT and DTU for updates"); // To avoid processing the same predecessor more than once. SmallPtrSet SeenSuccs; SmallPtrSet SuccsOfPredBB(succ_begin(PredBB), @@ -311,6 +323,10 @@ if (DTU) DTU->applyUpdates(Updates); + if (DT) + for (BasicBlock *S : depth_first(BB)) + DT->eraseNode(S); + // Finally, erase the old block and update dominator info. DeleteDeadBlock(BB, DTU); diff --git a/llvm/lib/Transforms/Utils/Local.cpp b/llvm/lib/Transforms/Utils/Local.cpp --- a/llvm/lib/Transforms/Utils/Local.cpp +++ b/llvm/lib/Transforms/Utils/Local.cpp @@ -2232,8 +2232,8 @@ } unsigned llvm::changeToUnreachable(Instruction *I, bool PreserveLCSSA, - DomTreeUpdater *DTU, - MemorySSAUpdater *MSSAU) { + DomTreeUpdater *DTU, MemorySSAUpdater *MSSAU, + DominatorTree *DT) { BasicBlock *BB = I->getParent(); if (MSSAU) @@ -2261,12 +2261,19 @@ ++NumInstrsRemoved; } if (DTU) { + assert(!DT && "cannot use both DT and DTU for updates"); SmallVector Updates; Updates.reserve(UniqueSuccessors.size()); for (BasicBlock *UniqueSuccessor : UniqueSuccessors) Updates.push_back({DominatorTree::Delete, BB, UniqueSuccessor}); DTU->applyUpdates(Updates); } + + if (DT) { + assert(!DTU && "cannot use both DT and DTU for updates"); + for (BasicBlock *UniqueSuccessor : UniqueSuccessors) + updateDominatorTreeUsingPredecessors(UniqueSuccessor, *DT); + } return NumInstrsRemoved; } @@ -3516,3 +3523,26 @@ return Changed; } + +void llvm::updateDominatorTreeUsingPredecessors(BasicBlock *BB, + DominatorTree &DT) { + if (pred_empty(BB)) { + for (BasicBlock *N : depth_first(BB)) + DT.eraseNode(N); + return; + } + + DomTreeNode *BBNode = DT.getNode(BB); + // BB is unreachable, nothing to update. + if (!BBNode) + return; + + assert( + all_of(predecessors(BB), [BB](BasicBlock *Pred) { return Pred != BB; }) && + "BB is a predecessor of itself"); + auto PredI = pred_begin(BB); + BasicBlock *NewIDom = *PredI; + for (PredI = std::next(PredI); PredI != pred_end(BB); ++PredI) + NewIDom = DT.findNearestCommonDominator(NewIDom, *PredI); + BBNode->setIDom(DT.getNode(NewIDom)); +} diff --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp --- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp @@ -680,8 +680,6 @@ assert(!UnrollVerifyDomtree || DT->verify(DominatorTree::VerificationLevel::Fast)); - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) { auto *Term = cast(Src->getTerminator()); const unsigned Idx = ExitOnTrue ^ WillExit; @@ -695,7 +693,12 @@ BranchInst::Create(Dest, Term); Term->eraseFromParent(); - DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}}); + updateDominatorTreeUsingPredecessors(DeadSucc, *DT); + for (DomTreeNode *Child : to_vector(DT->getNode(Src)->children())) { + if (L->contains(Child->getBlock())) + continue; + updateDominatorTreeUsingPredecessors(Child->getBlock(), *DT); + } }; auto WillExit = [&](const ExitInfo &Info, unsigned i, unsigned j, @@ -757,7 +760,9 @@ // When completely unrolling, the last latch becomes unreachable. if (!LatchIsExiting && CompletelyUnroll) - changeToUnreachable(Latches.back()->getTerminator(), PreserveLCSSA, &DTU); + changeToUnreachable(Latches.back()->getTerminator(), PreserveLCSSA, + /*DTU=*/nullptr, + /*MSSAU=*/nullptr, DT); // Merge adjacent basic blocks, if possible. for (BasicBlock *Latch : Latches) { @@ -769,16 +774,15 @@ if (Term && Term->isUnconditional()) { BasicBlock *Dest = Term->getSuccessor(0); BasicBlock *Fold = Dest->getUniquePredecessor(); - if (MergeBlockIntoPredecessor(Dest, &DTU, LI)) { + if (MergeBlockIntoPredecessor( + Dest, /*DTU=*/nullptr, LI, /*MSSAU=*/nullptr, /*MemDep=*/nullptr, + /*PredecessorWithTwoSuccessors=*/false, DT)) { // Dest has been folded into Fold. Update our worklists accordingly. std::replace(Latches.begin(), Latches.end(), Dest, Fold); llvm::erase_value(UnrolledLoopBlocks, Dest); } } } - // Apply updates to the DomTree. - DT = &DTU.getDomTree(); - assert(!UnrollVerifyDomtree || DT->verify(DominatorTree::VerificationLevel::Fast)); diff --git a/llvm/unittests/Transforms/Utils/LocalTest.cpp b/llvm/unittests/Transforms/Utils/LocalTest.cpp --- a/llvm/unittests/Transforms/Utils/LocalTest.cpp +++ b/llvm/unittests/Transforms/Utils/LocalTest.cpp @@ -15,6 +15,7 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DebugInfo.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -1087,3 +1088,84 @@ BB0->dropAllReferences(); } + +static BasicBlock *getBasicBlockByName(Function &F, StringRef Name) { + for (BasicBlock &BB : F) + if (BB.getName() == Name) + return &BB; + llvm_unreachable("Expected to find basic block!"); +} + +static Value *getArgumentByName(Function &F, StringRef Name) { + for (Argument &A : F.args()) + if (A.getName() == Name) + return &A; + llvm_unreachable("Expected to find argument!"); +} + +TEST(Local, updateDominatorTreeUsingPredecessors) { + LLVMContext C; + std::unique_ptr M = parseIR(C, R"IR( +define i32 @test(i1 %cond) { +entry: + br i1 %cond, label %bb0, label %bb1 + +bb0: + br label %bb1 + +bb1: + %phi = phi i32 [ 0, %entry ], [ 1, %bb0 ] + ret i32 %phi +} +)IR"); + Function *F = M->getFunction("test"); + DominatorTree DT(*F); + + BasicBlock *Entry = getBasicBlockByName(*F, "entry"); + BasicBlock *BB0 = getBasicBlockByName(*F, "bb0"); + BasicBlock *BB1 = getBasicBlockByName(*F, "bb1"); + + // Remove edge from entry -> bb1. + Entry->getTerminator()->eraseFromParent(); + BranchInst::Create(BB0, Entry); + updateDominatorTreeUsingPredecessors(BB1, &DT); + EXPECT_TRUE(DT.verify()); + + // Remove edge from bb0 -> bb1. + BB0->getTerminator()->eraseFromParent(); + new UnreachableInst(C, BB0); + updateDominatorTreeUsingPredecessors(BB1, &DT); + EXPECT_TRUE(DT.verify()); + + M = parseIR(C, R"IR( +define i32 @test(i1 %cond) { +entry: + br label %bb0 + +bb0: + br i1 %cond, label %bb1, label %bb0 + +bb1: + %phi = phi i32 [ 0, %entry ], [ 1, %bb0 ] + ret i32 %phi +} +)IR"); + F = M->getFunction("test"); + DT.recalculate(*F); + + Entry = getBasicBlockByName(*F, "entry"); + BB0 = getBasicBlockByName(*F, "bb0"); + BB1 = getBasicBlockByName(*F, "bb1"); + + // Remove edge from bb0 -> bb0. + BB0->getTerminator()->eraseFromParent(); + BranchInst::Create(BB1, BB0); + updateDominatorTreeUsingPredecessors(BB1, &DT); + EXPECT_TRUE(DT.verify()); + + // Add edge from entry -> bb1. + Entry->getTerminator()->eraseFromParent(); + BranchInst::Create(BB0, BB1, getArgumentByName(*F, "cond"), Entry); + updateDominatorTreeUsingPredecessors(BB1, &DT); + EXPECT_TRUE(DT.verify()); +}