Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -611,7 +612,8 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, - const TargetTransformInfo *TTI, LPMUpdater *U) { + const TargetTransformInfo *TTI, LPMUpdater *U, + MemorySSAUpdater *MSSAU) { Function *F = FI.OuterLoop->getHeader()->getParent(); LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n"); { @@ -647,7 +649,11 @@ BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock(); InnerExitingBlock->getTerminator()->eraseFromParent(); BranchInst::Create(InnerExitBlock, InnerExitingBlock); + + // Update the DomTree and MemorySSA. DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader()); + if (MSSAU) + MSSAU->removeEdge(InnerExitingBlock, FI.InnerLoop->getHeader()); // Replace all uses of the polynomial calculated from the two induction // variables with the one new one. @@ -744,7 +750,8 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, - const TargetTransformInfo *TTI, LPMUpdater *U) { + const TargetTransformInfo *TTI, LPMUpdater *U, + MemorySSAUpdater *MSSAU) { LLVM_DEBUG( dbgs() << "Loop flattening running on outer loop " << FI.OuterLoop->getHeader()->getName() << " and inner loop " @@ -773,7 +780,7 @@ // If we have widened and can perform the transformation, do that here. if (CanFlatten) - return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U); + return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); // Otherwise, if we haven't widened the IV, check if the new iteration // variable might overflow. In this case, we need to version the loop, and @@ -791,18 +798,19 @@ } LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); - return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U); + return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); } bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U) { + AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U, + MemorySSAUpdater *MSSAU) { bool Changed = false; for (Loop *InnerLoop : LN.getLoops()) { auto *OuterLoop = InnerLoop->getParentLoop(); if (!OuterLoop) continue; FlattenInfo FI(OuterLoop, InnerLoop); - Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U); + Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU); } return Changed; } @@ -813,16 +821,27 @@ bool Changed = false; + MemorySSAUpdater MSSAU = MemorySSAUpdater(AR.MSSA); + if (AR.MSSA && VerifyMemorySSA) + AR.MSSA->verifyMemorySSA(); + // The loop flattening pass requires loops to be // in simplified form, and also needs LCSSA. Running // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. - Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U); + Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U, + AR.MSSA ? &MSSAU : nullptr); if (!Changed) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + if (AR.MSSA && VerifyMemorySSA) + AR.MSSA->verifyMemorySSA(); + + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve(); + return PA; } namespace { @@ -864,10 +883,13 @@ auto &TTIP = getAnalysis(); auto *TTI = &TTIP.getTTI(F); auto *AC = &getAnalysis().getAssumptionCache(F); + MemorySSA *MSSA = &getAnalysis().getMSSA(); + MemorySSAUpdater MSSAU(MSSA); + bool Changed = false; for (Loop *L : *LI) { auto LN = LoopNest::getLoopNest(*L, *SE); - Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr); + Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr, &MSSAU); } return Changed; }