diff --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h --- a/llvm/include/llvm/Transforms/Scalar.h +++ b/llvm/include/llvm/Transforms/Scalar.h @@ -15,6 +15,7 @@ #define LLVM_TRANSFORMS_SCALAR_H #include "llvm/Transforms/Utils/SimplifyCFGOptions.h" +#include "llvm/Analysis/LoopPass.h" #include namespace llvm { @@ -159,7 +160,7 @@ // // LoopFlatten - This pass flattens nested loops into a single loop. // -FunctionPass *createLoopFlattenPass(); +LoopPass *createLoopFlattenPass(); //===----------------------------------------------------------------------===// // diff --git a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h --- a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h +++ b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h @@ -24,7 +24,8 @@ public: LoopFlattenPass() = default; - PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + PreservedAnalyses run(LoopNest &LN, LoopAnalysisManager &LAM, + LoopStandardAnalysisResults &AR, LPMUpdater &U); }; } // end namespace llvm diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -621,7 +621,7 @@ FPM.addPass(SimplifyCFGPass()); FPM.addPass(InstCombinePass()); if (EnableLoopFlatten) - FPM.addPass(LoopFlattenPass()); + FPM.addPass(createFunctionToLoopPassAdaptor(LoopFlattenPass())); // The loop passes in LPM2 (LoopFullUnrollPass) do not preserve MemorySSA. // *All* loop passes must preserve it, in order to be able to use it. FPM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM2), @@ -796,7 +796,7 @@ FPM.addPass(SimplifyCFGPass()); FPM.addPass(InstCombinePass()); if (EnableLoopFlatten) - FPM.addPass(LoopFlattenPass()); + FPM.addPass(createFunctionToLoopPassAdaptor(LoopFlattenPass())); // The loop passes in LPM2 (LoopIdiomRecognizePass, IndVarSimplifyPass, // LoopDeletionPass and LoopFullUnrollPass) do not preserve MemorySSA. // *All* loop passes must preserve it, in order to be able to use it. @@ -1849,7 +1849,7 @@ // More loops are countable; try to optimize them. if (EnableLoopFlatten && Level.getSpeedupLevel() > 1) - MainFPM.addPass(LoopFlattenPass()); + MainFPM.addPass(createFunctionToLoopPassAdaptor(LoopFlattenPass())); if (EnableConstraintElimination) MainFPM.addPass(ConstraintEliminationPass()); diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -249,7 +249,6 @@ FUNCTION_PASS("loop-simplify", LoopSimplifyPass()) FUNCTION_PASS("loop-sink", LoopSinkPass()) FUNCTION_PASS("loop-unroll-and-jam", LoopUnrollAndJamPass()) -FUNCTION_PASS("loop-flatten", LoopFlattenPass()) FUNCTION_PASS("lowerinvoke", LowerInvokePass()) FUNCTION_PASS("lowerswitch", LowerSwitchPass()) FUNCTION_PASS("mem2reg", PromotePass()) @@ -396,6 +395,7 @@ LOOP_PASS("loop-rotate", LoopRotatePass()) LOOP_PASS("no-op-loop", NoOpLoopPass()) LOOP_PASS("print", PrintLoopPass(dbgs())) +LOOP_PASS("loop-flatten", LoopFlattenPass()) LOOP_PASS("loop-deletion", LoopDeletionPass()) LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass()) LOOP_PASS("loop-reduce", LoopStrengthReducePass()) diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -28,7 +28,9 @@ #include "llvm/Transforms/Scalar/LoopFlatten.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -671,13 +673,8 @@ return Changed; } -PreservedAnalyses LoopFlattenPass::run(Function &F, - FunctionAnalysisManager &AM) { - auto *DT = &AM.getResult(F); - auto *LI = &AM.getResult(F); - auto *SE = &AM.getResult(F); - auto *AC = &AM.getResult(F); - auto *TTI = &AM.getResult(F); +PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, + LoopStandardAnalysisResults &AR, LPMUpdater &U) { bool Changed = false; @@ -685,15 +682,16 @@ // in simplified form, and also needs LCSSA. Running // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. - for (const auto &L : *LI) { + + for (const auto &L : LN.getLoops()) { if (L->isInnermost()) continue; Changed |= - simplifyLoop(L, DT, LI, SE, AC, nullptr, false /* PreserveLCSSA */); - Changed |= formLCSSARecursively(*L, *DT, LI, SE); + simplifyLoop(L, &AR.DT, &AR.LI, &AR.SE, &AR.AC, nullptr, false /* PreserveSSA */); + Changed |= formLCSSARecursively(*L, AR.DT, &AR.LI, &AR.SE); } - Changed |= Flatten(DT, LI, SE, AC, TTI); + Changed |= Flatten(&AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI); if (!Changed) return PreservedAnalyses::all(); @@ -702,15 +700,16 @@ } namespace { -class LoopFlattenLegacyPass : public FunctionPass { +class LoopFlattenLegacyPass : public LoopPass { public: static char ID; // Pass ID, replacement for typeid - LoopFlattenLegacyPass() : FunctionPass(ID) { + LoopFlattenLegacyPass() : LoopPass(ID) { initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry()); } // Possibly flatten loop L into its child. - bool runOnFunction(Function &F) override; + //bool runOnFunction(Function &F) override; + bool runOnLoop(Loop *L, LPPassManager &LPM) override; void getAnalysisUsage(AnalysisUsage &AU) const override { getLoopAnalysisUsage(AU); @@ -730,9 +729,10 @@ INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", false, false) -FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); } +LoopPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); } -bool LoopFlattenLegacyPass::runOnFunction(Function &F) { +bool LoopFlattenLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { + Function &F = *L->getHeader()->getParent(); ScalarEvolution *SE = &getAnalysis().getSE(); LoopInfo *LI = &getAnalysis().getLoopInfo(); auto *DTWP = getAnalysisIfAvailable();