diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -63,7 +63,7 @@ AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden, cl::init(false), cl::desc("Assume that the product of the two iteration " - "limits will never overflow")); + "trip counts will never overflow")); static cl::opt WidenIV("loop-flatten-widen-iv", cl::Hidden, @@ -74,10 +74,12 @@ struct FlattenInfo { Loop *OuterLoop = nullptr; Loop *InnerLoop = nullptr; + // These PHINodes correspond to loop induction variables, which are expected + // to start at zero and increment by one on each loop. PHINode *InnerInductionPHI = nullptr; PHINode *OuterInductionPHI = nullptr; - Value *InnerLimit = nullptr; - Value *OuterLimit = nullptr; + Value *InnerTripCount = nullptr; + Value *OuterTripCount = nullptr; BinaryOperator *InnerIncrement = nullptr; BinaryOperator *OuterIncrement = nullptr; BranchInst *InnerBranch = nullptr; @@ -91,12 +93,12 @@ FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; }; -// Finds the induction variable, increment and limit for a simple loop that we -// can flatten. +// Finds the induction variable, increment and trip count for a simple loop that +// we can flatten. static bool findLoopComponents( Loop *L, SmallPtrSetImpl &IterationInstructions, - PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment, - BranchInst *&BackBranch, ScalarEvolution *SE) { + PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, + BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n"); if (!L->isLoopSimplifyForm()) { @@ -104,6 +106,13 @@ return false; } + // Currently, to simplify the implementation, the Loop induction variable must + // start at zero and increment with a step size of one. + if (!L->isCanonical(*SE)) { + LLVM_DEBUG(dbgs() << "Loop is not canonical\n"); + return false; + } + // There must be exactly one exiting block, and it must be the same at the // latch. BasicBlock *Latch = L->getLoopLatch(); @@ -144,40 +153,44 @@ IterationInstructions.insert(Compare); LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump()); - // Find increment and limit from the compare - Increment = nullptr; - if (match(Compare->getOperand(0), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast(Compare->getOperand(0)); - Limit = Compare->getOperand(1); - } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE && - match(Compare->getOperand(1), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast(Compare->getOperand(1)); - Limit = Compare->getOperand(0); - } - if (!Increment || Increment->hasNUsesOrMore(3)) { - LLVM_DEBUG(dbgs() << "Cound not find valid increment\n"); + // Find increment and trip count. + // There are exactly 2 incoming values to the induction phi; one from the + // pre-header and one from the latch. The incoming latch value is the + // increment variable. + Increment = + dyn_cast(InductionPHI->getIncomingValueForBlock(Latch)); + if (Increment->hasNUsesOrMore(3)) { + LLVM_DEBUG(dbgs() << "Could not find valid increment\n"); return false; } + // The trip count is the RHS of the compare. If this doesn't match the trip + // count computed by SCEV then this is either because the trip count variable + // has been widened (then leave the trip count as it is), or because it is a + // constant and another transformation has changed the compare, e.g. + // icmp ult %inc, tripcount -> icmp ult %j, tripcount-1, then we don't flatten + // the loop (yet). + TripCount = Compare->getOperand(1); + const SCEV *SCEVTripCount = + SE->getTripCountFromExitCount(SE->getBackedgeTakenCount(L)); + if (SE->getSCEV(TripCount) != SCEVTripCount) { + if (!IsWidened) { + LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); + return false; + } + auto TripCountInst = dyn_cast(TripCount); + if (!TripCountInst) { + LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); + return false; + } + if ((!isa(TripCountInst) && !isa(TripCountInst)) || + SE->getSCEV(TripCountInst->getOperand(0)) != SCEVTripCount) { + LLVM_DEBUG(dbgs() << "Could not find valid extended trip count\n"); + return false; + } + } IterationInstructions.insert(Increment); LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); - LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump()); - - assert(InductionPHI->getNumIncomingValues() == 2); - - if (InductionPHI->getIncomingValueForBlock(Latch) != Increment) { - LLVM_DEBUG( - dbgs() << "Incoming value from latch is not the increment inst\n"); - return false; - } - - auto *CI = dyn_cast( - InductionPHI->getIncomingValueForBlock(L->getLoopPreheader())); - if (!CI || !CI->isZero()) { - LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump()); - return false; - } + LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); return true; @@ -300,7 +313,7 @@ // Multiplies of the outer iteration variable and inner iteration // count will be optimised out. if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI), - m_Specific(FI.InnerLimit)))) + m_Specific(FI.InnerTripCount)))) continue; InstructionCost Cost = TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); @@ -325,16 +338,16 @@ static bool checkIVUsers(FlattenInfo &FI) { // We require all uses of both induction variables to match this pattern: // - // (OuterPHI * InnerLimit) + InnerPHI + // (OuterPHI * InnerTripCount) + InnerPHI // // Any uses of the induction variables not matching that pattern would // require a div/mod to reconstruct in the flattened loop, so the // transformation wouldn't be profitable. - Value *InnerLimit = FI.InnerLimit; + Value *InnerTripCount = FI.InnerTripCount; if (FI.Widened && - (isa(InnerLimit) || isa(InnerLimit))) - InnerLimit = cast(InnerLimit)->getOperand(0); + (isa(InnerTripCount) || isa(InnerTripCount))) + InnerTripCount = cast(InnerTripCount)->getOperand(0); // Check that all uses of the inner loop's induction variable match the // expected pattern, recording the uses of the outer IV. @@ -368,7 +381,7 @@ m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), m_Value(MatchedItCount))); - if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) { + if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { LLVM_DEBUG(dbgs() << "Use is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); FI.LinearIVUses.insert(U); @@ -417,7 +430,7 @@ } // Return an OverflowResult dependant on if overflow of the multiplication of -// InnerLimit and OuterLimit can be assumed not to happen. +// InnerTripCount and OuterTripCount can be assumed not to happen. static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, AssumptionCache *AC) { Function *F = FI.OuterLoop->getHeader()->getParent(); @@ -430,7 +443,7 @@ // Check if the multiply could not overflow due to known ranges of the // input values. OverflowResult OR = computeOverflowForUnsignedMul( - FI.InnerLimit, FI.OuterLimit, DL, AC, + FI.InnerTripCount, FI.OuterTripCount, DL, AC, FI.OuterLoop->getLoopPreheader()->getTerminator(), DT); if (OR != OverflowResult::MayOverflow) return OR; @@ -461,21 +474,23 @@ ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI) { SmallPtrSet IterationInstructions; - if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI, - FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE)) + if (!findLoopComponents(FI.InnerLoop, IterationInstructions, + FI.InnerInductionPHI, FI.InnerTripCount, + FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened)) return false; - if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI, - FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE)) + if (!findLoopComponents(FI.OuterLoop, IterationInstructions, + FI.OuterInductionPHI, FI.OuterTripCount, + FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened)) return false; - // Both of the loop limit values must be invariant in the outer loop + // Both of the loop trip count values must be invariant in the outer loop // (non-instructions are all inherently invariant). - if (!FI.OuterLoop->isLoopInvariant(FI.InnerLimit)) { - LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n"); + if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) { + LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n"); return false; } - if (!FI.OuterLoop->isLoopInvariant(FI.OuterLimit)) { - LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n"); + if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) { + LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n"); return false; } @@ -515,9 +530,9 @@ ORE.emit(Remark); } - Value *NewTripCount = - BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount", - FI.OuterLoop->getLoopPreheader()->getTerminator()); + Value *NewTripCount = BinaryOperator::CreateMul( + FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount", + FI.OuterLoop->getLoopPreheader()->getTerminator()); LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; NewTripCount->dump()); @@ -581,7 +596,7 @@ // If both induction types are less than the maximum legal integer width, // promote both to the widest type available so we know calculating - // (OuterLimit * InnerLimit) as the new trip count is safe. + // (OuterTripCount * InnerTripCount) as the new trip count is safe. if (InnerType != OuterType || InnerType->getScalarSizeInBits() >= MaxLegalSize || MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) { diff --git a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll --- a/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll +++ b/llvm/test/Transforms/LoopFlatten/loop-flatten-negative.ll @@ -341,6 +341,37 @@ ret i32 10 } +; When the loop trip count is a constant (e.g. 20) and the step size is +; 1, InstCombine causes the transformation icmp ult i32 %inc, 20 -> +; icmp ult i32 %j, 19. In this case a valid trip count is not found so +; the loop is not flattened. +define i32 @test9(i32* nocapture %A) { +entry: + br label %for.cond1.preheader + +for.cond1.preheader: + %i.017 = phi i32 [ 0, %entry ], [ %inc6, %for.cond.cleanup3 ] + %mul = mul i32 %i.017, 20 + br label %for.body4 + +for.cond.cleanup3: + %inc6 = add i32 %i.017, 1 + %cmp = icmp ult i32 %inc6, 11 + br i1 %cmp, label %for.cond1.preheader, label %for.cond.cleanup + +for.body4: + %j.016 = phi i32 [ 0, %for.cond1.preheader ], [ %inc, %for.body4 ] + %add = add i32 %j.016, %mul + %arrayidx = getelementptr inbounds i32, i32* %A, i32 %add + store i32 30, i32* %arrayidx, align 4 + %inc = add nuw nsw i32 %j.016, 1 + %cmp2 = icmp ult i32 %j.016, 19 + br i1 %cmp2, label %for.body4, label %for.cond.cleanup3 + +for.cond.cleanup: + %0 = load i32, i32* %A, align 4 + ret i32 %0 +} ; Outer loop conditional phi define i32 @e() {