diff --git a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h
--- a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h
+++ b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h
@@ -24,7 +24,8 @@
 public:
   LoopFlattenPass() = default;
 
-  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+  PreservedAnalyses run(LoopNest &LN, LoopAnalysisManager &LAM,
+                        LoopStandardAnalysisResults &AR, LPMUpdater &U);
 };
 
 } // end namespace llvm
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -621,7 +621,7 @@
   FPM.addPass(SimplifyCFGPass());
   FPM.addPass(InstCombinePass());
   if (EnableLoopFlatten)
-    FPM.addPass(LoopFlattenPass());
+    FPM.addPass(createFunctionToLoopPassAdaptor(LoopFlattenPass()));
   // The loop passes in LPM2 (LoopFullUnrollPass) do not preserve MemorySSA.
   // *All* loop passes must preserve it, in order to be able to use it.
   FPM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM2),
@@ -796,7 +796,7 @@
   FPM.addPass(SimplifyCFGPass());
   FPM.addPass(InstCombinePass());
   if (EnableLoopFlatten)
-    FPM.addPass(LoopFlattenPass());
+    FPM.addPass(createFunctionToLoopPassAdaptor(LoopFlattenPass()));
   // The loop passes in LPM2 (LoopIdiomRecognizePass, IndVarSimplifyPass,
   // LoopDeletionPass and LoopFullUnrollPass) do not preserve MemorySSA.
   // *All* loop passes must preserve it, in order to be able to use it.
@@ -1849,7 +1849,7 @@
 
   // More loops are countable; try to optimize them.
   if (EnableLoopFlatten && Level.getSpeedupLevel() > 1)
-    MainFPM.addPass(LoopFlattenPass());
+    MainFPM.addPass(createFunctionToLoopPassAdaptor(LoopFlattenPass()));
 
   if (EnableConstraintElimination)
     MainFPM.addPass(ConstraintEliminationPass());
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -249,7 +249,6 @@
 FUNCTION_PASS("loop-simplify", LoopSimplifyPass())
 FUNCTION_PASS("loop-sink", LoopSinkPass())
 FUNCTION_PASS("loop-unroll-and-jam", LoopUnrollAndJamPass())
-FUNCTION_PASS("loop-flatten", LoopFlattenPass())
 FUNCTION_PASS("lowerinvoke", LowerInvokePass())
 FUNCTION_PASS("lowerswitch", LowerSwitchPass())
 FUNCTION_PASS("mem2reg", PromotePass())
@@ -390,6 +389,7 @@
 LOOP_PASS("dot-ddg", DDGDotPrinterPass())
 LOOP_PASS("invalidate<all>", InvalidateAllAnalysesPass())
 LOOP_PASS("licm", LICMPass())
+LOOP_PASS("loop-flatten", LoopFlattenPass())
 LOOP_PASS("loop-idiom", LoopIdiomRecognizePass())
 LOOP_PASS("loop-instsimplify", LoopInstSimplifyPass())
 LOOP_PASS("loop-interchange", LoopInterchangePass())
diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
--- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp
@@ -28,7 +28,9 @@
 
 #include "llvm/Transforms/Scalar/LoopFlatten.h"
 #include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/LoopAnalysisManager.h"
 #include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -88,7 +90,7 @@
   // Whether this holds the flatten info before or after widening.
   bool Widened = false;
 
-  FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
+  FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};
 };
 
 // Finds the induction variable, increment and limit for a simple loop that we
@@ -343,8 +345,7 @@
   // transformation wouldn't be profitable.
 
   Value *InnerLimit = FI.InnerLimit;
-  if (FI.Widened &&
-      (isa<SExtInst>(InnerLimit) || isa<ZExtInst>(InnerLimit)))
+  if (FI.Widened && (isa<SExtInst>(InnerLimit) || isa<ZExtInst>(InnerLimit)))
     InnerLimit = cast<Instruction>(InnerLimit)->getOperand(0);
 
   // Check that all uses of the inner loop's induction variable match the
