Index: llvm/include/llvm/Transforms/Scalar.h =================================================================== --- llvm/include/llvm/Transforms/Scalar.h +++ llvm/include/llvm/Transforms/Scalar.h @@ -149,12 +149,6 @@ // Pass *createLoopInterchangePass(); -//===----------------------------------------------------------------------===// -// -// LoopFlatten - This pass flattens nested loops into a single loop. -// -Pass *createLoopFlattenPass(); - //===----------------------------------------------------------------------===// // // LoopStrengthReduce - This pass is strength reduces GEP instructions that use @@ -333,6 +327,12 @@ // Pass *createLoopDeletionPass(); +//===----------------------------------------------------------------------===// +// +// LoopFlatten - This pass flattens nested loops into a single loop. +// +FunctionPass *createLoopFlattenPass(); + //===----------------------------------------------------------------------===// // // ConstantHoisting - This pass prepares a function for expensive constants. Index: llvm/include/llvm/Transforms/Scalar/LoopFlatten.h =================================================================== --- llvm/include/llvm/Transforms/Scalar/LoopFlatten.h +++ llvm/include/llvm/Transforms/Scalar/LoopFlatten.h @@ -24,8 +24,7 @@ public: LoopFlattenPass() = default; - PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, LPMUpdater &U); + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; } // end namespace llvm Index: llvm/lib/Passes/PassBuilder.cpp =================================================================== --- llvm/lib/Passes/PassBuilder.cpp +++ llvm/lib/Passes/PassBuilder.cpp @@ -540,7 +540,7 @@ LPM2.addPass(LoopDeletionPass()); if (EnableLoopFlatten) - LPM2.addPass(LoopFlattenPass()); + FPM.addPass(LoopFlattenPass()); // Do not enable unrolling in PreLinkThinLTO phase during sample PGO // because it changes IR to makes profile annotation in back compile // inaccurate. The normal unroller doesn't pay attention to forced full unroll Index: llvm/lib/Passes/PassRegistry.def =================================================================== --- llvm/lib/Passes/PassRegistry.def +++ llvm/lib/Passes/PassRegistry.def @@ -239,6 +239,7 @@ FUNCTION_PASS("loop-simplify", LoopSimplifyPass()) FUNCTION_PASS("loop-sink", LoopSinkPass()) FUNCTION_PASS("loop-unroll-and-jam", LoopUnrollAndJamPass()) +FUNCTION_PASS("loop-flatten", LoopFlattenPass()) FUNCTION_PASS("lowerinvoke", LowerInvokePass()) FUNCTION_PASS("lowerswitch", LowerSwitchPass()) FUNCTION_PASS("mem2reg", PromotePass()) @@ -377,7 +378,6 @@ LOOP_PASS("no-op-loop", NoOpLoopPass()) LOOP_PASS("print", PrintLoopPass(dbgs())) LOOP_PASS("loop-deletion", LoopDeletionPass()) -LOOP_PASS("loop-flatten", LoopFlattenPass()) LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass()) LOOP_PASS("loop-reduce", LoopStrengthReducePass()) LOOP_PASS("indvars", IndVarSimplifyPass()) Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -29,7 +29,6 @@ #include "llvm/Transforms/Scalar/LoopFlatten.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -404,17 +403,14 @@ static bool FlattenLoopPair(Loop *OuterLoop, Loop *InnerLoop, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, TargetTransformInfo *TTI, - std::function markLoopAsDeleted) { + AssumptionCache *AC, TargetTransformInfo *TTI) { Function *F = OuterLoop->getHeader()->getParent(); - LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop " << OuterLoop->getHeader()->getName() << " and inner loop " << InnerLoop->getHeader()->getName() << " in " << F->getName() << "\n"); SmallPtrSet IterationInstructions; - PHINode *InnerInductionPHI, *OuterInductionPHI; Value *InnerLimit, *OuterLimit; BinaryOperator *InnerIncrement, *OuterIncrement; @@ -527,39 +523,50 @@ // Tell LoopInfo, SCEV and the pass manager that the inner loop has been // deleted, and any information that have about the outer loop invalidated. - markLoopAsDeleted(InnerLoop); SE->forgetLoop(OuterLoop); SE->forgetLoop(InnerLoop); LI->erase(InnerLoop); - return true; } -PreservedAnalyses LoopFlattenPass::run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, - LPMUpdater &Updater) { - if (L.getSubLoops().size() != 1) - return PreservedAnalyses::all(); +bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, TargetTransformInfo *TTI) { + bool Changed = false; + for (auto *OuterLoop : LI->getTopLevelLoops()) { + if (OuterLoop->getSubLoops().size() != 1) + continue; + Loop *InnerLoop = *OuterLoop->begin(); + Changed |= FlattenLoopPair(OuterLoop, InnerLoop, DT, LI, SE, AC, TTI); + } + return Changed; +} - Loop *InnerLoop = *L.begin(); - std::string LoopName(InnerLoop->getName()); - if (!FlattenLoopPair( - &L, InnerLoop, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, - [&](Loop *L) { Updater.markLoopAsDeleted(*L, LoopName); })) +PreservedAnalyses LoopFlattenPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto *DT = &AM.getResult(F); + auto *LI = &AM.getResult(F); + auto *SE = &AM.getResult(F); + auto *AC = &AM.getResult(F); + auto *TTI = &AM.getResult(F); + + if (!Flatten(DT, LI, SE, AC, TTI)) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + + PreservedAnalyses PA; + PA.preserveSet(); + return PA; } namespace { -class LoopFlattenLegacyPass : public LoopPass { +class LoopFlattenLegacyPass : public FunctionPass { public: static char ID; // Pass ID, replacement for typeid - LoopFlattenLegacyPass() : LoopPass(ID) { + LoopFlattenLegacyPass() : FunctionPass(ID) { initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry()); } // Possibly flatten loop L into its child. - bool runOnLoop(Loop *L, LPPassManager &) override; + bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { getLoopAnalysisUsage(AU); @@ -574,32 +581,20 @@ char LoopFlattenLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", false, false) -Pass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); } - -bool LoopFlattenLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipLoop(L)) - return false; - - if (L->getSubLoops().size() != 1) - return false; +FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); } +bool LoopFlattenLegacyPass::runOnFunction(Function &F) { ScalarEvolution *SE = &getAnalysis().getSE(); LoopInfo *LI = &getAnalysis().getLoopInfo(); auto *DTWP = getAnalysisIfAvailable(); DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; auto &TTIP = getAnalysis(); - TargetTransformInfo *TTI = &TTIP.getTTI(*L->getHeader()->getParent()); - AssumptionCache *AC = - &getAnalysis().getAssumptionCache( - *L->getHeader()->getParent()); - - Loop *InnerLoop = *L->begin(); - return FlattenLoopPair(L, InnerLoop, DT, LI, SE, AC, TTI, - [&](Loop *L) { LPM.markLoopAsDeleted(*L); }); + auto *TTI = &TTIP.getTTI(F); + auto *AC = &getAnalysis().getAssumptionCache(F); + return Flatten(DT, LI, SE, AC, TTI); }