Index: lib/Transforms/Scalar/SimpleLoopUnswitch.cpp =================================================================== --- lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -23,6 +23,8 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -112,7 +114,8 @@ /// terminator. static void rewritePHINodesForUnswitchedExitBlock(BasicBlock &UnswitchedBB, BasicBlock &OldExitingBB, - BasicBlock &OldPH) { + BasicBlock &OldPH, + MemorySSA *MSSA) { for (PHINode &PN : UnswitchedBB.phis()) { // When the loop exit is directly unswitched we just need to update the // incoming basic block. We loop to handle weird cases with repeated @@ -123,6 +126,13 @@ PN.setIncomingBlock(i, &OldPH); } } + + if (MSSA) + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(&UnswitchedBB)) { + // Only updates incoming block. Incoming value will be replaced in all + // blocks no longer dominated by OldExitingBB, after DT is updated. + MPhi->setIncomingBlock(MPhi->getBasicBlockIndex(&UnswitchedBB), &OldPH); + } } /// Rewrite the PHI nodes in the loop exit basic block and the split off @@ -135,7 +145,9 @@ static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, BasicBlock &UnswitchedBB, BasicBlock &OldExitingBB, - BasicBlock &OldPH) { + BasicBlock &OldPH, + MemorySSA *MSSA, + MemorySSAUpdater *MSSAUpdater) { assert(&ExitBB != &UnswitchedBB && "Must have different loop exit and unswitched blocks!"); Instruction *InsertPt = &*UnswitchedBB.begin(); @@ -165,6 +177,13 @@ PN.replaceAllUsesWith(NewPN); NewPN->addIncoming(&PN, &ExitBB); } + + if (MSSA) { + // Only updates incoming block. Incoming value will be replaced in all + // blocks no longer dominated by OldExitingBB, after DT is updated. + MSSAUpdater->wireOldPhiIntoNewPhiAfterUnswitch(&ExitBB, &UnswitchedBB, + &OldExitingBB, &OldPH); + } } /// Unswitch a trivial branch if the condition is loop invariant. @@ -182,7 +201,8 @@ /// the loop to an unconditional branch but doesn't remove it entirely. Further /// cleanup can be done with some simplify-cfg like pass. static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, MemorySSA *MSSA, + MemorySSAUpdater *MSSAUpdater) { assert(BI.isConditional() && "Can only unswitch a conditional branch!"); LLVM_DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); @@ -214,12 +234,15 @@ LLVM_DEBUG(dbgs() << " unswitching trivial branch when: " << CondVal << " == " << LoopCond << "\n"); + if (MSSA) { + MSSA->verifyMemorySSA(); + } // Split the preheader, so that we know that there is a safe place to insert // the conditional branch. We will change the preheader to have a conditional // branch on LoopCond. BasicBlock *OldPH = L.getLoopPreheader(); - BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI); + BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI, MSSA, MSSAUpdater); // Now that we have a place to insert the conditional branch, create a place // to branch to: this is the exit block out of the loop that we are @@ -232,7 +255,8 @@ "A branch's parent isn't a predecessor!"); UnswitchedBB = LoopExitBB; } else { - UnswitchedBB = SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI); + UnswitchedBB = SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI, MSSA, + MSSAUpdater); } // Now splice the branch to gate reaching the new preheader and re-point its @@ -249,10 +273,11 @@ // Rewrite the relevant PHI nodes. if (UnswitchedBB == LoopExitBB) - rewritePHINodesForUnswitchedExitBlock(*UnswitchedBB, *ParentBB, *OldPH); + rewritePHINodesForUnswitchedExitBlock(*UnswitchedBB, *ParentBB, *OldPH, MSSA); else rewritePHINodesForExitAndUnswitchedBlocks(*LoopExitBB, *UnswitchedBB, - *ParentBB, *OldPH); + *ParentBB, *OldPH, MSSA, + MSSAUpdater); // Now we need to update the dominator tree. DT.applyUpdates( @@ -262,6 +287,14 @@ // within the loop with a constant. replaceLoopUsesWithConstant(L, *LoopCond, *Replacement); + // After the dominator tree was updated, update values in phis no longer + // dominated by ParentBB, to use value incoming from NewPh (incoming BB was + // already set to OldPh for those value in the rewritePhi calls above). + if (MSSA) { + MSSAUpdater->updatePhisAfterUnswitch(ParentBB, NewPH, &DT); + MSSA->verifyMemorySSA(); + } + ++NumTrivial; ++NumBranches; return true; @@ -291,7 +324,8 @@ /// in-loop successor, the switch is further simplified to an unconditional /// branch. Still more cleanup can be done with some simplify-cfg like pass. static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, MemorySSA *MSSA, + MemorySSAUpdater *MSSAUpdater) { LLVM_DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n"); Value *LoopCond = SI.getCondition(); @@ -318,6 +352,10 @@ LLVM_DEBUG(dbgs() << " unswitching trivial cases...\n"); + if (MSSA) { + MSSA->verifyMemorySSA(); + } + SmallVector, 4> ExitCases; ExitCases.reserve(ExitCaseIndices.size()); // We walk the case indices backwards so that we remove the last case first @@ -369,7 +407,8 @@ // Split the preheader, so that we know that there is a safe place to insert // the switch. BasicBlock *OldPH = L.getLoopPreheader(); - BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI); + BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI, MSSA, + MSSAUpdater); OldPH->getTerminator()->eraseFromParent(); // Now add the unswitched switch. @@ -389,12 +428,14 @@ if (DefaultExitBB) { if (pred_empty(DefaultExitBB)) { UnswitchedExitBBs.insert(DefaultExitBB); - rewritePHINodesForUnswitchedExitBlock(*DefaultExitBB, *ParentBB, *OldPH); + rewritePHINodesForUnswitchedExitBlock(*DefaultExitBB, *ParentBB, *OldPH, MSSA); } else { auto *SplitBB = - SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI); + SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI, MSSA, + MSSAUpdater); rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB, - *ParentBB, *OldPH); + *ParentBB, *OldPH, MSSA, + MSSAUpdater); DefaultExitBB = SplitExitBBMap[DefaultExitBB] = SplitBB; } } @@ -409,7 +450,7 @@ if (pred_empty(ExitBB)) { // Only rewrite once. if (UnswitchedExitBBs.insert(ExitBB).second) - rewritePHINodesForUnswitchedExitBlock(*ExitBB, *ParentBB, *OldPH); + rewritePHINodesForUnswitchedExitBlock(*ExitBB, *ParentBB, *OldPH, MSSA); continue; } @@ -418,9 +459,10 @@ BasicBlock *&SplitExitBB = SplitExitBBMap[ExitBB]; if (!SplitExitBB) { // If this is the first time we see this, do the split and remember it. - SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); + SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI, MSSA, MSSAUpdater); rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB, - *ParentBB, *OldPH); + *ParentBB, *OldPH, MSSA, + MSSAUpdater); } // Update the case pair to point to the split block. CasePair.second = SplitExitBB; @@ -472,6 +514,12 @@ } DT.applyUpdates(DTUpdates); + // Update MemorySSA only after the dominator tree was updated. + if (MSSA) { + MSSAUpdater->updatePhisAfterUnswitch(ParentBB, NewPH, &DT); + MSSA->verifyMemorySSA(); + } + assert(DT.verify(DominatorTree::VerificationLevel::Fast)); ++NumTrivial; ++NumSwitches; @@ -488,7 +536,8 @@ /// The return value indicates whether anything was unswitched (and therefore /// changed). static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, - LoopInfo &LI) { + LoopInfo &LI, MemorySSA *MSSA, + MemorySSAUpdater *MSSAUpdater) { bool Changed = false; // If loop header has only one reachable successor we should keep looking for @@ -522,7 +571,7 @@ if (isa(SI->getCondition())) return Changed; - if (!unswitchTrivialSwitch(L, *SI, DT, LI)) + if (!unswitchTrivialSwitch(L, *SI, DT, LI, MSSA, MSSAUpdater)) // Coludn't unswitch this one so we're done. return Changed; @@ -554,7 +603,7 @@ // Found a trivial condition candidate: non-foldable conditional branch. If // we fail to unswitch this, we can't do anything else that is trivial. - if (!unswitchTrivialBranch(L, *BI, DT, LI)) + if (!unswitchTrivialBranch(L, *BI, DT, LI, MSSA, MSSAUpdater)) return Changed; // Mark that we managed to unswitch something. @@ -601,7 +650,8 @@ const SmallPtrSetImpl &SkippedLoopAndExitBlocks, ValueToValueMapTy &VMap, SmallVectorImpl &DTUpdates, AssumptionCache &AC, - DominatorTree &DT, LoopInfo &LI) { + DominatorTree &DT, LoopInfo &LI, MemorySSA *MSSA, + MemorySSAUpdater *MSSAUpdater) { SmallVector NewBlocks; NewBlocks.reserve(L.getNumBlocks() + ExitBlocks.size()); @@ -639,7 +689,8 @@ // place to merge the CFG, so split the exit first. This is always safe to // do because there cannot be any non-loop predecessors of a loop exit in // loop simplified form. - auto *MergeBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); + auto *MergeBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI, MSSA, + MSSAUpdater); // Rearrange the names to make it easier to write test cases by having the // exit block carry the suffix rather than the merge block carrying the @@ -1049,10 +1100,15 @@ deleteDeadBlocksFromLoop(Loop &L, const SmallVectorImpl &DeadBlocks, SmallVectorImpl &ExitBlocks, - DominatorTree &DT, LoopInfo &LI) { + DominatorTree &DT, LoopInfo &LI, + MemorySSAUpdater *MSSAUpdater) { SmallPtrSet DeadBlockSet(DeadBlocks.begin(), DeadBlocks.end()); + // Remove all MemorySSA in the dead blocks + if (MSSAUpdater) + MSSAUpdater->removeBlocks(DeadBlockSet); + // Filter out the dead blocks from the exit blocks list so that it can be // used in the caller. llvm::erase_if(ExitBlocks, @@ -1465,12 +1521,16 @@ /// the new loops and no-longer valid loops to the caller. static bool unswitchInvariantBranch( Loop &L, BranchInst &BI, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, + MemorySSA *MSSA, MemorySSAUpdater *MSSAUpdater, AssumptionCache &AC, function_ref)> NonTrivialUnswitchCB) { assert(BI.isConditional() && "Can only unswitch a conditional branch!"); assert(L.isLoopInvariant(BI.getCondition()) && "Can only unswitch an invariant branch condition!"); + if (MSSA) { + MSSA->verifyMemorySSA(); + } + // Constant and BBs tracking the cloned and continuing successor. const int ClonedSucc = 0; auto *ParentBB = BI.getParent(); @@ -1547,7 +1607,8 @@ // between the unswitched versions, and we will have a new preheader for the // original loop. BasicBlock *SplitBB = L.getLoopPreheader(); - BasicBlock *LoopPH = SplitEdge(SplitBB, L.getHeader(), &DT, &LI); + BasicBlock *LoopPH = SplitEdge(SplitBB, L.getHeader(), &DT, &LI, MSSA, + MSSAUpdater); // Keep a mapping for the cloned values. ValueToValueMapTy VMap; @@ -1558,7 +1619,8 @@ // Build the cloned blocks from the loop. auto *ClonedPH = buildClonedLoopBlocks( L, LoopPH, SplitBB, ExitBlocks, ParentBB, UnswitchedSuccBB, - ContinueSuccBB, SkippedLoopAndExitBlocks, VMap, DTUpdates, AC, DT, LI); + ContinueSuccBB, SkippedLoopAndExitBlocks, VMap, DTUpdates, AC, DT, LI, + MSSA, MSSAUpdater); // Remove the parent as a predecessor of the unswitched successor. UnswitchedSuccBB->removePredecessor(ParentBB, /*DontDeleteUselessPHIs*/ true); @@ -1572,7 +1634,7 @@ // Create a new unconditional branch to the continuing block (as opposed to // the one cloned). - BranchInst::Create(ContinueSuccBB, ParentBB); + BranchInst *GoodBranch = BranchInst::Create(ContinueSuccBB, ParentBB); // Before we update the dominator tree, collect the dead blocks if we're going // to end up deleting the unswitched successor. @@ -1597,6 +1659,37 @@ } } + if (MSSA) { + GoodBranch->eraseFromParent(); + // Create a new branch to facilitate MemorySSA update. + BranchInst *RevBranch = + BranchInst::Create(UnswitchedSuccBB, ContinueSuccBB, + UndefValue::get( + llvm::Type::getInt1Ty(ParentBB->getContext())), + ParentBB); + + //DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); + //DT.applyUpdates(DTUpdates); + + SmallPtrSet subLoopHeaders; + subLoopHeaders.insert(L.getHeader()); + for (Loop *InnerL : L.getSubLoops()) + subLoopHeaders.insert(InnerL->getHeader()); + + // Note: this will work with a partially invalid DT, but needs the reverse + // branch set for correct updates. + MSSAUpdater->updateForClonedLoop(L.getHeader(), ExitBlocks, VMap, + subLoopHeaders, &DT, + /*IgnoreIncomingWithNoClones=*/true); + + // Now remove the branch and create the right one below. + RevBranch->eraseFromParent(); + BranchInst::Create(ContinueSuccBB, ParentBB); + //DTUpdates.clear(); + //DTUpdates.push_back({DominatorTree::Delete, ParentBB, UnswitchedSuccBB}); + //DT.applyUpdates(DTUpdates); + } + // Add the remaining edges to our updates and apply them to get an up-to-date // dominator tree. Note that this will cause the dead blocks above to be // unreachable and no longer in the dominator tree. @@ -1614,7 +1707,11 @@ // Delete anything that was made dead in the original loop due to // unswitching. if (!DeadBlocks.empty()) - deleteDeadBlocksFromLoop(L, DeadBlocks, ExitBlocks, DT, LI); + deleteDeadBlocksFromLoop(L, DeadBlocks, ExitBlocks, DT, LI, MSSAUpdater); + + if (MSSA) { + MSSA->verifyMemorySSA(); + } SmallVector HoistedLoops; bool IsStillLoop = rebuildLoopAfterUnswitch(L, ExitBlocks, LI, HoistedLoops); @@ -1708,6 +1805,11 @@ SibLoops.push_back(UpdatedL); NonTrivialUnswitchCB(IsStillLoop, SibLoops); + if (MSSA) { + MSSA->verifyMemorySSA(); + } + + ++NumBranches; return true; } @@ -1753,7 +1855,8 @@ /// well by cloning the loop if the result is small enough. static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - TargetTransformInfo &TTI, bool NonTrivial, + TargetTransformInfo &TTI, MemorySSA *MSSA, + MemorySSAUpdater *MSSAUpdater, bool NonTrivial, function_ref)> NonTrivialUnswitchCB) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -1764,7 +1867,7 @@ return false; // Try trivial unswitch first before loop over other basic blocks in the loop. - Changed |= unswitchAllTrivialConditions(L, DT, LI); + Changed |= unswitchAllTrivialConditions(L, DT, LI, MSSA, MSSAUpdater); // If we're not doing non-trivial unswitching, we're done. We both accept // a parameter but also check a local flag that can be used for testing @@ -1903,7 +2006,8 @@ << BestUnswitchCost << ") branch: " << *BestUnswitchTI << "\n"); Changed |= unswitchInvariantBranch(L, cast(*BestUnswitchTI), DT, - LI, AC, NonTrivialUnswitchCB); + LI, MSSA, MSSAUpdater, AC, + NonTrivialUnswitchCB); } else { LLVM_DEBUG(dbgs() << "Cannot unswitch, lowest cost found: " << BestUnswitchCost << "\n"); @@ -1938,10 +2042,17 @@ U.markLoopAsDeleted(L, LoopName); }; - if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, - NonTrivialUnswitchCB)) + std::unique_ptr MSSAUpdater; + if (AR.MSSA) + MSSAUpdater = make_unique(AR.MSSA); + if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, AR.MSSA, MSSAUpdater.get(), + NonTrivial, NonTrivialUnswitchCB)) return PreservedAnalyses::all(); + if (AR.MSSA) { + AR.MSSA->verifyMemorySSA(); + } + // Historically this pass has had issues with the dominator tree so verify it // in asserts builds. assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast)); @@ -1967,6 +2078,8 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.addRequired(); + if (EnableMSSALoopDependency) + AU.addRequired(); getLoopAnalysisUsage(AU); } }; @@ -1986,6 +2099,12 @@ auto &LI = getAnalysis().getLoopInfo(); auto &AC = getAnalysis().getAssumptionCache(F); auto &TTI = getAnalysis().getTTI(F); + MemorySSA *MSSA = nullptr; + std::unique_ptr MSSAUpdater; + if (EnableMSSALoopDependency) { + MSSA = &getAnalysis().getMSSA(); + MSSAUpdater = make_unique(MSSA); + } auto NonTrivialUnswitchCB = [&L, &LPM](bool CurrentLoopValid, ArrayRef NewLoops) { @@ -2002,8 +2121,22 @@ LPM.markLoopAsDeleted(*L); }; + if (MSSA) { + MSSA->verifyMemorySSA(); + } + bool Changed = - unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, NonTrivialUnswitchCB); + unswitchLoop(*L, DT, LI, AC, TTI, MSSA, MSSAUpdater.get(), NonTrivial, + NonTrivialUnswitchCB); + + if (MSSA) { + // TODO: Compare performance of update vs rebuild MSSA. Hack: + // auto &AA = getAnalysis().getAAResults(); + // MSSA->~MemorySSA(); + // new (MSSA) MemorySSA(*F, &AA, DT); + + MSSA->verifyMemorySSA(); + } // If anything was unswitched, also clear any cached information about this // loop. @@ -2023,6 +2156,7 @@ INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", "Simple unswitch loops", false, false)