Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -63,7 +63,7 @@ AssumeNoOverflow("loop-flatten-assume-no-overflow", cl::Hidden, cl::init(false), cl::desc("Assume that the product of the two iteration " - "limits will never overflow")); + "trip counts will never overflow")); static cl::opt WidenIV("loop-flatten-widen-iv", cl::Hidden, @@ -76,8 +76,8 @@ Loop *InnerLoop = nullptr; PHINode *InnerInductionPHI = nullptr; PHINode *OuterInductionPHI = nullptr; - Value *InnerLimit = nullptr; - Value *OuterLimit = nullptr; + Value *InnerTripCount = nullptr; + Value *OuterTripCount = nullptr; BinaryOperator *InnerIncrement = nullptr; BinaryOperator *OuterIncrement = nullptr; BranchInst *InnerBranch = nullptr; @@ -91,12 +91,12 @@ FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {}; }; -// Finds the induction variable, increment and limit for a simple loop that we -// can flatten. +// Finds the induction variable, increment and trip count for a simple loop that +// we can flatten. static bool findLoopComponents( Loop *L, SmallPtrSetImpl &IterationInstructions, - PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment, - BranchInst *&BackBranch, ScalarEvolution *SE) { + PHINode *&InductionPHI, Value *&TripCount, BinaryOperator *&Increment, + BranchInst *&BackBranch, ScalarEvolution *SE, bool IsWidened) { LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n"); if (!L->isLoopSimplifyForm()) { @@ -122,6 +122,13 @@ } LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump()); + auto *CI = dyn_cast( + InductionPHI->getIncomingValueForBlock(L->getLoopPreheader())); + if (!CI || !CI->isZero()) { + LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump()); + return false; + } + bool ContinueOnTrue = L->contains(Latch->getTerminator()->getSuccessor(0)); auto IsValidPredicate = [&](ICmpInst::Predicate Pred) { if (ContinueOnTrue) @@ -144,40 +151,37 @@ IterationInstructions.insert(Compare); LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump()); - // Find increment and limit from the compare + // Find increment and trip count. + // The loop must be canonical, and there are exactly 2 incoming values to the + // induction phi; one from the pre-header and one from the latch. The incoming + // latch value is the increment variable. Increment = nullptr; - if (match(Compare->getOperand(0), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast(Compare->getOperand(0)); - Limit = Compare->getOperand(1); - } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE && - match(Compare->getOperand(1), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast(Compare->getOperand(1)); - Limit = Compare->getOperand(0); - } + if (L->isCanonical(*SE)) + Increment = + dyn_cast(InductionPHI->getIncomingValueForBlock(Latch)); if (!Increment || Increment->hasNUsesOrMore(3)) { - LLVM_DEBUG(dbgs() << "Cound not find valid increment\n"); + LLVM_DEBUG(dbgs() << "Could not find valid increment\n"); return false; } + // The trip count is the RHS of the compare. If this doesn't match the trip + // count computed by SCEV then this is either because the trip count variable + // has been widened (then leave the trip count as it is), or because it is a + // constant and another transformation has changed the compare (e.g. icmp ult + // %inc, tripcount -> icmp ult %j, tripcount-1), then change the trip count to + // be the SCEV trip count. + TripCount = Compare->getOperand(1); + const SCEV *SCEVTripCount = + SE->getTripCountFromExitCount(SE->getBackedgeTakenCount(L)); + if (!(SE->getSCEV(TripCount) == SCEVTripCount) && !IsWidened) { + auto *ConstantTripCount = dyn_cast(SCEVTripCount); + TripCount = ConstantInt::get(Compare->getContext(), + ConstantTripCount->getValue()->getValue()); + assert(SE->getSCEV(TripCount) == SCEVTripCount && + "Expected constant trip count to match SCEV trip count\n"); + } IterationInstructions.insert(Increment); LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); - LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump()); - - assert(InductionPHI->getNumIncomingValues() == 2); - - if (InductionPHI->getIncomingValueForBlock(Latch) != Increment) { - LLVM_DEBUG( - dbgs() << "Incoming value from latch is not the increment inst\n"); - return false; - } - - auto *CI = dyn_cast( - InductionPHI->getIncomingValueForBlock(L->getLoopPreheader())); - if (!CI || !CI->isZero()) { - LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump()); - return false; - } + LLVM_DEBUG(dbgs() << "Found trip count: "; TripCount->dump()); LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); return true; @@ -300,7 +304,7 @@ // Multiplies of the outer iteration variable and inner iteration // count will be optimised out. if (match(&I, m_c_Mul(m_Specific(FI.OuterInductionPHI), - m_Specific(FI.InnerLimit)))) + m_Specific(FI.InnerTripCount)))) continue; InstructionCost Cost = TTI->getUserCost(&I, TargetTransformInfo::TCK_SizeAndLatency); @@ -325,16 +329,16 @@ static bool checkIVUsers(FlattenInfo &FI) { // We require all uses of both induction variables to match this pattern: // - // (OuterPHI * InnerLimit) + InnerPHI + // (OuterPHI * InnerTripCount) + InnerPHI // // Any uses of the induction variables not matching that pattern would // require a div/mod to reconstruct in the flattened loop, so the // transformation wouldn't be profitable. - Value *InnerLimit = FI.InnerLimit; + Value *InnerTripCount = FI.InnerTripCount; if (FI.Widened && - (isa(InnerLimit) || isa(InnerLimit))) - InnerLimit = cast(InnerLimit)->getOperand(0); + (isa(InnerTripCount) || isa(InnerTripCount))) + InnerTripCount = cast(InnerTripCount)->getOperand(0); // Check that all uses of the inner loop's induction variable match the // expected pattern, recording the uses of the outer IV. @@ -368,7 +372,7 @@ m_c_Mul(m_Trunc(m_Specific(FI.OuterInductionPHI)), m_Value(MatchedItCount))); - if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerLimit) { + if ((IsAdd || IsAddTrunc) && MatchedItCount == InnerTripCount) { LLVM_DEBUG(dbgs() << "Use is optimisable\n"); ValidOuterPHIUses.insert(MatchedMul); FI.LinearIVUses.insert(U); @@ -417,7 +421,7 @@ } // Return an OverflowResult dependant on if overflow of the multiplication of -// InnerLimit and OuterLimit can be assumed not to happen. +// InnerTripCount and OuterTripCount can be assumed not to happen. static OverflowResult checkOverflow(FlattenInfo &FI, DominatorTree *DT, AssumptionCache *AC) { Function *F = FI.OuterLoop->getHeader()->getParent(); @@ -430,7 +434,7 @@ // Check if the multiply could not overflow due to known ranges of the // input values. OverflowResult OR = computeOverflowForUnsignedMul( - FI.InnerLimit, FI.OuterLimit, DL, AC, + FI.InnerTripCount, FI.OuterTripCount, DL, AC, FI.OuterLoop->getLoopPreheader()->getTerminator(), DT); if (OR != OverflowResult::MayOverflow) return OR; @@ -461,21 +465,23 @@ ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI) { SmallPtrSet IterationInstructions; - if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI, - FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE)) + if (!findLoopComponents(FI.InnerLoop, IterationInstructions, + FI.InnerInductionPHI, FI.InnerTripCount, + FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened)) return false; - if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI, - FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE)) + if (!findLoopComponents(FI.OuterLoop, IterationInstructions, + FI.OuterInductionPHI, FI.OuterTripCount, + FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened)) return false; - // Both of the loop limit values must be invariant in the outer loop + // Both of the loop trip count values must be invariant in the outer loop // (non-instructions are all inherently invariant). - if (!FI.OuterLoop->isLoopInvariant(FI.InnerLimit)) { - LLVM_DEBUG(dbgs() << "inner loop limit not invariant\n"); + if (!FI.OuterLoop->isLoopInvariant(FI.InnerTripCount)) { + LLVM_DEBUG(dbgs() << "inner loop trip count not invariant\n"); return false; } - if (!FI.OuterLoop->isLoopInvariant(FI.OuterLimit)) { - LLVM_DEBUG(dbgs() << "outer loop limit not invariant\n"); + if (!FI.OuterLoop->isLoopInvariant(FI.OuterTripCount)) { + LLVM_DEBUG(dbgs() << "outer loop trip count not invariant\n"); return false; } @@ -515,9 +521,9 @@ ORE.emit(Remark); } - Value *NewTripCount = - BinaryOperator::CreateMul(FI.InnerLimit, FI.OuterLimit, "flatten.tripcount", - FI.OuterLoop->getLoopPreheader()->getTerminator()); + Value *NewTripCount = BinaryOperator::CreateMul( + FI.InnerTripCount, FI.OuterTripCount, "flatten.tripcount", + FI.OuterLoop->getLoopPreheader()->getTerminator()); LLVM_DEBUG(dbgs() << "Created new trip count in preheader: "; NewTripCount->dump()); @@ -581,7 +587,7 @@ // If both induction types are less than the maximum legal integer width, // promote both to the widest type available so we know calculating - // (OuterLimit * InnerLimit) as the new trip count is safe. + // (OuterTripCount * InnerTripCount) as the new trip count is safe. if (InnerType != OuterType || InnerType->getScalarSizeInBits() >= MaxLegalSize || MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) {