Index: include/llvm/Transforms/Utils/LoopRotationUtils.h =================================================================== --- include/llvm/Transforms/Utils/LoopRotationUtils.h +++ include/llvm/Transforms/Utils/LoopRotationUtils.h @@ -20,6 +20,7 @@ class DominatorTree; class Loop; class LoopInfo; +class MemorySSAUpdater; class ScalarEvolution; struct SimplifyQuery; class TargetTransformInfo; @@ -32,8 +33,8 @@ /// LoopRotation. If it is true, the profitability heuristic will be ignored. bool LoopRotation(Loop *L, LoopInfo *LI, const TargetTransformInfo *TTI, AssumptionCache *AC, DominatorTree *DT, ScalarEvolution *SE, - const SimplifyQuery &SQ, bool RotationOnly, - unsigned Threshold, bool IsUtilMode); + MemorySSAUpdater *MSSAU, const SimplifyQuery &SQ, + bool RotationOnly, unsigned Threshold, bool IsUtilMode); } // namespace llvm Index: lib/Transforms/Scalar/LoopRotation.cpp =================================================================== --- lib/Transforms/Scalar/LoopRotation.cpp +++ lib/Transforms/Scalar/LoopRotation.cpp @@ -15,6 +15,8 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Support/Debug.h" @@ -40,12 +42,19 @@ const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); const SimplifyQuery SQ = getBestSimplifyQuery(AR, DL); - bool Changed = LoopRotation(&L, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, SQ, - false, Threshold, false); + std::unique_ptr MSSAU; + if (AR.MSSA) + MSSAU = make_unique(AR.MSSA); + bool Changed = LoopRotation(&L, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, + MSSAU.get(), SQ, false, Threshold, false); if (!Changed) return PreservedAnalyses::all(); + if (AR.MSSA) { + AR.MSSA->verifyMemorySSA(); + } + return getLoopPassPreservedAnalyses(); } @@ -68,6 +77,10 @@ void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired(); AU.addRequired(); + if (EnableMSSALoopDependency) { + AU.addRequired(); + AU.addPreserved(); + } getLoopAnalysisUsage(AU); } @@ -84,8 +97,13 @@ auto *SEWP = getAnalysisIfAvailable(); auto *SE = SEWP ? &SEWP->getSE() : nullptr; const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); - return LoopRotation(L, LI, TTI, AC, DT, SE, SQ, false, MaxHeaderSize, - false); + std::unique_ptr MSSAU; + if (EnableMSSALoopDependency) { + MemorySSA *MSSA = &getAnalysis().getMSSA(); + MSSAU = make_unique(MSSA); + } + return LoopRotation(L, LI, TTI, AC, DT, SE, MSSAU.get(), SQ, false, + MaxHeaderSize, false); } }; } @@ -96,6 +114,7 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_END(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", false, false) Index: lib/Transforms/Utils/LoopRotationUtils.cpp =================================================================== --- lib/Transforms/Utils/LoopRotationUtils.cpp +++ lib/Transforms/Utils/LoopRotationUtils.cpp @@ -20,6 +20,8 @@ #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -54,6 +56,7 @@ AssumptionCache *AC; DominatorTree *DT; ScalarEvolution *SE; + MemorySSAUpdater *MSSAU; const SimplifyQuery &SQ; bool RotationOnly; bool IsUtilMode; @@ -61,10 +64,11 @@ public: LoopRotate(unsigned MaxHeaderSize, LoopInfo *LI, const TargetTransformInfo *TTI, AssumptionCache *AC, - DominatorTree *DT, ScalarEvolution *SE, const SimplifyQuery &SQ, - bool RotationOnly, bool IsUtilMode) + DominatorTree *DT, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + const SimplifyQuery &SQ, bool RotationOnly, bool IsUtilMode) : MaxHeaderSize(MaxHeaderSize), LI(LI), TTI(TTI), AC(AC), DT(DT), SE(SE), - SQ(SQ), RotationOnly(RotationOnly), IsUtilMode(IsUtilMode) {} + MSSAU(MSSAU), SQ(SQ), RotationOnly(RotationOnly), + IsUtilMode(IsUtilMode) {} bool processLoop(Loop *L); private: @@ -269,6 +273,8 @@ SE->forgetTopmostLoop(L); LLVM_DEBUG(dbgs() << "LoopRotation: rotating "; L->dump()); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); // Find new Loop header. NewHeader is a Header's one and only successor // that is inside loop. Header's other successor is outside the @@ -385,7 +391,6 @@ // remove the corresponding incoming values from the PHI nodes in OrigHeader. LoopEntryBranch->eraseFromParent(); - SmallVector InsertedPHIs; // If there were any uses of instructions in the duplicated block outside the // loop, update them, inserting PHI nodes as required @@ -411,6 +416,28 @@ Updates.push_back({DominatorTree::Insert, OrigPreheader, NewHeader}); Updates.push_back({DominatorTree::Delete, OrigPreheader, OrigHeader}); DT->applyUpdates(Updates); + + if (MSSAU) { + ValueMap[OrigHeader] = OrigPreheader; + // Above RewriteUses call may replace map keys (cloned instructions) + // with phis. Re-add mapping (ClonedInstruction, CloneOfInstruction). + // We do not delete existing entry of (PhiNode, CloneOfInstruction), + // since the map is not used further. Ideally we'd set + // Config::FollowRAW = false for ValueMap before the above call to + // avoid the replacement in the first place. + for (auto *PNI : InsertedPHIs) + if (Instruction *NewInsn = dyn_cast_or_null(ValueMap.lookup(PNI))) { + Instruction *OldInsn = cast(PNI->getIncomingValue(0)); + if (OldInsn == NewInsn) + OldInsn = cast(PNI->getIncomingValue(1)); + ValueMap[OldInsn] = NewInsn; + } + + MSSAU->updateForClonedBlockIntoPred(OrigHeader, OrigPreheader, ValueMap); + MSSAU->applyUpdates(Updates, *DT); + if (VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + } } // At this point, we've finished our major CFG changes. As part of cloning @@ -433,7 +460,7 @@ // Split the edge to form a real preheader. BasicBlock *NewPH = SplitCriticalEdge( OrigPreheader, NewHeader, - CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); + CriticalEdgeSplittingOptions(DT, LI, MSSAU).setPreserveLCSSA()); NewPH->setName(NewHeader->getName() + ".lr.ph"); // Preserve canonical loop form, which means that 'Exit' should have only @@ -452,7 +479,7 @@ SplitLatchEdge |= L->getLoopLatch() == ExitPred; BasicBlock *ExitSplit = SplitCriticalEdge( ExitPred, Exit, - CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); + CriticalEdgeSplittingOptions(DT, LI, MSSAU).setPreserveLCSSA()); ExitSplit->moveBefore(Exit); } assert(SplitLatchEdge && @@ -467,17 +494,27 @@ // With our CFG finalized, update DomTree if it is available. if (DT) DT->deleteEdge(OrigPreheader, Exit); + + // Update MSSA too, if available. + if (MSSAU) + MSSAU->removeEdge(OrigPreheader, Exit); } assert(L->getLoopPreheader() && "Invalid loop preheader after loop rotation"); assert(L->getLoopLatch() && "Invalid loop latch after loop rotation"); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + // Now that the CFG and DomTree are in a consistent state again, try to merge // the OrigHeader block into OrigLatch. This will succeed if they are // connected by an unconditional branch. This is just a cleanup so the // emitted code isn't too gross in this common case. DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); - MergeBlockIntoPredecessor(OrigHeader, &DTU, LI); + MergeBlockIntoPredecessor(OrigHeader, &DTU, LI, MSSAU); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); LLVM_DEBUG(dbgs() << "LoopRotation: into "; L->dump()); @@ -586,9 +623,14 @@ << LastExit->getName() << "\n"); // Hoist the instructions from Latch into LastExit. + Instruction *FirstLatchInst = &*(Latch->begin()); LastExit->getInstList().splice(BI->getIterator(), Latch->getInstList(), Latch->begin(), Jmp->getIterator()); + // Update MemorySSA + if (MSSAU) + MSSAU->moveAllAfterMergeBlocks(Latch, LastExit, FirstLatchInst); + unsigned FallThruPath = BI->getSuccessor(0) == Latch ? 0 : 1; BasicBlock *Header = Jmp->getSuccessor(0); assert(Header == L->getHeader() && "expected a backward branch"); @@ -604,6 +646,10 @@ if (DT) DT->eraseNode(Latch); Latch->eraseFromParent(); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + return true; } @@ -636,11 +682,16 @@ /// The utility to convert a loop into a loop with bottom test. bool llvm::LoopRotation(Loop *L, LoopInfo *LI, const TargetTransformInfo *TTI, AssumptionCache *AC, DominatorTree *DT, - ScalarEvolution *SE, const SimplifyQuery &SQ, - bool RotationOnly = true, + ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + const SimplifyQuery &SQ, bool RotationOnly = true, unsigned Threshold = unsigned(-1), bool IsUtilMode = true) { - LoopRotate LR(Threshold, LI, TTI, AC, DT, SE, SQ, RotationOnly, IsUtilMode); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); + LoopRotate LR(Threshold, LI, TTI, AC, DT, SE, MSSAU, SQ, RotationOnly, + IsUtilMode); + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); return LR.processLoop(L); } Index: test/Transforms/LoopRotate/preserve-scev.ll =================================================================== --- test/Transforms/LoopRotate/preserve-scev.ll +++ test/Transforms/LoopRotate/preserve-scev.ll @@ -1,27 +1,47 @@ ; RUN: opt < %s -loop-rotate -loop-reduce -verify-dom-info -verify-loop-info -disable-output -define fastcc void @foo() nounwind { +define fastcc void @foo(i32* %A, i64 %i) nounwind { BB: br label %BB1 BB1: ; preds = %BB19, %BB + %tttmp1 = getelementptr i32, i32* %A, i64 %i + %tttmp2 = load i32, i32* %tttmp1 + %tttmp3 = add i32 %tttmp2, 1 + store i32 %tttmp3, i32* %tttmp1 br label %BB4 BB2: ; preds = %BB4 %tmp = bitcast i32 undef to i32 ; [#uses=1] + %tttmp7 = getelementptr i32, i32* %A, i64 %i + %tttmp8 = load i32, i32* %tttmp7 + %tttmp9 = add i32 %tttmp8, 3 + store i32 %tttmp9, i32* %tttmp7 br label %BB4 -BB4: ; preds = %BB3, %BB1 +BB4: ; preds = %BB2, %BB1 %tmp5 = phi i32 [ undef, %BB1 ], [ %tmp, %BB2 ] ; [#uses=1] + %tttmp4 = getelementptr i32, i32* %A, i64 %i + %tttmp5 = load i32, i32* %tttmp4 + %tttmp6 = add i32 %tttmp5, 3 + store i32 %tttmp6, i32* %tttmp4 br i1 false, label %BB8, label %BB2 BB8: ; preds = %BB6 %tmp7 = bitcast i32 %tmp5 to i32 ; [#uses=2] + %tttmp10 = getelementptr i32, i32* %A, i64 %i + %tttmp11 = load i32, i32* %tttmp10 + %tttmp12 = add i32 %tttmp11, 3 + store i32 %tttmp12, i32* %tttmp10 br i1 false, label %BB9, label %BB13 BB9: ; preds = %BB12, %BB8 %tmp10 = phi i32 [ %tmp11, %BB12 ], [ %tmp7, %BB8 ] ; [#uses=2] %tmp11 = add i32 %tmp10, 1 ; [#uses=1] + %tttmp13 = getelementptr i32, i32* %A, i64 %i + %tttmp14 = load i32, i32* %tttmp13 + %tttmp15 = add i32 %tttmp14, 3 + store i32 %tttmp15, i32* %tttmp13 br label %BB12 BB12: ; preds = %BB9 @@ -29,16 +49,28 @@ BB13: ; preds = %BB15, %BB8 %tmp14 = phi i32 [ %tmp16, %BB15 ], [ %tmp7, %BB8 ] ; [#uses=1] + %tttmp16 = getelementptr i32, i32* %A, i64 %i + %tttmp17 = load i32, i32* %tttmp16 + %tttmp18 = add i32 %tttmp17, 3 + store i32 %tttmp18, i32* %tttmp16 br label %BB15 BB15: ; preds = %BB13 %tmp16 = add i32 %tmp14, -1 ; [#uses=1] + %tttmp19 = getelementptr i32, i32* %A, i64 %i + %tttmp20 = load i32, i32* %tttmp19 + %tttmp21 = add i32 %tttmp20, 3 + store i32 %tttmp21, i32* %tttmp19 br i1 false, label %BB13, label %BB18 BB17: ; preds = %BB12 br label %BB19 BB18: ; preds = %BB15 + %tttmp22 = getelementptr i32, i32* %A, i64 %i + %tttmp23 = load i32, i32* %tttmp22 + %tttmp24 = add i32 %tttmp23, 3 + store i32 %tttmp24, i32* %tttmp22 br label %BB19 BB19: ; preds = %BB18, %BB17