Index: llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h =================================================================== --- llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h +++ llvm/include/llvm/Transforms/Utils/BasicBlockUtils.h @@ -444,7 +444,44 @@ Instruction **ThenTerm, Instruction **ElseTerm, MDNode *BranchWeights = nullptr, - DomTreeUpdater *DTU = nullptr); + DomTreeUpdater *DTU = nullptr, + LoopInfo *LI = nullptr); + +/// Split the containing block at the specified instruction - everything before +/// SplitBefore stays in the old basic block, and the rest of the instructions +/// in the BB are moved to a new block. The two blocks are connected by a +/// conditional branch (with value of Cmp being the condition). +/// Before: +/// Head +/// SplitBefore +/// Tail +/// After: +/// Head +/// if (Cond) +/// TrueBlock +/// else +//// FalseBlock +/// SplitBefore +/// Tail +/// +/// If \p ThenBlock is null, the resulting CFG won't contain the TrueBlock. If +/// \p ThenBlock is non-null and points to non-null BasicBlock pointer, that +/// block will be inserted as the TrueBlock. Otherwise a new block will be +/// created. Likewise for the \p ElseBlock parameter. +/// If \p UnreachableThen or \p UnreachableElse is true, the corresponding newly +/// created blocks will end with UnreachableInst, otherwise with branches to +/// Tail. The function will not modify existing basic blocks passed to it. The +/// caller must ensure that Tail is reachable from Head. +/// Returns the newly created blocks in \p ThenBlock and \p ElseBlock. +/// Updates DTU and LI if given. +void SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, + BasicBlock **ThenBlock, + BasicBlock **ElseBlock, + bool UnreachableThen = false, + bool UnreachableElse = false, + MDNode *BranchWeights = nullptr, + DomTreeUpdater *DTU = nullptr, + LoopInfo *LI = nullptr); /// Insert a for (int i = 0; i < End; i++) loop structure (with the exception /// that \p End is assumed > 0, and thus not checked on entry) at \p Index: llvm/lib/Transforms/Utils/BasicBlockUtils.cpp =================================================================== --- llvm/lib/Transforms/Utils/BasicBlockUtils.cpp +++ llvm/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -1477,94 +1477,107 @@ MDNode *BranchWeights, DomTreeUpdater *DTU, LoopInfo *LI, BasicBlock *ThenBlock) { - SmallVector Updates; - BasicBlock *Head = SplitBefore->getParent(); - BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); - if (DTU) { - SmallPtrSet UniqueSuccessorsOfHead; - Updates.push_back({DominatorTree::Insert, Head, Tail}); - Updates.reserve(Updates.size() + 2 * succ_size(Tail)); - for (BasicBlock *SuccessorOfHead : successors(Tail)) - if (UniqueSuccessorsOfHead.insert(SuccessorOfHead).second) { - Updates.push_back({DominatorTree::Insert, Tail, SuccessorOfHead}); - Updates.push_back({DominatorTree::Delete, Head, SuccessorOfHead}); - } - } - Instruction *HeadOldTerm = Head->getTerminator(); - LLVMContext &C = Head->getContext(); - Instruction *CheckTerm; - bool CreateThenBlock = (ThenBlock == nullptr); - if (CreateThenBlock) { - ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); - if (Unreachable) - CheckTerm = new UnreachableInst(C, ThenBlock); - else { - CheckTerm = BranchInst::Create(Tail, ThenBlock); - if (DTU) - Updates.push_back({DominatorTree::Insert, ThenBlock, Tail}); - } - CheckTerm->setDebugLoc(SplitBefore->getDebugLoc()); - } else - CheckTerm = ThenBlock->getTerminator(); - BranchInst *HeadNewTerm = - BranchInst::Create(/*ifTrue*/ ThenBlock, /*ifFalse*/ Tail, Cond); - if (DTU) - Updates.push_back({DominatorTree::Insert, Head, ThenBlock}); - HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights); - ReplaceInstWithInst(HeadOldTerm, HeadNewTerm); - - if (DTU) - DTU->applyUpdates(Updates); - - if (LI) { - if (Loop *L = LI->getLoopFor(Head)) { - // unreachable-terminated blocks cannot belong to any loop. - if (!Unreachable) - L->addBasicBlockToLoop(ThenBlock, *LI); - L->addBasicBlockToLoop(Tail, *LI); - } - } - - return CheckTerm; + SplitBlockAndInsertIfThenElse( + Cond, SplitBefore, &ThenBlock, /* ElseBlock */ nullptr, + /* UnreachableThen */ Unreachable, + /* UnreachableElse */ false, BranchWeights, DTU, LI); + return ThenBlock->getTerminator(); } void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, Instruction **ThenTerm, Instruction **ElseTerm, MDNode *BranchWeights, - DomTreeUpdater *DTU) { - BasicBlock *Head = SplitBefore->getParent(); + DomTreeUpdater *DTU, LoopInfo *LI) { + BasicBlock *ThenBlock = nullptr; + BasicBlock *ElseBlock = nullptr; + SplitBlockAndInsertIfThenElse( + Cond, SplitBefore, &ThenBlock, &ElseBlock, /* UnreachableThen */ false, + /* UnreachableElse */ false, BranchWeights, DTU, LI); + + *ThenTerm = ThenBlock->getTerminator(); + *ElseTerm = ElseBlock->getTerminator(); +} + +void llvm::SplitBlockAndInsertIfThenElse( + Value *Cond, Instruction *SplitBefore, BasicBlock **ThenBlock, + BasicBlock **ElseBlock, bool UnreachableThen, bool UnreachableElse, + MDNode *BranchWeights, DomTreeUpdater *DTU, LoopInfo *LI) { + assert((ThenBlock || ElseBlock) && + "At least one branch block must be created"); + assert((!UnreachableThen || !UnreachableElse) && + "Split block tail must be reachable"); + SmallVector Updates; SmallPtrSet UniqueOrigSuccessors; - if (DTU) + BasicBlock *Head = SplitBefore->getParent(); + if (DTU) { UniqueOrigSuccessors.insert(succ_begin(Head), succ_end(Head)); + Updates.reserve(4 + 2 * UniqueOrigSuccessors.size()); + } + LLVMContext &C = Head->getContext(); BasicBlock *Tail = Head->splitBasicBlock(SplitBefore->getIterator()); + BasicBlock *TrueBlock = Tail; + BasicBlock *FalseBlock = Tail; + bool ThenToTailEdge = false; + bool ElseToTailEdge = false; + + // Encapsulate the logic around creation/insertion/etc of a new block. + auto handleBlock = [&](BasicBlock **PBB, bool Unreachable, BasicBlock *&BB, + bool &ToTailEdge) { + if (PBB == nullptr) + return; // Do not create/insert a block. + + if (*PBB) + BB = *PBB; // Caller supplied block, use it. + else { + // Create a new block. + BB = BasicBlock::Create(C, "", Head->getParent(), Tail); + if (Unreachable) + (void)new UnreachableInst(C, BB); + else { + (void)BranchInst::Create(Tail, BB); + ToTailEdge = true; + } + BB->getTerminator()->setDebugLoc(SplitBefore->getDebugLoc()); + // Pass the new block back to the caller. + *PBB = BB; + } + }; + + handleBlock(ThenBlock, UnreachableThen, TrueBlock, ThenToTailEdge); + handleBlock(ElseBlock, UnreachableElse, FalseBlock, ElseToTailEdge); + Instruction *HeadOldTerm = Head->getTerminator(); - LLVMContext &C = Head->getContext(); - BasicBlock *ThenBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); - BasicBlock *ElseBlock = BasicBlock::Create(C, "", Head->getParent(), Tail); - *ThenTerm = BranchInst::Create(Tail, ThenBlock); - (*ThenTerm)->setDebugLoc(SplitBefore->getDebugLoc()); - *ElseTerm = BranchInst::Create(Tail, ElseBlock); - (*ElseTerm)->setDebugLoc(SplitBefore->getDebugLoc()); BranchInst *HeadNewTerm = - BranchInst::Create(/*ifTrue*/ThenBlock, /*ifFalse*/ElseBlock, Cond); + BranchInst::Create(/*ifTrue*/ TrueBlock, /*ifFalse*/ FalseBlock, Cond); HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights); ReplaceInstWithInst(HeadOldTerm, HeadNewTerm); + if (DTU) { - SmallVector Updates; - Updates.reserve(4 + 2 * UniqueOrigSuccessors.size()); - for (BasicBlock *Succ : successors(Head)) { - Updates.push_back({DominatorTree::Insert, Head, Succ}); - Updates.push_back({DominatorTree::Insert, Succ, Tail}); - } + Updates.emplace_back(DominatorTree::Insert, Head, TrueBlock); + Updates.emplace_back(DominatorTree::Insert, Head, FalseBlock); + if (ThenToTailEdge) + Updates.emplace_back(DominatorTree::Insert, TrueBlock, Tail); + if (ElseToTailEdge) + Updates.emplace_back(DominatorTree::Insert, FalseBlock, Tail); for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors) - Updates.push_back({DominatorTree::Insert, Tail, UniqueOrigSuccessor}); + Updates.emplace_back(DominatorTree::Insert, Tail, UniqueOrigSuccessor); for (BasicBlock *UniqueOrigSuccessor : UniqueOrigSuccessors) - Updates.push_back({DominatorTree::Delete, Head, UniqueOrigSuccessor}); + Updates.emplace_back(DominatorTree::Delete, Head, UniqueOrigSuccessor); DTU->applyUpdates(Updates); } + + if (LI) { + if (Loop *L = LI->getLoopFor(Head); L) { + if (ThenToTailEdge) + L->addBasicBlockToLoop(TrueBlock, *LI); + if (ElseToTailEdge) + L->addBasicBlockToLoop(FalseBlock, *LI); + L->addBasicBlockToLoop(Tail, *LI); + } + } } std::pair