Index: llvm/include/llvm/InitializePasses.h =================================================================== --- llvm/include/llvm/InitializePasses.h +++ llvm/include/llvm/InitializePasses.h @@ -242,7 +242,6 @@ void initializeLoopInfoWrapperPassPass(PassRegistry&); void initializeLoopInstSimplifyLegacyPassPass(PassRegistry&); void initializeLoopInterchangeLegacyPassPass(PassRegistry &); -void initializeLoopFlattenLegacyPassPass(PassRegistry&); void initializeLoopLoadEliminationPass(PassRegistry&); void initializeLoopPassPass(PassRegistry&); void initializeLoopPredicationLegacyPassPass(PassRegistry&); Index: llvm/include/llvm/LinkAllPasses.h =================================================================== --- llvm/include/llvm/LinkAllPasses.h +++ llvm/include/llvm/LinkAllPasses.h @@ -127,7 +127,6 @@ (void) llvm::createLazyValueInfoPass(); (void) llvm::createLoopExtractorPass(); (void) llvm::createLoopInterchangePass(); - (void) llvm::createLoopFlattenPass(); (void) llvm::createLoopPredicationPass(); (void) llvm::createLoopSimplifyPass(); (void) llvm::createLoopSimplifyCFGPass(); Index: llvm/include/llvm/Transforms/Scalar.h =================================================================== --- llvm/include/llvm/Transforms/Scalar.h +++ llvm/include/llvm/Transforms/Scalar.h @@ -149,12 +149,6 @@ // Pass *createLoopInterchangePass(); -//===----------------------------------------------------------------------===// -// -// LoopFlatten - This pass flattens nested loops into a single loop. -// -Pass *createLoopFlattenPass(); - //===----------------------------------------------------------------------===// // // LoopStrengthReduce - This pass is strength reduces GEP instructions that use Index: llvm/include/llvm/Transforms/Scalar/LoopFlatten.h =================================================================== --- llvm/include/llvm/Transforms/Scalar/LoopFlatten.h +++ /dev/null @@ -1,33 +0,0 @@ -//===- LoopFlatten.h - Loop Flatten ---------------- -----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file provides the interface for the Loop Flatten Pass. -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_TRANSFORMS_SCALAR_LOOPFLATTEN_H -#define LLVM_TRANSFORMS_SCALAR_LOOPFLATTEN_H - -#include "llvm/Analysis/LoopAnalysisManager.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/IR/PassManager.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" - -namespace llvm { - -class LoopFlattenPass : public PassInfoMixin { -public: - LoopFlattenPass() = default; - - PreservedAnalyses run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, LPMUpdater &U); -}; - -} // end namespace llvm - -#endif // LLVM_TRANSFORMS_SCALAR_LOOPFLATTEN_H Index: llvm/lib/Passes/PassBuilder.cpp =================================================================== --- llvm/lib/Passes/PassBuilder.cpp +++ llvm/lib/Passes/PassBuilder.cpp @@ -151,7 +151,6 @@ #include "llvm/Transforms/Scalar/LoopDataPrefetch.h" #include "llvm/Transforms/Scalar/LoopDeletion.h" #include "llvm/Transforms/Scalar/LoopDistribute.h" -#include "llvm/Transforms/Scalar/LoopFlatten.h" #include "llvm/Transforms/Scalar/LoopFuse.h" #include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" #include "llvm/Transforms/Scalar/LoopInstSimplify.h" @@ -257,10 +256,6 @@ "enable-npm-unroll-and-jam", cl::init(false), cl::Hidden, cl::desc("Enable the Unroll and Jam pass for the new PM (default = off)")); -static cl::opt EnableLoopFlatten( - "enable-npm-loop-flatten", cl::init(false), cl::Hidden, - cl::desc("Enable the Loop flattening pass for the new PM (default = off)")); - static cl::opt EnableSyntheticCounts( "enable-npm-synthetic-counts", cl::init(false), cl::Hidden, cl::ZeroOrMore, cl::desc("Run synthetic function entry count generation " @@ -521,8 +516,6 @@ C(LPM2, Level); LPM2.addPass(LoopDeletionPass()); - if (EnableLoopFlatten) - LPM2.addPass(LoopFlattenPass()); // Do not enable unrolling in PreLinkThinLTO phase during sample PGO // because it changes IR to makes profile annotation in back compile // inaccurate. The normal unroller doesn't pay attention to forced full unroll Index: llvm/lib/Passes/PassRegistry.def =================================================================== --- llvm/lib/Passes/PassRegistry.def +++ llvm/lib/Passes/PassRegistry.def @@ -366,7 +366,6 @@ LOOP_PASS("no-op-loop", NoOpLoopPass()) LOOP_PASS("print", PrintLoopPass(dbgs())) LOOP_PASS("loop-deletion", LoopDeletionPass()) -LOOP_PASS("loop-flatten", LoopFlattenPass()) LOOP_PASS("loop-simplifycfg", LoopSimplifyCFGPass()) LOOP_PASS("loop-reduce", LoopStrengthReducePass()) LOOP_PASS("indvars", IndVarSimplifyPass()) Index: llvm/lib/Transforms/IPO/PassManagerBuilder.cpp =================================================================== --- llvm/lib/Transforms/IPO/PassManagerBuilder.cpp +++ llvm/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -92,9 +92,7 @@ cl::init(false), cl::Hidden, cl::desc("Enable Unroll And Jam Pass")); -static cl::opt EnableLoopFlatten("enable-loop-flatten", cl::init(false), - cl::Hidden, - cl::desc("Enable the LoopFlatten Pass")); +extern cl::opt EnableLoopFlatten; static cl::opt EnablePrepareForThinLTO("prepare-for-thinlto", cl::init(false), cl::Hidden, @@ -442,16 +440,15 @@ MPM.add(createInstructionCombiningPass()); // We resume loop passes creating a second loop pipeline here. MPM.add(createIndVarSimplifyPass()); // Canonicalize indvars + if (EnableLoopFlatten) + MPM.add(createLoopSimplifyCFGPass()); + MPM.add(createLoopIdiomPass()); // Recognize idioms like memset. addExtensionsToPM(EP_LateLoopOptimizations, MPM); MPM.add(createLoopDeletionPass()); // Delete dead loops if (EnableLoopInterchange) MPM.add(createLoopInterchangePass()); // Interchange loops - if (EnableLoopFlatten) { - MPM.add(createLoopFlattenPass()); // Flatten loops - MPM.add(createLoopSimplifyCFGPass()); - } // Unroll small loops MPM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, @@ -1046,8 +1043,6 @@ PM.add(createLoopDeletionPass()); if (EnableLoopInterchange) PM.add(createLoopInterchangePass()); - if (EnableLoopFlatten) - PM.add(createLoopFlattenPass()); // Unroll small loops PM.add(createSimpleLoopUnrollPass(OptLevel, DisableUnrollLoops, Index: llvm/lib/Transforms/Scalar/CMakeLists.txt =================================================================== --- llvm/lib/Transforms/Scalar/CMakeLists.txt +++ llvm/lib/Transforms/Scalar/CMakeLists.txt @@ -32,7 +32,6 @@ LoopIdiomRecognize.cpp LoopInstSimplify.cpp LoopInterchange.cpp - LoopFlatten.cpp LoopLoadElimination.cpp LoopPassManager.cpp LoopPredication.cpp Index: llvm/lib/Transforms/Scalar/IndVarSimplify.cpp =================================================================== --- llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -36,10 +36,12 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -89,6 +91,7 @@ #include using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "indvars" @@ -135,6 +138,22 @@ AllowIVWidening("indvars-widen-indvars", cl::Hidden, cl::init(true), cl::desc("Allow widening of indvars to eliminate s/zext")); +static cl::opt RepeatedInstructionThreshold( + "loop-flatten-cost-threshold", cl::Hidden, cl::init(2), + cl::desc("Limit on the cost of instructions that can be repeated due to " + "loop flattening")); + +static cl::opt + AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden, + cl::init(false), + cl::desc("Assume that the product of the two iteration " + "limits will never overflow")); + +cl::opt EnableLoopFlatten("enable-loop-flatten", cl::init(false), + cl::Hidden, + cl::desc("Enable loop flattening in " + "IndVarSimplify")); + namespace { struct RewritePhi; @@ -147,6 +166,8 @@ TargetLibraryInfo *TLI; const TargetTransformInfo *TTI; std::unique_ptr MSSAU; + AssumptionCache *AC; + std::function markLoopAsDeleted; SmallVector DeadInsts; @@ -154,6 +175,11 @@ bool rewriteNonIntegerIVs(Loop *L); bool simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI); + + bool tryFlattenLoopPair(Loop *L, + SmallVectorImpl &DeadInsts, + SCEVExpander &Rewriter); + /// Try to eliminate loop exits based on analyzeable exit counts bool optimizeLoopExits(Loop *L, SCEVExpander &Rewriter); /// Try to form loop invariant tests for loop exits by changing how many @@ -171,8 +197,10 @@ public: IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, const DataLayout &DL, TargetLibraryInfo *TLI, - TargetTransformInfo *TTI, MemorySSA *MSSA) - : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) { + TargetTransformInfo *TTI, MemorySSA *MSSA, AssumptionCache *AC, + std::function markLoopAsDeleted) + : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI), AC(AC), + markLoopAsDeleted(markLoopAsDeleted) { if (MSSA) MSSAU = std::make_unique(MSSA); } @@ -2682,6 +2710,538 @@ return Changed; } +//===----------------------------------------------------------------------===// +// Flatten nested loops. Remove outer-loop induction variables. +//===----------------------------------------------------------------------===// +// +// The intention is to optimise loop nests like this, which together access an +// array linearly: +// for (int i = 0; i < N; ++i) +// for (int j = 0; j < M; ++j) +// f(A[i*M+j]); +// into one loop: +// for (int i = 0; i < (N*M); ++i) +// f(A[i]); +// +// It can also flatten loops where the induction variables are not used in the +// loop. This is only worth doing if the induction variables are only used in an +// expression like i*M+j. If they had any other uses, we would have to insert a +// div/mod to reconstruct the original values, so this wouldn't be profitable. +// +// We also need to prove that N*M will not overflow. + +struct FlattenInfo { + Loop *OuterLoop; + Loop *InnerLoop; + PHINode *InnerInductionPHI; + PHINode *OuterInductionPHI; + Value *InnerLimit; + Value *OuterLimit; + BinaryOperator *InnerIncrement; + BinaryOperator *OuterIncrement; + BranchInst *InnerBranch; + BranchInst *OuterBranch; + SmallPtrSet LinearIVUses; + SmallPtrSet InnerPHIsToTransform; + + FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; +}; + +// Finds the induction variable, increment and limit for a simple loop that we +// can flatten. +static bool findLoopComponents( + Loop *L, SmallPtrSetImpl &IterationInstructions, + PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment, + BranchInst *&BackBranch, ScalarEvolution *SE) { + LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n"); + + if (!L->isLoopSimplifyForm()) { + LLVM_DEBUG(dbgs() << "Loop is not in normal form\n"); + return false; + } + + // There must be exactly one exiting block, and it must be the same at the + // latch. + BasicBlock *Latch = L->getLoopLatch(); + if (L->getExitingBlock() != Latch) { + LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n"); + return false; + } + // Latch block must end in a conditional branch. + BackBranch = dyn_cast(Latch->getTerminator()); + if (!BackBranch || !BackBranch->isConditional()) { + LLVM_DEBUG(dbgs() << "Could not find back-branch\n"); + return false; + } + IterationInstructions.insert(BackBranch); + LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump()); + bool ContinueOnTrue = L->contains(BackBranch->getSuccessor(0)); + + // Find the induction PHI. If there is no induction PHI, we can't do the + // transformation. TODO: could other variables trigger this? Do we have to + // search for the best one? + InductionPHI = nullptr; + for (PHINode &PHI : L->getHeader()->phis()) { + InductionDescriptor ID; + if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) { + InductionPHI = &PHI; + LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump()); + break; + } + } + if (!InductionPHI) { + LLVM_DEBUG(dbgs() << "Could not find induction PHI\n"); + return false; + } + + auto IsValidPredicate = [&](ICmpInst::Predicate Pred) { + if (ContinueOnTrue) + return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT; + else + return Pred == CmpInst::ICMP_EQ; + }; + + // Find Compare and make sure it is valid + ICmpInst *Compare = dyn_cast(BackBranch->getCondition()); + if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) || + Compare->hasNUsesOrMore(2)) { + LLVM_DEBUG(dbgs() << "Could not find valid comparison\n"); + return false; + } + IterationInstructions.insert(Compare); + LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump()); + + // Find increment and limit from the compare + Increment = nullptr; + if (match(Compare->getOperand(0), + m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { + Increment = dyn_cast(Compare->getOperand(0)); + Limit = Compare->getOperand(1); + } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE && + match(Compare->getOperand(1), + m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { + Increment = dyn_cast(Compare->getOperand(1)); + Limit = Compare->getOperand(0); + } + if (!Increment || Increment->hasNUsesOrMore(3)) { + LLVM_DEBUG(dbgs() << "Cound not find valid increment\n"); + return false; + } + IterationInstructions.insert(Increment); + LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); + LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump()); + + assert(InductionPHI->getNumIncomingValues() == 2); + assert(InductionPHI->getIncomingValueForBlock(Latch) == Increment && + "PHI value is not increment inst"); + + auto *CI = dyn_cast( + InductionPHI->getIncomingValueForBlock(L->getLoopPreheader())); + if (!CI || !CI->isZero()) { + LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump()); + return false; + } + + LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); + return true; +} + +static bool checkPHIs(struct FlattenInfo &FI, + const TargetTransformInfo *TTI) { + // All PHIs in the inner and outer headers must either be: + // - The induction PHI, which we are going to rewrite as one induction in + // the new loop. This is already checked by findLoopComponents. + // - An outer header PHI with all incoming values from outside the loop. + // LoopSimplify guarantees we have a pre-header, so we don't need to + // worry about that here. + // - Pairs of PHIs in the inner and outer headers, which implement a + // loop-carried dependency that will still be valid in the new loop. To + // be valid, this variable must be modified only in the inner loop. + + // The set of PHI nodes in the outer loop header that we know will still be + // valid after the transformation. These will not need to be modified (with + // the exception of the induction variable), but we do need to check that + // there are no unsafe PHI nodes. + SmallPtrSet SafeOuterPHIs; + SafeOuterPHIs.insert(FI.OuterInductionPHI); + + // Check that all PHI nodes in the inner loop header match one of the valid + // patterns. + for (PHINode &InnerPHI : FI.InnerLoop->getHeader()->phis()) { + // The induction PHIs break these rules, and that's OK because we treat + // them specially when doing the transformation. + if (&InnerPHI == FI.InnerInductionPHI) + continue; + + // Each inner loop PHI node must have two incoming values/blocks - one + // from the pre-header, and one from the latch. + assert(InnerPHI.getNumIncomingValues() == 2); + Value *PreHeaderValue = + InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopPreheader()); + Value *LatchValue = + InnerPHI.getIncomingValueForBlock(FI.InnerLoop->getLoopLatch()); + + // The incoming value from the outer loop must be the PHI node in the + // outer loop header, with no modifications made in the top of the outer + // loop. + PHINode *OuterPHI = dyn_cast(PreHeaderValue); + if (!OuterPHI || OuterPHI->getParent() != FI.OuterLoop->getHeader()) { + LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n"); + return false; + } + + // The other incoming value must come from the inner loop, without any + // modifications in the tail end of the outer loop. We are in LCSSA form, + // so this will actually be a PHI in the inner loop's exit block, which + // only uses values from inside the inner loop. + PHINode *LCSSAPHI = dyn_cast( + OuterPHI->getIncomingValueForBlock(FI.OuterLoop->getLoopLatch())); + if (!LCSSAPHI) { + LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n"); + return false; + } + + // The value used by the LCSSA PHI must be the same one that the inner + // loop's PHI uses. + if (LCSSAPHI->hasConstantValue() != LatchValue) { + LLVM_DEBUG( + dbgs() << "LCSSA PHI incoming value does not match latch value\n"); + return false; + } + + LLVM_DEBUG(dbgs() << "PHI pair is safe:\n"); + LLVM_DEBUG(dbgs() << " Inner: "; InnerPHI.dump()); + LLVM_DEBUG(dbgs() << " Outer: "; OuterPHI->dump()); + SafeOuterPHIs.insert(OuterPHI); + FI.InnerPHIsToTransform.insert(&InnerPHI); + } + + for (PHINode &OuterPHI : FI.OuterLoop->getHeader()->phis()) { + if (!SafeOuterPHIs.count(&OuterPHI)) { + LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump()); + return false; + } + } + + return true; +} + +static bool +checkOuterLoopInsts(struct FlattenInfo &FI, + SmallPtrSetImpl &IterationInstructions, + const TargetTransformInfo *TTI) { + // Check for instructions in the outer but not inner loop. If any of these + // have side-effects then this transformation is not legal, and if there is + // a significant amount of code here which can't be optimised out that it's + // not profitable (as these instructions would get executed for each + // iteration of the inner loop). + unsigned RepeatedInstrCost = 0; + for (auto *B : FI.OuterLoop->getBlocks()) { + if (FI.InnerLoop->contains(B)) + continue; + + for (auto &I : *B) { + if (!isa(&I) && !I.isTerminator() && + !isSafeToSpeculativelyExecute(&I)) { + LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have " + "side effects: "; + I.dump()); + return false; + } + // The execution count of the outer loop's iteration instructions + // (increment, compare and branch) will be increased, but the + // equivalent instructions will be removed from the inner loop, so + // they make a net difference of zero. + if (IterationInstructions.count(&I)) + continue; + // The uncoditional branch to the inner loop's header will turn into + // a fall-through, so adds no cost. + BranchInst *Br = dyn_cast(&I); + if (Br && Br->isUnconditional() && + Br->getSuccessor(0) == FI.InnerLoop->getHeader()) + continue; + // Multiplies of the outer iteration variable and inner iteration + // count will be optimised out. + if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI), + m_Specific(FI.InnerLimit)))) + continue; + int Cost = TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); + LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump()); + RepeatedInstrCost += Cost; + } + } + + LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: " + << RepeatedInstrCost << "\n"); + // Bail out if flattening the loops would cause instructions in the outer + // loop but not in the inner loop to be executed extra times. + if (RepeatedInstrCost > RepeatedInstructionThreshold) + return false; + + return true; +} + +static bool checkIVUsers(struct FlattenInfo &FI) { + //SmallPtrSetImpl &LinearIVUses) { + // We require all uses of both induction variables to match this pattern: + // + // (OuterPHI * InnerLimit) + InnerPHI + // + // Any uses of the induction variables not matching that pattern would + // require a div/mod to reconstruct in the flattened loop, so the + // transformation wouldn't be profitable. + + // Check that all uses of the inner loop's induction variable match the + // expected pattern, recording the uses of the outer IV. + SmallPtrSet ValidOuterPHIUses; + for (User *U : FI.InnerInductionPHI->users()) { + if (U == FI.InnerIncrement) + continue; + + LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); + + Value *MatchedMul, *MatchedItCount; + if (match(U, m_c_Add(m_Specific(FI.InnerInductionPHI), m_Value(MatchedMul))) && + match(MatchedMul, + m_c_Mul(m_Specific(FI.OuterInductionPHI), m_Value(MatchedItCount))) && + MatchedItCount == FI.InnerLimit) { + LLVM_DEBUG(dbgs() << "Use is optimisable\n"); + ValidOuterPHIUses.insert(MatchedMul); + FI.LinearIVUses.insert(U); + } else { + LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); + return false; + } + } + + // Check that there are no uses of the outer IV other than the ones found + // as part of the pattern above. + for (User *U : FI.OuterInductionPHI->users()) { + if (U == FI.OuterIncrement) + continue; + + LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); + + if (!ValidOuterPHIUses.count(U)) { + LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); + return false; + } else { + LLVM_DEBUG(dbgs() << "Use is optimisable\n"); + } + } + + LLVM_DEBUG(dbgs() << "Found " << FI.LinearIVUses.size() + << " value(s) that can be replaced:\n"; + for (Value *V : FI.LinearIVUses) { + dbgs() << " "; + V->dump(); + }); + + return true; +} + +// Return an OverflowResult dependant on if overflow of the multiplication of +// InnerLimit and OuterLimit can be assumed not to happen. +static OverflowResult checkOverflow(struct FlattenInfo &FI, + DominatorTree *DT, AssumptionCache *AC) { + Function *F = FI.OuterLoop->getHeader()->getParent(); + const DataLayout &DL = F->getParent()->getDataLayout(); + + // For debugging/testing. + if (AssumeNoOverflow) + return OverflowResult::NeverOverflows; + + // Check if the multiply could not overflow due to known ranges of the + // input values. + OverflowResult OR = computeOverflowForUnsignedMul( + FI.InnerLimit, FI.OuterLimit, DL, AC, + FI.OuterLoop->getLoopPreheader()->getTerminator(), DT); + if (OR != OverflowResult::MayOverflow) + return OR; + + for (Value *V : FI.LinearIVUses) { + for (Value *U : V->users()) { + if (auto *GEP = dyn_cast(U)) { + // The IV is used as the operand of a GEP, and the IV is at least as + // wide as the address space of the GEP. In this case, the GEP would + // wrap around the address space before the IV increment wraps, which + // would be UB. + if (GEP->isInBounds() && + V->getType()->getIntegerBitWidth() >= + DL.getPointerTypeSizeInBits(GEP->getType())) { + LLVM_DEBUG( + dbgs() << "use of linear IV would be UB if overflow occurred: "; + GEP->dump()); + return OverflowResult::NeverOverflows; + } + } + } + } + + return OverflowResult::MayOverflow; +} + +static bool CanFlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, const TargetTransformInfo *TTI, + std::function markLoopAsDeleted) { + Function *F = FI.OuterLoop->getHeader()->getParent(); + + LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop " + << FI.OuterLoop->getHeader()->getName() << " and inner loop " + << FI.InnerLoop->getHeader()->getName() << " in " + << F->getName() << "\n"); + + SmallPtrSet IterationInstructions; + + if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI, + FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE)) + return false; + if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI, + FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE)) + return false; + + // Both of the loop limit values must be invariant in the outer loop + // (non-instructions are all inherently invariant). + bool Changed; + if (!FI.OuterLoop->makeLoopInvariant(FI.InnerLimit, Changed)) { + LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n"); + return false; + } + if (!FI.OuterLoop->makeLoopInvariant(FI.OuterLimit, Changed)) { + LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n"); + return false; + } + + if (!checkPHIs(FI, TTI)) + return false; + + if (!checkOuterLoopInsts(FI, IterationInstructions, TTI)) + return false; + + // Find the values in the loop that can be replaced with the linearized + // induction variable, and check that there are no other uses of the inner + // or outer induction variable. If there were, we could still do this + // transformation, but we'd have to insert a div/mod to calculate the + // original IVs, so it wouldn't be profitable. + if (!checkIVUsers(FI)) + return false; + + return true; +} + +static void FlattenLoopPair(struct FlattenInfo &FI, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + AssumptionCache *AC, const TargetTransformInfo *TTI, + std::function markLoopAsDeleted) { + LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n"); + Function *F = FI.OuterLoop->getHeader()->getParent(); + + { + using namespace ore; + OptimizationRemark Remark(DEBUG_TYPE, "Flattened", FI.InnerLoop->getStartLoc(), + FI.InnerLoop->getHeader()); + OptimizationRemarkEmitter ORE(F); + Remark << "Flattened into outer loop"; + ORE.emit(Remark); + } + + Value *NewTripCount = + BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount", + FI.OuterLoop->getLoopPreheader()->getTerminator()); + LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; + NewTripCount->dump()); + + // Fix up PHI nodes that take values from the inner loop back-edge, which + // we are about to remove. + FI.InnerInductionPHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); + for (PHINode *PHI : FI.InnerPHIsToTransform) + PHI->removeIncomingValue(FI.InnerLoop->getLoopLatch()); + + // Modify the trip count of the outer loop to be the product of the two + // trip counts. + cast(FI.OuterBranch->getCondition())->setOperand(1, NewTripCount); + + // Replace the inner loop backedge with an unconditional branch to the exit. + BasicBlock *InnerExitBlock = FI.InnerLoop->getExitBlock(); + BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock(); + InnerExitingBlock->getTerminator()->eraseFromParent(); + BranchInst::Create(InnerExitBlock, InnerExitingBlock); + DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader()); + + // Replace all uses of the polynomial calculated from the two induction + // variables with the one new one. + for (Value *V : FI.LinearIVUses) + V->replaceAllUsesWith(FI.OuterInductionPHI); + + // Tell LoopInfo, SCEV and the pass manager that the inner loop has been + // deleted, and any information that have about the outer loop invalidated. + markLoopAsDeleted(FI.InnerLoop); + SE->forgetLoop(FI.OuterLoop); + SE->forgetLoop(FI.InnerLoop); + LI->erase(FI.InnerLoop); +} + +bool IndVarSimplify::tryFlattenLoopPair(Loop *L, + SmallVectorImpl &DeadInsts, SCEVExpander &Rewriter) { + if (!EnableLoopFlatten) + return false; + if (!L->getParentLoop()) { + return false; + } + + struct FlattenInfo FI(L->getParentLoop(), L); + + if (!CanFlattenLoopPair(FI, DT, LI, SE, AC, TTI, markLoopAsDeleted)) + return false; + + LLVM_DEBUG(dbgs() << "INDVARS: flattening loop nest!\n"); + + Module *M = L->getHeader()->getParent()->getParent(); + auto &DL = M->getDataLayout(); + auto *InnerType = FI.InnerInductionPHI->getType(); + auto *OuterType = FI.OuterInductionPHI->getType(); + unsigned MaxLegalSize = DL.getLargestLegalIntTypeSizeInBits(); + auto *MaxLegalType = DL.getLargestLegalIntType(M->getContext()); + + // If both induction types are less than maximum integer width, promote + // both to the widest type available so we know calculating Limit * Limit + // as the new trip count is safe. + if (InnerType == OuterType && + InnerType->getScalarSizeInBits() < MaxLegalSize) { + LLVM_DEBUG(dbgs() << "Promote induction phis to " << MaxLegalSize << "\n"); + + SmallVector WideIVs; + auto AddCandidatePhi = [&] (BasicBlock::iterator I) { + for ( ; isa(I); ++I) { + LLVM_DEBUG(dbgs() << "INDVARS: widen phi: "; cast(I)->dump()); + WideIVs.push_back( {cast(I), MaxLegalType, false }); + } + }; + + AddCandidatePhi(L->getHeader()->begin()); + AddCandidatePhi(L->getParentLoop()->getHeader()->begin()); + + for (; !WideIVs.empty(); WideIVs.pop_back()) { + WidenIV Widener(WideIVs.back(), LI, SE, DT, DeadInsts, true); + if (PHINode *WidePhi = Widener.createWideIV(Rewriter)) { + LLVM_DEBUG(dbgs() << "INDVARS: created wide phi: "; WidePhi->dump()); + } else { + return false; + } + } + } else if (checkOverflow(FI, DT, AC) == OverflowResult::MayOverflow) { + LLVM_DEBUG(dbgs() << "INDVARS: overflow possible, bailing...\n"); + return false; + } else { + // TODO: support different sized phis. + return false; + } + + FlattenLoopPair(FI, DT, LI, SE, AC, TTI, markLoopAsDeleted); + return true; +} + //===----------------------------------------------------------------------===// // IndVarSimplify driver. Manage several subpasses of IV simplification. //===----------------------------------------------------------------------===// @@ -2731,6 +3291,8 @@ // other expressions involving loop IVs have been evaluated. This helps SCEV // set no-wrap flags before normalizing sign/zero extension. Rewriter.disableCanonicalMode(); + + Changed |= tryFlattenLoopPair(L, DeadInsts, Rewriter); Changed |= simplifyAndExtend(L, Rewriter, LI); // Check to see if we can compute the final value of any expressions @@ -2879,11 +3441,14 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, - LPMUpdater &) { + LPMUpdater &Updater) { Function *F = L.getHeader()->getParent(); const DataLayout &DL = F->getParent()->getDataLayout(); - IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI, AR.MSSA); + std::string LoopName(L.getName()); + IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI, AR.MSSA, &AR.AC, + [&](Loop *L) { Updater.markLoopAsDeleted(*L, LoopName); }); + if (!IVS.run(&L)) return PreservedAnalyses::all(); @@ -2916,11 +3481,15 @@ auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr; const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); auto *MSSAAnalysis = getAnalysisIfAvailable(); + auto *AC = &getAnalysis().getAssumptionCache( + *L->getHeader()->getParent()); + MemorySSA *MSSA = nullptr; if (MSSAAnalysis) MSSA = &MSSAAnalysis->getMSSA(); - IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI, MSSA); + IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI, MSSA, AC, + [&](Loop *L) { LPM.markLoopAsDeleted(*L); }); return IVS.run(L); } @@ -2928,6 +3497,8 @@ AU.setPreservesCFG(); AU.addPreserved(); getLoopAnalysisUsage(AU); + AU.addRequired(); + AU.addPreserved(); } }; @@ -2938,6 +3509,7 @@ INITIALIZE_PASS_BEGIN(IndVarSimplifyLegacyPass, "indvars", "Induction Variable Simplification", false, false) INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(IndVarSimplifyLegacyPass, "indvars", "Induction Variable Simplification", false, false) Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ /dev/null @@ -1,606 +0,0 @@ -//===- LoopFlatten.cpp - Loop flattening pass------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This pass flattens pairs nested loops into a single loop. -// -// The intention is to optimise loop nests like this, which together access an -// array linearly: -// for (int i = 0; i < N; ++i) -// for (int j = 0; j < M; ++j) -// f(A[i*M+j]); -// into one loop: -// for (int i = 0; i < (N*M); ++i) -// f(A[i]); -// -// It can also flatten loops where the induction variables are not used in the -// loop. This is only worth doing if the induction variables are only used in an -// expression like i*M+j. If they had any other uses, we would have to insert a -// div/mod to reconstruct the original values, so this wouldn't be profitable. -// -// We also need to prove that N*M will not overflow. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/Scalar/LoopFlatten.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/OptimizationRemarkEmitter.h" -#include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Dominators.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Module.h" -#include "llvm/IR/PatternMatch.h" -#include "llvm/IR/Verifier.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/LoopUtils.h" - -#define DEBUG_TYPE "loop-flatten" - -using namespace llvm; -using namespace llvm::PatternMatch; - -static cl::opt RepeatedInstructionThreshold( - "loop-flatten-cost-threshold", cl::Hidden, cl::init(2), - cl::desc("Limit on the cost of instructions that can be repeated due to " - "loop flattening")); - -static cl::opt - AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden, - cl::init(false), - cl::desc("Assume that the product of the two iteration " - "limits will never overflow")); - -// Finds the induction variable, increment and limit for a simple loop that we -// can flatten. -static bool findLoopComponents( - Loop *L, SmallPtrSetImpl &IterationInstructions, - PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment, - BranchInst *&BackBranch, ScalarEvolution *SE) { - LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n"); - - if (!L->isLoopSimplifyForm()) { - LLVM_DEBUG(dbgs() << "Loop is not in normal form\n"); - return false; - } - - // There must be exactly one exiting block, and it must be the same at the - // latch. - BasicBlock *Latch = L->getLoopLatch(); - if (L->getExitingBlock() != Latch) { - LLVM_DEBUG(dbgs() << "Exiting and latch block are different\n"); - return false; - } - // Latch block must end in a conditional branch. - BackBranch = dyn_cast(Latch->getTerminator()); - if (!BackBranch || !BackBranch->isConditional()) { - LLVM_DEBUG(dbgs() << "Could not find back-branch\n"); - return false; - } - IterationInstructions.insert(BackBranch); - LLVM_DEBUG(dbgs() << "Found back branch: "; BackBranch->dump()); - bool ContinueOnTrue = L->contains(BackBranch->getSuccessor(0)); - - // Find the induction PHI. If there is no induction PHI, we can't do the - // transformation. TODO: could other variables trigger this? Do we have to - // search for the best one? - InductionPHI = nullptr; - for (PHINode &PHI : L->getHeader()->phis()) { - InductionDescriptor ID; - if (InductionDescriptor::isInductionPHI(&PHI, L, SE, ID)) { - InductionPHI = &PHI; - LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump()); - break; - } - } - if (!InductionPHI) { - LLVM_DEBUG(dbgs() << "Could not find induction PHI\n"); - return false; - } - - auto IsValidPredicate = [&](ICmpInst::Predicate Pred) { - if (ContinueOnTrue) - return Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT; - else - return Pred == CmpInst::ICMP_EQ; - }; - - // Find Compare and make sure it is valid - ICmpInst *Compare = dyn_cast(BackBranch->getCondition()); - if (!Compare || !IsValidPredicate(Compare->getUnsignedPredicate()) || - Compare->hasNUsesOrMore(2)) { - LLVM_DEBUG(dbgs() << "Could not find valid comparison\n"); - return false; - } - IterationInstructions.insert(Compare); - LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump()); - - // Find increment and limit from the compare - Increment = nullptr; - if (match(Compare->getOperand(0), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast(Compare->getOperand(0)); - Limit = Compare->getOperand(1); - } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE && - match(Compare->getOperand(1), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast(Compare->getOperand(1)); - Limit = Compare->getOperand(0); - } - if (!Increment || Increment->hasNUsesOrMore(3)) { - LLVM_DEBUG(dbgs() << "Cound not find valid increment\n"); - return false; - } - IterationInstructions.insert(Increment); - LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); - LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump()); - - assert(InductionPHI->getNumIncomingValues() == 2); - assert(InductionPHI->getIncomingValueForBlock(Latch) == Increment && - "PHI value is not increment inst"); - - auto *CI = dyn_cast( - InductionPHI->getIncomingValueForBlock(L->getLoopPreheader())); - if (!CI || !CI->isZero()) { - LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump()); - return false; - } - - LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); - return true; -} - -static bool checkPHIs(Loop *OuterLoop, Loop *InnerLoop, - SmallPtrSetImpl &InnerPHIsToTransform, - PHINode *InnerInductionPHI, PHINode *OuterInductionPHI, - TargetTransformInfo *TTI) { - // All PHIs in the inner and outer headers must either be: - // - The induction PHI, which we are going to rewrite as one induction in - // the new loop. This is already checked by findLoopComponents. - // - An outer header PHI with all incoming values from outside the loop. - // LoopSimplify guarantees we have a pre-header, so we don't need to - // worry about that here. - // - Pairs of PHIs in the inner and outer headers, which implement a - // loop-carried dependency that will still be valid in the new loop. To - // be valid, this variable must be modified only in the inner loop. - - // The set of PHI nodes in the outer loop header that we know will still be - // valid after the transformation. These will not need to be modified (with - // the exception of the induction variable), but we do need to check that - // there are no unsafe PHI nodes. - SmallPtrSet SafeOuterPHIs; - SafeOuterPHIs.insert(OuterInductionPHI); - - // Check that all PHI nodes in the inner loop header match one of the valid - // patterns. - for (PHINode &InnerPHI : InnerLoop->getHeader()->phis()) { - // The induction PHIs break these rules, and that's OK because we treat - // them specially when doing the transformation. - if (&InnerPHI == InnerInductionPHI) - continue; - - // Each inner loop PHI node must have two incoming values/blocks - one - // from the pre-header, and one from the latch. - assert(InnerPHI.getNumIncomingValues() == 2); - Value *PreHeaderValue = - InnerPHI.getIncomingValueForBlock(InnerLoop->getLoopPreheader()); - Value *LatchValue = - InnerPHI.getIncomingValueForBlock(InnerLoop->getLoopLatch()); - - // The incoming value from the outer loop must be the PHI node in the - // outer loop header, with no modifications made in the top of the outer - // loop. - PHINode *OuterPHI = dyn_cast(PreHeaderValue); - if (!OuterPHI || OuterPHI->getParent() != OuterLoop->getHeader()) { - LLVM_DEBUG(dbgs() << "value modified in top of outer loop\n"); - return false; - } - - // The other incoming value must come from the inner loop, without any - // modifications in the tail end of the outer loop. We are in LCSSA form, - // so this will actually be a PHI in the inner loop's exit block, which - // only uses values from inside the inner loop. - PHINode *LCSSAPHI = dyn_cast( - OuterPHI->getIncomingValueForBlock(OuterLoop->getLoopLatch())); - if (!LCSSAPHI) { - LLVM_DEBUG(dbgs() << "could not find LCSSA PHI\n"); - return false; - } - - // The value used by the LCSSA PHI must be the same one that the inner - // loop's PHI uses. - if (LCSSAPHI->hasConstantValue() != LatchValue) { - LLVM_DEBUG( - dbgs() << "LCSSA PHI incoming value does not match latch value\n"); - return false; - } - - LLVM_DEBUG(dbgs() << "PHI pair is safe:\n"); - LLVM_DEBUG(dbgs() << " Inner: "; InnerPHI.dump()); - LLVM_DEBUG(dbgs() << " Outer: "; OuterPHI->dump()); - SafeOuterPHIs.insert(OuterPHI); - InnerPHIsToTransform.insert(&InnerPHI); - } - - for (PHINode &OuterPHI : OuterLoop->getHeader()->phis()) { - if (!SafeOuterPHIs.count(&OuterPHI)) { - LLVM_DEBUG(dbgs() << "found unsafe PHI in outer loop: "; OuterPHI.dump()); - return false; - } - } - - return true; -} - -static bool -checkOuterLoopInsts(Loop *OuterLoop, Loop *InnerLoop, - SmallPtrSetImpl &IterationInstructions, - Value *InnerLimit, PHINode *OuterPHI, - TargetTransformInfo *TTI) { - // Check for instructions in the outer but not inner loop. If any of these - // have side-effects then this transformation is not legal, and if there is - // a significant amount of code here which can't be optimised out that it's - // not profitable (as these instructions would get executed for each - // iteration of the inner loop). - unsigned RepeatedInstrCost = 0; - for (auto *B : OuterLoop->getBlocks()) { - if (InnerLoop->contains(B)) - continue; - - for (auto &I : *B) { - if (!isa(&I) && !I.isTerminator() && - !isSafeToSpeculativelyExecute(&I)) { - LLVM_DEBUG(dbgs() << "Cannot flatten because instruction may have " - "side effects: "; - I.dump()); - return false; - } - // The execution count of the outer loop's iteration instructions - // (increment, compare and branch) will be increased, but the - // equivalent instructions will be removed from the inner loop, so - // they make a net difference of zero. - if (IterationInstructions.count(&I)) - continue; - // The uncoditional branch to the inner loop's header will turn into - // a fall-through, so adds no cost. - BranchInst *Br = dyn_cast(&I); - if (Br && Br->isUnconditional() && - Br->getSuccessor(0) == InnerLoop->getHeader()) - continue; - // Multiplies of the outer iteration variable and inner iteration - // count will be optimised out. - if (match(&I, m_c_Mul(m_Specific(OuterPHI), m_Specific(InnerLimit)))) - continue; - int Cost = TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); - LLVM_DEBUG(dbgs() << "Cost " << Cost << ": "; I.dump()); - RepeatedInstrCost += Cost; - } - } - - LLVM_DEBUG(dbgs() << "Cost of instructions that will be repeated: " - << RepeatedInstrCost << "\n"); - // Bail out if flattening the loops would cause instructions in the outer - // loop but not in the inner loop to be executed extra times. - if (RepeatedInstrCost > RepeatedInstructionThreshold) - return false; - - return true; -} - -static bool checkIVUsers(PHINode *InnerPHI, PHINode *OuterPHI, - BinaryOperator *InnerIncrement, - BinaryOperator *OuterIncrement, Value *InnerLimit, - SmallPtrSetImpl &LinearIVUses) { - // We require all uses of both induction variables to match this pattern: - // - // (OuterPHI * InnerLimit) + InnerPHI - // - // Any uses of the induction variables not matching that pattern would - // require a div/mod to reconstruct in the flattened loop, so the - // transformation wouldn't be profitable. - - // Check that all uses of the inner loop's induction variable match the - // expected pattern, recording the uses of the outer IV. - SmallPtrSet ValidOuterPHIUses; - for (User *U : InnerPHI->users()) { - if (U == InnerIncrement) - continue; - - LLVM_DEBUG(dbgs() << "Found use of inner induction variable: "; U->dump()); - - Value *MatchedMul, *MatchedItCount; - if (match(U, m_c_Add(m_Specific(InnerPHI), m_Value(MatchedMul))) && - match(MatchedMul, - m_c_Mul(m_Specific(OuterPHI), m_Value(MatchedItCount))) && - MatchedItCount == InnerLimit) { - LLVM_DEBUG(dbgs() << "Use is optimisable\n"); - ValidOuterPHIUses.insert(MatchedMul); - LinearIVUses.insert(U); - } else { - LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); - return false; - } - } - - // Check that there are no uses of the outer IV other than the ones found - // as part of the pattern above. - for (User *U : OuterPHI->users()) { - if (U == OuterIncrement) - continue; - - LLVM_DEBUG(dbgs() << "Found use of outer induction variable: "; U->dump()); - - if (!ValidOuterPHIUses.count(U)) { - LLVM_DEBUG(dbgs() << "Did not match expected pattern, bailing\n"); - return false; - } else { - LLVM_DEBUG(dbgs() << "Use is optimisable\n"); - } - } - - LLVM_DEBUG(dbgs() << "Found " << LinearIVUses.size() - << " value(s) that can be replaced:\n"; - for (Value *V : LinearIVUses) { - dbgs() << " "; - V->dump(); - }); - - return true; -} - -// Return an OverflowResult dependant on if overflow of the multiplication of -// InnerLimit and OuterLimit can be assumed not to happen. -static OverflowResult checkOverflow(Loop *OuterLoop, Value *InnerLimit, - Value *OuterLimit, - SmallPtrSetImpl &LinearIVUses, - DominatorTree *DT, AssumptionCache *AC) { - Function *F = OuterLoop->getHeader()->getParent(); - const DataLayout &DL = F->getParent()->getDataLayout(); - - // For debugging/testing. - if (AssumeNoOverflow) - return OverflowResult::NeverOverflows; - - // Check if the multiply could not overflow due to known ranges of the - // input values. - OverflowResult OR = computeOverflowForUnsignedMul( - InnerLimit, OuterLimit, DL, AC, - OuterLoop->getLoopPreheader()->getTerminator(), DT); - if (OR != OverflowResult::MayOverflow) - return OR; - - for (Value *V : LinearIVUses) { - for (Value *U : V->users()) { - if (auto *GEP = dyn_cast(U)) { - // The IV is used as the operand of a GEP, and the IV is at least as - // wide as the address space of the GEP. In this case, the GEP would - // wrap around the address space before the IV increment wraps, which - // would be UB. - if (GEP->isInBounds() && - V->getType()->getIntegerBitWidth() >= - DL.getPointerTypeSizeInBits(GEP->getType())) { - LLVM_DEBUG( - dbgs() << "use of linear IV would be UB if overflow occurred: "; - GEP->dump()); - return OverflowResult::NeverOverflows; - } - } - } - } - - return OverflowResult::MayOverflow; -} - -static bool FlattenLoopPair(Loop *OuterLoop, Loop *InnerLoop, DominatorTree *DT, - LoopInfo *LI, ScalarEvolution *SE, - AssumptionCache *AC, TargetTransformInfo *TTI, - std::function markLoopAsDeleted) { - Function *F = OuterLoop->getHeader()->getParent(); - - LLVM_DEBUG(dbgs() << "Loop flattening running on outer loop " - << OuterLoop->getHeader()->getName() << " and inner loop " - << InnerLoop->getHeader()->getName() << " in " - << F->getName() << "\n"); - - SmallPtrSet IterationInstructions; - - PHINode *InnerInductionPHI, *OuterInductionPHI; - Value *InnerLimit, *OuterLimit; - BinaryOperator *InnerIncrement, *OuterIncrement; - BranchInst *InnerBranch, *OuterBranch; - - if (!findLoopComponents(InnerLoop, IterationInstructions, InnerInductionPHI, - InnerLimit, InnerIncrement, InnerBranch, SE)) - return false; - if (!findLoopComponents(OuterLoop, IterationInstructions, OuterInductionPHI, - OuterLimit, OuterIncrement, OuterBranch, SE)) - return false; - - // Both of the loop limit values must be invariant in the outer loop - // (non-instructions are all inherently invariant). - bool Changed; - if (!OuterLoop->makeLoopInvariant(InnerLimit, Changed)) { - LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n"); - return false; - } - if (!OuterLoop->makeLoopInvariant(OuterLimit, Changed)) { - LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n"); - return false; - } - - SmallPtrSet InnerPHIsToTransform; - if (!checkPHIs(OuterLoop, InnerLoop, InnerPHIsToTransform, InnerInductionPHI, - OuterInductionPHI, TTI)) - return false; - - // FIXME: it should be possible to handle different types correctly. - if (InnerInductionPHI->getType() != OuterInductionPHI->getType()) - return false; - - if (!checkOuterLoopInsts(OuterLoop, InnerLoop, IterationInstructions, - InnerLimit, OuterInductionPHI, TTI)) - return false; - - // Find the values in the loop that can be replaced with the linearized - // induction variable, and check that there are no other uses of the inner - // or outer induction variable. If there were, we could still do this - // transformation, but we'd have to insert a div/mod to calculate the - // original IVs, so it wouldn't be profitable. - SmallPtrSet LinearIVUses; - if (!checkIVUsers(InnerInductionPHI, OuterInductionPHI, InnerIncrement, - OuterIncrement, InnerLimit, LinearIVUses)) - return false; - - // Check if the new iteration variable might overflow. In this case, we - // need to version the loop, and select the original version at runtime if - // the iteration space is too large. - // TODO: We currently don't version the loop. - // TODO: it might be worth using a wider iteration variable rather than - // versioning the loop, if a wide enough type is legal. - bool MustVersionLoop = true; - OverflowResult OR = - checkOverflow(OuterLoop, InnerLimit, OuterLimit, LinearIVUses, DT, AC); - if (OR == OverflowResult::AlwaysOverflowsHigh || - OR == OverflowResult::AlwaysOverflowsLow) { - LLVM_DEBUG(dbgs() << "Multiply would always overflow, so not profitable\n"); - return false; - } else if (OR == OverflowResult::MayOverflow) { - LLVM_DEBUG(dbgs() << "Multiply might overflow, not flattening\n"); - } else { - LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n"); - MustVersionLoop = false; - } - - // We cannot safely flatten the loop. Exit now. - if (MustVersionLoop) - return false; - - // Do the actual transformation. - LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n"); - - { - using namespace ore; - OptimizationRemark Remark(DEBUG_TYPE, "Flattened", InnerLoop->getStartLoc(), - InnerLoop->getHeader()); - OptimizationRemarkEmitter ORE(F); - Remark << "Flattened into outer loop"; - ORE.emit(Remark); - } - - Value *NewTripCount = - BinaryOperator::CreateMul(InnerLimit, OuterLimit, "flatten.tripcount", - OuterLoop->getLoopPreheader()->getTerminator()); - LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; - NewTripCount->dump()); - - // Fix up PHI nodes that take values from the inner loop back-edge, which - // we are about to remove. - InnerInductionPHI->removeIncomingValue(InnerLoop->getLoopLatch()); - for (PHINode *PHI : InnerPHIsToTransform) - PHI->removeIncomingValue(InnerLoop->getLoopLatch()); - - // Modify the trip count of the outer loop to be the product of the two - // trip counts. - cast(OuterBranch->getCondition())->setOperand(1, NewTripCount); - - // Replace the inner loop backedge with an unconditional branch to the exit. - BasicBlock *InnerExitBlock = InnerLoop->getExitBlock(); - BasicBlock *InnerExitingBlock = InnerLoop->getExitingBlock(); - InnerExitingBlock->getTerminator()->eraseFromParent(); - BranchInst::Create(InnerExitBlock, InnerExitingBlock); - DT->deleteEdge(InnerExitingBlock, InnerLoop->getHeader()); - - // Replace all uses of the polynomial calculated from the two induction - // variables with the one new one. - for (Value *V : LinearIVUses) - V->replaceAllUsesWith(OuterInductionPHI); - - // Tell LoopInfo, SCEV and the pass manager that the inner loop has been - // deleted, and any information that have about the outer loop invalidated. - markLoopAsDeleted(InnerLoop); - SE->forgetLoop(OuterLoop); - SE->forgetLoop(InnerLoop); - LI->erase(InnerLoop); - - return true; -} - -PreservedAnalyses LoopFlattenPass::run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, - LPMUpdater &Updater) { - if (L.getSubLoops().size() != 1) - return PreservedAnalyses::all(); - - Loop *InnerLoop = *L.begin(); - std::string LoopName(InnerLoop->getName()); - if (!FlattenLoopPair( - &L, InnerLoop, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, - [&](Loop *L) { Updater.markLoopAsDeleted(*L, LoopName); })) - return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); -} - -namespace { -class LoopFlattenLegacyPass : public LoopPass { -public: - static char ID; // Pass ID, replacement for typeid - LoopFlattenLegacyPass() : LoopPass(ID) { - initializeLoopFlattenLegacyPassPass(*PassRegistry::getPassRegistry()); - } - - // Possibly flatten loop L into its child. - bool runOnLoop(Loop *L, LPPassManager &) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - getLoopAnalysisUsage(AU); - AU.addRequired(); - AU.addPreserved(); - AU.addRequired(); - AU.addPreserved(); - } -}; -} // namespace - -char LoopFlattenLegacyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", - false, false) -INITIALIZE_PASS_DEPENDENCY(LoopPass) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops", - false, false) - -Pass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); } - -bool LoopFlattenLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { - if (skipLoop(L)) - return false; - - if (L->getSubLoops().size() != 1) - return false; - - ScalarEvolution *SE = &getAnalysis().getSE(); - LoopInfo *LI = &getAnalysis().getLoopInfo(); - auto *DTWP = getAnalysisIfAvailable(); - DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; - auto &TTIP = getAnalysis(); - TargetTransformInfo *TTI = &TTIP.getTTI(*L->getHeader()->getParent()); - AssumptionCache *AC = - &getAnalysis().getAssumptionCache( - *L->getHeader()->getParent()); - - Loop *InnerLoop = *L->begin(); - return FlattenLoopPair(L, InnerLoop, DT, LI, SE, AC, TTI, - [&](Loop *L) { LPM.markLoopAsDeleted(*L); }); -} Index: llvm/lib/Transforms/Scalar/Scalar.cpp =================================================================== --- llvm/lib/Transforms/Scalar/Scalar.cpp +++ llvm/lib/Transforms/Scalar/Scalar.cpp @@ -67,7 +67,6 @@ initializeLoopAccessLegacyAnalysisPass(Registry); initializeLoopInstSimplifyLegacyPassPass(Registry); initializeLoopInterchangeLegacyPassPass(Registry); - initializeLoopFlattenLegacyPassPass(Registry); initializeLoopPredicationLegacyPassPass(Registry); initializeLoopRotateLegacyPassPass(Registry); initializeLoopStrengthReducePass(Registry); @@ -187,10 +186,6 @@ unwrap(PM)->add(createLoopDeletionPass()); } -void LLVMAddLoopFlattenPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLoopFlattenPass()); -} - void LLVMAddLoopIdiomPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createLoopIdiomPass()); } Index: llvm/utils/gn/secondary/llvm/lib/Transforms/Scalar/BUILD.gn =================================================================== --- llvm/utils/gn/secondary/llvm/lib/Transforms/Scalar/BUILD.gn +++ llvm/utils/gn/secondary/llvm/lib/Transforms/Scalar/BUILD.gn @@ -38,7 +38,6 @@ "LoopDataPrefetch.cpp", "LoopDeletion.cpp", "LoopDistribute.cpp", - "LoopFlatten.cpp", "LoopFuse.cpp", "LoopIdiomRecognize.cpp", "LoopInstSimplify.cpp",