diff --git a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h --- a/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h +++ b/llvm/include/llvm/Transforms/Scalar/LoopFlatten.h @@ -24,7 +24,8 @@ public: LoopFlattenPass() = default; - PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + PreservedAnalyses run(LoopNest &LN, LoopAnalysisManager &LAM, + LoopStandardAnalysisResults &AR, LPMUpdater &U); }; } // end namespace llvm diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp --- a/llvm/lib/Passes/PassBuilder.cpp +++ b/llvm/lib/Passes/PassBuilder.cpp @@ -621,7 +621,7 @@ FPM.addPass(SimplifyCFGPass()); FPM.addPass(InstCombinePass()); if (EnableLoopFlatten) - FPM.addPass(LoopFlattenPass()); + FPM.addPass(createFunctionToLoopPassAdaptor(LoopFlattenPass())); // The loop passes in LPM2 (LoopFullUnrollPass) do not preserve MemorySSA. // *All* loop passes must preserve it, in order to be able to use it. FPM.addPass(createFunctionToLoopPassAdaptor(std::move(LPM2), @@ -796,7 +796,7 @@ FPM.addPass(SimplifyCFGPass()); FPM.addPass(InstCombinePass()); if (EnableLoopFlatten) - FPM.addPass(LoopFlattenPass()); + FPM.addPass(createFunctionToLoopPassAdaptor(LoopFlattenPass())); // The loop passes in LPM2 (LoopIdiomRecognizePass, IndVarSimplifyPass, // LoopDeletionPass and LoopFullUnrollPass) do not preserve MemorySSA. // *All* loop passes must preserve it, in order to be able to use it. @@ -1849,7 +1849,7 @@ // More loops are countable; try to optimize them. if (EnableLoopFlatten && Level.getSpeedupLevel() > 1) - MainFPM.addPass(LoopFlattenPass()); + MainFPM.addPass(createFunctionToLoopPassAdaptor(LoopFlattenPass())); if (EnableConstraintElimination) MainFPM.addPass(ConstraintEliminationPass()); diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def --- a/llvm/lib/Passes/PassRegistry.def +++ b/llvm/lib/Passes/PassRegistry.def @@ -249,7 +249,6 @@ FUNCTION_PASS("loop-simplify", LoopSimplifyPass()) FUNCTION_PASS("loop-sink", LoopSinkPass()) FUNCTION_PASS("loop-unroll-and-jam", LoopUnrollAndJamPass()) -FUNCTION_PASS("loop-flatten", LoopFlattenPass()) FUNCTION_PASS("lowerinvoke", LowerInvokePass()) FUNCTION_PASS("lowerswitch", LowerSwitchPass()) FUNCTION_PASS("mem2reg", PromotePass()) @@ -390,6 +389,7 @@ LOOP_PASS("dot-ddg", DDGDotPrinterPass()) LOOP_PASS("invalidate", InvalidateAllAnalysesPass()) LOOP_PASS("licm", LICMPass()) +LOOP_PASS("loop-flatten", LoopFlattenPass()) LOOP_PASS("loop-idiom", LoopIdiomRecognizePass()) LOOP_PASS("loop-instsimplify", LoopInstSimplifyPass()) LOOP_PASS("loop-interchange", LoopInterchangePass()) diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -28,7 +28,9 @@ #include "llvm/Transforms/Scalar/LoopFlatten.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -88,7 +90,7 @@ // Whether this holds the flatten info before or after widening. bool Widened = false; - FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; + FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){}; }; // Finds the induction variable, increment and limit for a simple loop that we @@ -343,8 +345,7 @@ // transformation wouldn't be profitable. Value *InnerLimit = FI.InnerLimit; - if (FI.Widened && - (isa(InnerLimit) || isa(InnerLimit))) + if (FI.Widened && (isa(InnerLimit) || isa(InnerLimit))) InnerLimit = cast(InnerLimit)->getOperand(0); // Check that all uses of the inner loop's induction variable match the @@ -354,8 +355,8 @@ if (U == FI.InnerIncrement) continue; - // After widening the IVs, a trunc instruction might have been introduced, so - // look through truncs. + // After widening the IVs, a trunc instruction might have been introduced, + // so look through truncs. if (isa(U)) { if (!U->hasOneUse()) return false; @@ -373,11 +374,11 @@ // Matches the same pattern as above, except it also looks for truncs // on the phi, which can be the result of widening the induction variables. - bool IsAddTrunc = match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)), - m_Value(MatchedMul))) && - match(MatchedMul, - m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), - m_Value(MatchedItCount))); + bool IsAddTrunc = + match(U, m_c_Add(m_Trunc(m_Specific(FI.InnerInductionPHI)), + m_Value(MatchedMul))) && + match(MatchedMul, m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), + m_Value(MatchedItCount))); if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) { LLVM_DEBUG(dbgs() << "Use is optimisable\n"); @@ -395,8 +396,9 @@ if (U == FI.OuterIncrement) continue; - auto IsValidOuterPHIUses = [&] (User *U) -> bool { - LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); + auto IsValidOuterPHIUses = [&](User *U) -> bool { + LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; + U->dump()); if (!ValidOuterPHIUses.count(U)) { LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); return false; @@ -420,7 +422,8 @@ LLVM_DEBUG(dbgs() << "checkIVUsers: OK\n"; dbgs() << "Found " << FI.LinearIVUses.size() << " value(s) that can be replaced:\n"; - for (Value *V : FI.LinearIVUses) { + for (Value *V + : FI.LinearIVUses) { dbgs() << " "; V->dump(); }); @@ -472,11 +475,13 @@ ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI) { SmallPtrSet IterationInstructions; - if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI, - FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE)) + if (!findLoopComponents(FI.InnerLoop, IterationInstructions, + FI.InnerInductionPHI, FI.InnerLimit, + FI.InnerIncrement, FI.InnerBranch, SE)) return false; - if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI, - FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE)) + if (!findLoopComponents(FI.OuterLoop, IterationInstructions, + FI.OuterInductionPHI, FI.OuterLimit, + FI.OuterIncrement, FI.OuterBranch, SE)) return false; // Both of the loop limit values must be invariant in the outer loop @@ -519,16 +524,17 @@ LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n"); { using namespace ore; - OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(), + OptimizationRemark Remark(DEBUG_TYPE, "Flattened", + FI.InnerLoop->getStartLoc(), FI.InnerLoop->getHeader()); OptimizationRemarkEmitter ORE(F); Remark << "Flattened into outer loop"; ORE.emit(Remark); } - Value *NewTripCount = - BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount", - FI.OuterLoop->getLoopPreheader()->getTerminator()); + Value *NewTripCount = BinaryOperator::CreateMul( + FI.InnerLimit, FI.OuterLimit, "flatten.tripcount", + FI.OuterLoop->getLoopPreheader()->getTerminator()); LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; NewTripCount->dump()); @@ -561,8 +567,8 @@ OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(), "flatten.trunciv"); - LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); - dbgs() << "with: "; OuterValue->dump()); + LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: "; + OuterValue->dump()); V->replaceAllUsesWith(OuterValue); } @@ -595,7 +601,8 @@ // (OuterLimit * InnerLimit) as the new trip count is safe. if (InnerType != OuterType || InnerType->getScalarSizeInBits() >= MaxLegalSize || - MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) { + MaxLegalType->getScalarSizeInBits() < + InnerType->getScalarSizeInBits() * 2) { LLVM_DEBUG(dbgs() << "Can't widen the IV\n"); return false; } @@ -603,15 +610,15 @@ SCEVExpander Rewriter(*SE, DL, "loopflatten"); SmallVector WideIVs; SmallVector DeadInsts; - WideIVs.push_back( {FI.InnerInductionPHI, MaxLegalType, false }); - WideIVs.push_back( {FI.OuterInductionPHI, MaxLegalType, false }); + WideIVs.push_back({FI.InnerInductionPHI, MaxLegalType, false}); + WideIVs.push_back({FI.OuterInductionPHI, MaxLegalType, false}); unsigned ElimExt = 0; unsigned Widened = 0; for (const auto &WideIV : WideIVs) { - PHINode *WidePhi = createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, - ElimExt, Widened, true /* HasGuards */, - true /* UsePostIncrementRanges */); + PHINode *WidePhi = + createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened, + true /* HasGuards */, true /* UsePostIncrementRanges */); if (!WidePhi) return false; LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump()); @@ -658,10 +665,10 @@ return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI); } -bool Flatten(DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, +bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE, AssumptionCache *AC, TargetTransformInfo *TTI) { bool Changed = false; - for (auto *InnerLoop : LI->getLoopsInPreorder()) { + for (Loop *InnerLoop : LN.getLoops()) { auto *OuterLoop = InnerLoop->getParentLoop(); if (!OuterLoop) continue; @@ -671,13 +678,9 @@ return Changed; } -PreservedAnalyses LoopFlattenPass::run(Function &F, - FunctionAnalysisManager &AM) { - auto *DT = &AM.getResult(F); - auto *LI = &AM.getResult(F); - auto *SE = &AM.getResult(F); - auto *AC = &AM.getResult(F); - auto *TTI = &AM.getResult(F); +PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { bool Changed = false; @@ -685,15 +688,7 @@ // in simplified form, and also needs LCSSA. Running // this pass will simplify all loops that contain inner loops, // regardless of whether anything ends up being flattened. - for (const auto &L : *LI) { - if (L->isInnermost()) - continue; - Changed |= - simplifyLoop(L, DT, LI, SE, AC, nullptr, false /* PreserveLCSSA */); - Changed |= formLCSSARecursively(*L, *DT, LI, SE); - } - - Changed |= Flatten(DT, LI, SE, AC, TTI); + Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI); if (!Changed) return PreservedAnalyses::all(); @@ -710,7 +705,7 @@ } // Possibly flatten loop L into its child. - bool runOnFunction(Function &F) override; + bool runOnFunction(Function& F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { getLoopAnalysisUsage(AU); @@ -727,18 +722,22 @@ false, false) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", false, false) FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); } -bool LoopFlattenLegacyPass::runOnFunction(Function &F) { +bool LoopFlattenLegacyPass::runOnFunction(Function& F) { ScalarEvolution *SE = &getAnalysis().getSE(); LoopInfo *LI = &getAnalysis().getLoopInfo(); + auto LN = LoopNest::getLoopNest(*LI->getLoopsInPreorder()[0], *SE); auto *DTWP = getAnalysisIfAvailable(); DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; auto &TTIP = getAnalysis(); auto *TTI = &TTIP.getTTI(F); auto *AC = &getAnalysis().getAssumptionCache(F); - return Flatten(DT, LI, SE, AC, TTI); + return Flatten(*LN, DT, LI, SE, AC, TTI); }