diff --git a/llvm/include/llvm/Transforms/Scalar.h b/llvm/include/llvm/Transforms/Scalar.h --- a/llvm/include/llvm/Transforms/Scalar.h +++ b/llvm/include/llvm/Transforms/Scalar.h @@ -153,7 +153,7 @@ // // LoopFlatten - This pass flattens nested loops into a single loop. // -Pass *createLoopFlattenPass(); +FunctionPass *createLoopFlattenPass(); //===----------------------------------------------------------------------===// // diff --git a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h --- a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h +++ b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h @@ -24,8 +24,7 @@ public: LoopFlattenPass() = default; - PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, LPMUpdater &U); + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); }; } // end namespace llvm diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -543,7 +543,7 @@ LPM2.addPass(LoopDeletionPass()); if (EnableLoopFlatten) - LPM2.addPass(LoopFlattenPass()); + FPM.addPass(LoopFlattenPass()); // Do not enable unrolling in PreLinkThinLTO phase during sample PGO // because it changes IR to makes profile annotation in back compile // inaccurate. The normal unroller doesn't pay attention to forced full unroll diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -240,6 +240,7 @@ FUNCTION_PASS("loop-simplify", LoopSimplifyPass()) FUNCTION_PASS("loop-sink", LoopSinkPass()) FUNCTION_PASS("loop-unroll-and-jam", LoopUnrollAndJamPass()) +FUNCTION_PASS("loop-flatten", LoopFlattenPass()) FUNCTION_PASS("lowerinvoke", LowerInvokePass()) FUNCTION_PASS("lowerswitch", LowerSwitchPass()) FUNCTION_PASS("mem2reg", PromotePass()) @@ -380,7 +381,6 @@ LOOP_PASS("no-op-loop", NoOpLoopPass()) LOOP_PASS("print", PrintLoopPass(dbgs())) LOOP_PASS("loop-deletion", LoopDeletionPass()) -LOOP_PASS("loop-flatten", LoopFlattenPass()) LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass()) LOOP_PASS("loop-reduce", LoopStrengthReducePass()) LOOP_PASS("indvars", IndVarSimplifyPass()) diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -29,7 +29,6 @@ #include "llvm/Transforms/Scalar/LoopFlatten.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -416,17 +415,14 @@ static bool FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, const TargetTransformInfo *TTI, - std::function markLoopAsDeleted) { + AssumptionCache *AC, TargetTransformInfo *TTI) { Function *F = FI.OuterLoop->getHeader()->getParent(); - LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop " << FI.OuterLoop->getHeader()->getName() << " and inner loop " << FI.InnerLoop->getHeader()->getName() << " in " << F->getName() << "\n"); SmallPtrSet IterationInstructions; - if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI, FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE)) return false; @@ -528,40 +524,51 @@ // Tell LoopInfo, SCEV and the pass manager that the inner loop has been // deleted, and any information that have about the outer loop invalidated. - markLoopAsDeleted(FI.InnerLoop); SE->forgetLoop(FI.OuterLoop); SE->forgetLoop(FI.InnerLoop); LI->erase(FI.InnerLoop); - return true; } -PreservedAnalyses LoopFlattenPass::run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, - LPMUpdater &Updater) { - if (L.getSubLoops().size() != 1) - return PreservedAnalyses::all(); +bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, TargetTransformInfo *TTI) { + bool Changed = false; + for (auto *InnerLoop : LI->getLoopsInPreorder()) { + auto *OuterLoop = InnerLoop->getParentLoop(); + if (!OuterLoop) + continue; + struct FlattenInfo FI(OuterLoop, InnerLoop); + Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI); + } + return Changed; +} - Loop *InnerLoop = *L.begin(); - std::string LoopName(InnerLoop->getName()); - struct FlattenInfo FI(InnerLoop->getParentLoop(), InnerLoop); - if (!FlattenLoopPair( - FI, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, - [&](Loop *L) { Updater.markLoopAsDeleted(*L, LoopName); })) +PreservedAnalyses LoopFlattenPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto *DT = &AM.getResult(F); + auto *LI = &AM.getResult(F); + auto *SE = &AM.getResult(F); + auto *AC = &AM.getResult(F); + auto *TTI = &AM.getResult(F); + + if (!Flatten(DT, LI, SE, AC, TTI)) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + + PreservedAnalyses PA; + PA.preserveSet(); + return PA; } namespace { -class LoopFlattenLegacyPass : public LoopPass { +class LoopFlattenLegacyPass : public FunctionPass { public: static char ID; // Pass ID, replacement for typeid - LoopFlattenLegacyPass() : LoopPass(ID) { + LoopFlattenLegacyPass() : FunctionPass(ID) { initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry()); } // Possibly flatten loop L into its child. - bool runOnLoop(Loop *L, LPPassManager &) override; + bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { getLoopAnalysisUsage(AU); @@ -576,33 +583,20 @@ char LoopFlattenLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", false, false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", false, false) -Pass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); } - -bool LoopFlattenLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipLoop(L)) - return false; - - if (L->getSubLoops().size() != 1) - return false; +FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); } +bool LoopFlattenLegacyPass::runOnFunction(Function &F) { ScalarEvolution *SE = &getAnalysis().getSE(); LoopInfo *LI = &getAnalysis().getLoopInfo(); auto *DTWP = getAnalysisIfAvailable(); DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; auto &TTIP = getAnalysis(); - TargetTransformInfo *TTI = &TTIP.getTTI(*L->getHeader()->getParent()); - AssumptionCache *AC = - &getAnalysis().getAssumptionCache( - *L->getHeader()->getParent()); - - Loop *InnerLoop = *L->begin(); - struct FlattenInfo FI(InnerLoop->getParentLoop(), InnerLoop); - return FlattenLoopPair(FI, DT, LI, SE, AC, TTI, - [&](Loop *L) { LPM.markLoopAsDeleted(*L); }); + auto *TTI = &TTIP.getTTI(F); + auto *AC = &getAnalysis().getAssumptionCache(F); + return Flatten(DT, LI, SE, AC, TTI); } diff --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll --- a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll +++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll @@ -393,3 +393,216 @@ for.end19: ; preds = %for.end16 ret i32 undef } + +; A 3d loop corresponding to: +; +; for (int i = 0; i < N; ++i) +; for (int j = 0; j < N; ++j) +; for (int k = 0; k < N; ++k) +; f(&A[i + N * (j + N * k)]); +; +define void @d3_1(i32* %A, i32 %N) { +entry: + %cmp35 = icmp sgt i32 %N, 0 + br i1 %cmp35, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup + +for.cond1.preheader.lr.ph: + br label %for.cond1.preheader.us + +for.cond1.preheader.us: + %i.036.us = phi i32 [ 0, %for.cond1.preheader.lr.ph ], [ %inc15.us, %for.cond1.for.cond.cleanup3_crit_edge.us ] + br i1 true, label %for.cond5.preheader.us.us.preheader, label %for.cond5.preheader.us52.preheader + +for.cond5.preheader.us52.preheader: + br label %for.cond5.preheader.us52 + +for.cond5.preheader.us.us.preheader: + br label %for.cond5.preheader.us.us + +for.cond5.preheader.us52: + br i1 false, label %for.cond5.preheader.us52, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit58 + +for.cond1.for.cond.cleanup3_crit_edge.us.loopexit: + br label %for.cond1.for.cond.cleanup3_crit_edge.us + +for.cond1.for.cond.cleanup3_crit_edge.us.loopexit58: + br label %for.cond1.for.cond.cleanup3_crit_edge.us + +for.cond1.for.cond.cleanup3_crit_edge.us: + %inc15.us = add nuw nsw i32 %i.036.us, 1 + %cmp.us = icmp slt i32 %inc15.us, %N + br i1 %cmp.us, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit + +for.cond5.preheader.us.us: + %j.033.us.us = phi i32 [ %inc12.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.cond5.preheader.us.us.preheader ] + br label %for.body8.us.us + +for.cond5.for.cond.cleanup7_crit_edge.us.us: + %inc12.us.us = add nuw nsw i32 %j.033.us.us, 1 + %cmp2.us.us = icmp slt i32 %inc12.us.us, %N + br i1 %cmp2.us.us, label %for.cond5.preheader.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit + +for.body8.us.us: + %k.031.us.us = phi i32 [ 0, %for.cond5.preheader.us.us ], [ %inc.us.us, %for.body8.us.us ] + %mul.us.us = mul nsw i32 %k.031.us.us, %N + %add.us.us = add nsw i32 %mul.us.us, %j.033.us.us + %mul9.us.us = mul nsw i32 %add.us.us, %N + %add10.us.us = add nsw i32 %mul9.us.us, %i.036.us + %idxprom.us.us = sext i32 %add10.us.us to i64 + %arrayidx.us.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us.us + tail call void @f(i32* %arrayidx.us.us) #2 + %inc.us.us = add nuw nsw i32 %k.031.us.us, 1 + %cmp6.us.us = icmp slt i32 %inc.us.us, %N + br i1 %cmp6.us.us, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us + +for.cond.cleanup.loopexit: + br label %for.cond.cleanup + +for.cond.cleanup: + ret void +} + +; A 3d loop corresponding to: +; +; for (int k = 0; k < N; ++k) +; for (int i = 0; i < N; ++i) +; for (int j = 0; j < M; ++j) +; f(&A[i*M+j]); +; +; This could be supported, but isn't at the moment. +; +define void @d3_2(i32* %A, i32 %N, i32 %M) { +entry: + %cmp30 = icmp sgt i32 %N, 0 + br i1 %cmp30, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup + +for.cond1.preheader.lr.ph: + %cmp625 = icmp sgt i32 %M, 0 + br label %for.cond1.preheader.us + +for.cond1.preheader.us: + %k.031.us = phi i32 [ 0, %for.cond1.preheader.lr.ph ], [ %inc13.us, %for.cond1.for.cond.cleanup3_crit_edge.us ] + br i1 %cmp625, label %for.cond5.preheader.us.us.preheader, label %for.cond5.preheader.us43.preheader + +for.cond5.preheader.us43.preheader: + br label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50 + +for.cond5.preheader.us.us.preheader: + br label %for.cond5.preheader.us.us + +for.cond1.for.cond.cleanup3_crit_edge.us.loopexit: + br label %for.cond1.for.cond.cleanup3_crit_edge.us + +for.cond1.for.cond.cleanup3_crit_edge.us.loopexit50: + br label %for.cond1.for.cond.cleanup3_crit_edge.us + +for.cond1.for.cond.cleanup3_crit_edge.us: + %inc13.us = add nuw nsw i32 %k.031.us, 1 + %exitcond52 = icmp ne i32 %inc13.us, %N + br i1 %exitcond52, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit + +for.cond5.preheader.us.us: + %i.028.us.us = phi i32 [ %inc10.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.cond5.preheader.us.us.preheader ] + %mul.us.us = mul nsw i32 %i.028.us.us, %M + br label %for.body8.us.us + +for.cond5.for.cond.cleanup7_crit_edge.us.us: + %inc10.us.us = add nuw nsw i32 %i.028.us.us, 1 + %exitcond51 = icmp ne i32 %inc10.us.us, %N + br i1 %exitcond51, label %for.cond5.preheader.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit + +for.body8.us.us: + %j.026.us.us = phi i32 [ 0, %for.cond5.preheader.us.us ], [ %inc.us.us, %for.body8.us.us ] + %add.us.us = add nsw i32 %j.026.us.us, %mul.us.us + %idxprom.us.us = sext i32 %add.us.us to i64 + %arrayidx.us.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us.us + tail call void @f(i32* %arrayidx.us.us) #2 + %inc.us.us = add nuw nsw i32 %j.026.us.us, 1 + %exitcond = icmp ne i32 %inc.us.us, %M + br i1 %exitcond, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us + +for.cond.cleanup.loopexit: + br label %for.cond.cleanup + +for.cond.cleanup: + ret void +} + +; A 3d loop corresponding to: +; +; for (int i = 0; i < N; ++i) +; for (int j = 0; j < M; ++j) { +; A[i*M+j] = 0; +; for (int k = 0; k < N; ++k) +; g(); +; } +; +define void @d3_3(i32* nocapture %A, i32 %N, i32 %M) { +entry: + %cmp29 = icmp sgt i32 %N, 0 + br i1 %cmp29, label %for.cond1.preheader.lr.ph, label %for.cond.cleanup + +for.cond1.preheader.lr.ph: + %cmp227 = icmp sgt i32 %M, 0 + br i1 %cmp227, label %for.cond1.preheader.us.preheader, label %for.cond1.preheader.preheader + +for.cond1.preheader.preheader: + br label %for.cond.cleanup.loopexit49 + +for.cond1.preheader.us.preheader: + br label %for.cond1.preheader.us + +for.cond1.preheader.us: + %i.030.us = phi i32 [ %inc13.us, %for.cond1.for.cond.cleanup3_crit_edge.us ], [ 0, %for.cond1.preheader.us.preheader ] + %mul.us = mul nsw i32 %i.030.us, %M + br i1 true, label %for.body4.us.us.preheader, label %for.body4.us32.preheader + +for.body4.us32.preheader: + br label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit48 + +for.body4.us.us.preheader: + br label %for.body4.us.us + +for.cond1.for.cond.cleanup3_crit_edge.us.loopexit: + br label %for.cond1.for.cond.cleanup3_crit_edge.us + +for.cond1.for.cond.cleanup3_crit_edge.us.loopexit48: + br label %for.cond1.for.cond.cleanup3_crit_edge.us + +for.cond1.for.cond.cleanup3_crit_edge.us: + %inc13.us = add nuw nsw i32 %i.030.us, 1 + %exitcond51 = icmp ne i32 %inc13.us, %N + br i1 %exitcond51, label %for.cond1.preheader.us, label %for.cond.cleanup.loopexit + +for.body4.us.us: + %j.028.us.us = phi i32 [ %inc10.us.us, %for.cond5.for.cond.cleanup7_crit_edge.us.us ], [ 0, %for.body4.us.us.preheader ] + %add.us.us = add nsw i32 %j.028.us.us, %mul.us + %idxprom.us.us = sext i32 %add.us.us to i64 + %arrayidx.us.us = getelementptr inbounds i32, i32* %A, i64 %idxprom.us.us + store i32 0, i32* %arrayidx.us.us, align 4 + br label %for.body8.us.us + +for.cond5.for.cond.cleanup7_crit_edge.us.us: + %inc10.us.us = add nuw nsw i32 %j.028.us.us, 1 + %exitcond50 = icmp ne i32 %inc10.us.us, %M + br i1 %exitcond50, label %for.body4.us.us, label %for.cond1.for.cond.cleanup3_crit_edge.us.loopexit + +for.body8.us.us: + %k.026.us.us = phi i32 [ 0, %for.body4.us.us ], [ %inc.us.us, %for.body8.us.us ] + tail call void bitcast (void (...)* @g to void ()*)() #2 + %inc.us.us = add nuw nsw i32 %k.026.us.us, 1 + %exitcond = icmp ne i32 %inc.us.us, %N + br i1 %exitcond, label %for.body8.us.us, label %for.cond5.for.cond.cleanup7_crit_edge.us.us + +for.cond.cleanup.loopexit: + br label %for.cond.cleanup + +for.cond.cleanup.loopexit49: + br label %for.cond.cleanup + +for.cond.cleanup: + ret void +} + +declare dso_local void @f(i32*) +declare dso_local void @g(...)