@@ -354,8 +355,8 @@
     if (U == FI.InnerIncrement)
       continue;
 
-    // After widening the IVs, a trunc instruction might have been introduced, so
-    // look through truncs.
+    // After widening the IVs, a trunc instruction might have been introduced,
+    // so look through truncs.
     if (isa<TruncInst>(U)) {
       if (!U->hasOneUse())
         return false;
@@ -373,11 +374,11 @@
 
     // Matches the same pattern as above, except it also looks for truncs
     // on the phi, which can be the result of widening the induction variables.
-    bool IsAddTrunc = match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)),
-                                       m_Value(MatchedMul))) &&
-                      match(MatchedMul,
-                            m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
-                            m_Value(MatchedItCount)));
+    bool IsAddTrunc =
+        match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)),
+                         m_Value(MatchedMul))) &&
+        match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)),
+                                  m_Value(MatchedItCount)));
 
     if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) {
       LLVM_DEBUG(dbgs() << "Use is optimisable\n");
@@ -395,8 +396,9 @@
     if (U == FI.OuterIncrement)
       continue;
 
-    auto IsValidOuterPHIUses = [&] (User *U) -> bool {
-      LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump());
+    auto IsValidOuterPHIUses = [&](User *U) -> bool {
+      LLVM_DEBUG(dbgs() << "Found use of outer induction variable: ";
+                 U->dump());
       if (!ValidOuterPHIUses.count(U)) {
         LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n");
         return false;
@@ -420,7 +422,8 @@
   LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n";
              dbgs() << "Found " << FI.LinearIVUses.size()
                     << " value(s) that can be replaced:\n";
-             for (Value *V : FI.LinearIVUses) {
+             for (Value *V
+                  : FI.LinearIVUses) {
                dbgs() << "  ";
                V->dump();
              });
@@ -472,11 +475,13 @@
                                ScalarEvolution *SE, AssumptionCache *AC,
                                const TargetTransformInfo *TTI) {
   SmallPtrSet<Instruction *, 8> IterationInstructions;
-  if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI,
-                          FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE))
+  if (!findLoopComponents(FI.InnerLoop, IterationInstructions,
+                          FI.InnerInductionPHI, FI.InnerLimit,
+                          FI.InnerIncrement, FI.InnerBranch, SE))
     return false;
-  if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI,
-                          FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE))
+  if (!findLoopComponents(FI.OuterLoop, IterationInstructions,
+                          FI.OuterInductionPHI, FI.OuterLimit,
+                          FI.OuterIncrement, FI.OuterBranch, SE))
     return false;
 
   // Both of the loop limit values must be invariant in the outer loop
@@ -519,16 +524,17 @@
   LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n");
   {
     using namespace ore;
-    OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(),
+    OptimizationRemark Remark(DEBUG_TYPE, "Flattened",
+                              FI.InnerLoop->getStartLoc(),
                               FI.InnerLoop->getHeader());
     OptimizationRemarkEmitter ORE(F);
     Remark << "Flattened into outer loop";
     ORE.emit(Remark);
   }
 
-  Value *NewTripCount =
-      BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount",
-                                FI.OuterLoop->getLoopPreheader()->getTerminator());
+  Value *NewTripCount = BinaryOperator::CreateMul(
+      FI.InnerLimit, FI.OuterLimit, "flatten.tripcount",
+      FI.OuterLoop->getLoopPreheader()->getTerminator());
   LLVM_DEBUG(dbgs() << "Created new trip count in preheader: ";
              NewTripCount->dump());
 
@@ -561,8 +567,8 @@
       OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
                                        "flatten.trunciv");
 
-    LLVM_DEBUG(dbgs() << "Replacing: "; V->dump();
-               dbgs() << "with:      "; OuterValue->dump());
+    LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with:      ";
+               OuterValue->dump());
     V->replaceAllUsesWith(OuterValue);
   }
 
