Index: lib/Transforms/Scalar/LoopUnswitch.cpp =================================================================== --- lib/Transforms/Scalar/LoopUnswitch.cpp +++ lib/Transforms/Scalar/LoopUnswitch.cpp @@ -37,6 +37,8 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/Utils/Local.h" @@ -65,6 +67,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopUtils.h" @@ -180,6 +183,8 @@ Loop *currentLoop = nullptr; DominatorTree *DT = nullptr; + MemorySSA *MSSA = nullptr; + std::unique_ptr MSSAUpdater; BasicBlock *loopHeader = nullptr; BasicBlock *loopPreheader = nullptr; @@ -214,6 +219,8 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.addRequired(); + if (EnableMSSALoopDependency) + AU.addRequired(); if (hasBranchDivergence) AU.addRequired(); getLoopAnalysisUsage(AU); @@ -383,6 +390,7 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis) INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops", false, false) @@ -515,6 +523,10 @@ LI = &getAnalysis().getLoopInfo(); LPM = &LPM_Ref; DT = &getAnalysis().getDomTree(); + if (EnableMSSALoopDependency) { + MSSA = &getAnalysis().getMSSA(); + MSSAUpdater = make_unique(MSSA); + } currentLoop = L; Function *F = currentLoop->getHeader()->getParent(); @@ -522,6 +534,9 @@ if (SanitizeMemory) computeLoopSafetyInfo(&SafetyInfo, L); + if (MSSA) + MSSA->verifyMemorySSA(); + bool Changed = false; do { assert(currentLoop->isLCSSAForm(*DT)); @@ -529,6 +544,14 @@ Changed |= processCurrentLoop(); } while(redoLoop); + if (MSSA) { + // TODO: Compare performance of update vs rebuild MSSA. Hack: + // auto &AA = getAnalysis().getAAResults(); + // MSSA->~MemorySSA(); + // new (MSSA) MemorySSA(*F, &AA, DT); + MSSA->verifyMemorySSA(); + } + return Changed; } @@ -979,7 +1002,8 @@ // First step, split the preheader, so that we know that there is a safe place // to insert the conditional branch. We will change loopPreheader to have a // conditional branch on Cond. - BasicBlock *NewPH = SplitEdge(loopPreheader, loopHeader, DT, LI); + BasicBlock *NewPH = + SplitEdge(loopPreheader, loopHeader, DT, LI, MSSA, MSSAUpdater.get()); // Now that we have a place to insert the conditional branch, create a place // to branch to: this is the exit block out of the loop that we should @@ -990,7 +1014,8 @@ // without actually branching to it (the exit block should be dominated by the // loop header, not the preheader). assert(!L->contains(ExitBlock) && "Exit block is in the loop?"); - BasicBlock *NewExit = SplitBlock(ExitBlock, &ExitBlock->front(), DT, LI); + BasicBlock *NewExit = SplitBlock(ExitBlock, &ExitBlock->front(), DT, LI, MSSA, + MSSAUpdater.get()); // Okay, now we have a position to branch from and a position to branch to, // insert the new conditional branch. @@ -1003,6 +1028,9 @@ // Delete it, as it is no longer needed. delete OldBranch; + if (MSSA) + MSSAUpdater->updatePhisWhenAddingBBPredecessors(NewExit, ExitBlock, DT); + // We need to reprocess this loop, it could be unswitched again. redoLoop = true; @@ -1185,8 +1213,8 @@ // Although SplitBlockPredecessors doesn't preserve loop-simplify in // general, if we call it on all predecessors of all exits then it does. - SplitBlockPredecessors(ExitBlock, Preds, ".us-lcssa", DT, LI, nullptr, - nullptr, /*PreserveLCSSA*/ true); + SplitBlockPredecessors(ExitBlock, Preds, ".us-lcssa", DT, LI, MSSA, + MSSAUpdater.get(), /*PreserveLCSSA*/ true); } } @@ -1209,7 +1237,8 @@ // First step, split the preheader and exit blocks, and add these blocks to // the LoopBlocks list. - BasicBlock *NewPreheader = SplitEdge(loopPreheader, loopHeader, DT, LI); + BasicBlock *NewPreheader = + SplitEdge(loopPreheader, loopHeader, DT, LI, MSSA, MSSAUpdater.get()); LoopBlocks.push_back(NewPreheader); // We want the loop to come after the preheader, but before the exit blocks. @@ -1330,6 +1359,18 @@ // iteration. WeakTrackingVH LICHandle(LIC); + if (MSSA) { + // Update MemorySSA after cloning, and before splitting to unreachables, + // since that invalidates the 1:1 mapping of clones in VMap. + SmallPtrSet subLoopHeaders; + subLoopHeaders.insert(loopHeader); + for (Loop *InnerL : L->getSubLoops()) + subLoopHeaders.insert(InnerL->getHeader()); + MSSAUpdater->updateForClonedLoop(loopHeader, ExitBlocks, VMap, + subLoopHeaders, DT); + MSSA->verifyMemorySSA(); + } + // Now we rewrite the original code to know that the condition is true and the // new code to know that the condition is false. RewriteLoopBodyWithConditionConstant(L, LIC, Val, false); @@ -1340,6 +1381,10 @@ if (!LoopProcessWorklist.empty() && LoopProcessWorklist.back() == NewLoop && LICHandle && !isa(LICHandle)) RewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val, true); + + if (MSSA) { + MSSA->verifyMemorySSA(); + } } /// Remove all instances of I from the worklist vector specified. @@ -1478,7 +1523,7 @@ // and hooked up so as to preserve the loop structure, because // trying to update it is complicated. So instead we preserve the // loop structure and put the block on a dead code path. - SplitEdge(Switch, SISucc, DT, LI); + SplitEdge(Switch, SISucc, DT, LI, MSSA, MSSAUpdater.get()); // Compute the successors instead of relying on the return value // of SplitEdge, since it may have split the switch successor // after PHI nodes. @@ -1532,6 +1577,9 @@ Worklist.push_back(Use); LPM->deleteSimpleAnalysisValue(I, L); RemoveFromWorklist(I, Worklist); + if (MSSA) + if (MemoryAccess *MA = MSSA->getMemoryAccess(I)) + MSSAUpdater->removeMemoryAccess(MA); I->eraseFromParent(); ++NumSimplify; continue; @@ -1571,6 +1619,8 @@ // Move all of the successor contents from Succ to Pred. Pred->getInstList().splice(BI->getIterator(), Succ->getInstList(), Succ->begin(), Succ->end()); + if (MSSA) + MSSAUpdater->moveAllAfterMergeBlocks(Succ, Pred, BI); LPM->deleteSimpleAnalysisValue(BI, L); RemoveFromWorklist(BI, Worklist); BI->eraseFromParent();