@@ -595,7 +601,8 @@
   // (OuterLimit * InnerLimit) as the new trip count is safe.
   if (InnerType != OuterType ||
       InnerType->getScalarSizeInBits() >= MaxLegalSize ||
-      MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) {
+      MaxLegalType->getScalarSizeInBits() <
+          InnerType->getScalarSizeInBits() * 2) {
     LLVM_DEBUG(dbgs() << "Can't widen the IV\n");
     return false;
   }
@@ -603,15 +610,15 @@
   SCEVExpander Rewriter(*SE, DL, "loopflatten");
   SmallVector<WideIVInfo, 2> WideIVs;
   SmallVector<WeakTrackingVH, 4> DeadInsts;
-  WideIVs.push_back( {FI.InnerInductionPHI, MaxLegalType, false });
-  WideIVs.push_back( {FI.OuterInductionPHI, MaxLegalType, false });
+  WideIVs.push_back({FI.InnerInductionPHI, MaxLegalType, false});
+  WideIVs.push_back({FI.OuterInductionPHI, MaxLegalType, false});
   unsigned ElimExt = 0;
   unsigned Widened = 0;
 
   for (const auto &WideIV : WideIVs) {
-    PHINode *WidePhi = createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts,
-                                    ElimExt, Widened, true /* HasGuards */,
-                                    true /* UsePostIncrementRanges */);
+    PHINode *WidePhi =
+        createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened,
+                     true /* HasGuards */, true /* UsePostIncrementRanges */);
     if (!WidePhi)
       return false;
     LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump());
@@ -658,10 +665,10 @@
   return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI);
 }
 
-bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
+bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
              AssumptionCache *AC, TargetTransformInfo *TTI) {
   bool Changed = false;
-  for (auto *InnerLoop : LI->getLoopsInPreorder()) {
+  for (Loop *InnerLoop : LN.getLoops()) {
     auto *OuterLoop = InnerLoop->getParentLoop();
     if (!OuterLoop)
       continue;
@@ -671,13 +678,9 @@
   return Changed;
 }
 
-PreservedAnalyses LoopFlattenPass::run(Function &F,
-                                       FunctionAnalysisManager &AM) {
-  auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
-  auto *LI = &AM.getResult<LoopAnalysis>(F);
-  auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
-  auto *AC = &AM.getResult<AssumptionAnalysis>(F);
-  auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
+PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
+                                       LoopStandardAnalysisResults &AR,
+                                       LPMUpdater &U) {
 
   bool Changed = false;
 
@@ -685,15 +688,7 @@
   // in simplified form, and also needs LCSSA. Running
   // this pass will simplify all loops that contain inner loops,
   // regardless of whether anything ends up being flattened.
-  for (const auto &L : *LI) {
-    if (L->isInnermost())
-      continue;
-    Changed |=
-        simplifyLoop(L, DT, LI, SE, AC, nullptr, false /* PreserveLCSSA */);
-    Changed |= formLCSSARecursively(*L, *DT, LI, SE);
-  }
-
-  Changed |= Flatten(DT, LI, SE, AC, TTI);
+  Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI);
 
   if (!Changed)
     return PreservedAnalyses::all();
@@ -710,7 +705,7 @@
   }
 
   // Possibly flatten loop L into its child.
-  bool runOnFunction(Function &F) override;
+  bool runOnFunction(Function& F) override;
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     getLoopAnalysisUsage(AU);
@@ -727,18 +722,22 @@
                       false, false)
 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
+INITIALIZE_PASS_DEPENDENCY(LoopPass)
+INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
+INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
 INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
                     false, false)
 
 FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); }
 
-bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
+bool LoopFlattenLegacyPass::runOnFunction(Function& F) {
   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
   LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+  auto LN = LoopNest::getLoopNest(*LI->getLoopsInPreorder()[0], *SE);
   auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
   DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr;
   auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
   auto *TTI = &TTIP.getTTI(F);
   auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
-  return Flatten(DT, LI, SE, AC, TTI);
+  return Flatten(*LN, DT, LI, SE, AC, TTI);
 }