Index: llvm/lib/Transforms/Scalar/LoopFlatten.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -96,7 +96,7 @@ static bool findLoopComponents( Loop *L, SmallPtrSetImpl &IterationInstructions, PHINode *&InductionPHI, Value *&Limit, BinaryOperator *&Increment, - BranchInst *&BackBranch, ScalarEvolution *SE) { + BranchInst *&BackBranch, ScalarEvolution *SE, bool isWidened) { LLVM_DEBUG(dbgs() << "Finding components of loop: " << L->getName() << "\n"); if (!L->isLoopSimplifyForm()) { @@ -122,6 +122,15 @@ } LLVM_DEBUG(dbgs() << "Found induction PHI: "; InductionPHI->dump()); + assert(InductionPHI->getNumIncomingValues() == 2); + + auto *CI = dyn_cast( + InductionPHI->getIncomingValueForBlock(L->getLoopPreheader())); + if (!CI || !CI->isZero()) { + LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump()); + return false; + } + bool ContinueOnTrue = L->contains(Latch->getTerminator()->getSuccessor(0)); auto IsValidPredicate = [&](ICmpInst::Predicate Pred) { if (ContinueOnTrue) @@ -144,41 +153,36 @@ IterationInstructions.insert(Compare); LLVM_DEBUG(dbgs() << "Found comparison: "; Compare->dump()); - // Find increment and limit from the compare + // Find increment and limit. + // The loop must be canonical, and there are exactly 2 incoming values to the + // induction phi; one from the pre-header and one from the latch. The incoming + // latch value is the increment variable. Increment = nullptr; - if (match(Compare->getOperand(0), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast(Compare->getOperand(0)); - Limit = Compare->getOperand(1); - } else if (Compare->getUnsignedPredicate() == CmpInst::ICMP_NE && - match(Compare->getOperand(1), - m_c_Add(m_Specific(InductionPHI), m_ConstantInt<1>()))) { - Increment = dyn_cast(Compare->getOperand(1)); - Limit = Compare->getOperand(0); - } + if (L->isCanonical(*SE)) + Increment = + dyn_cast(InductionPHI->getIncomingValueForBlock(Latch)); if (!Increment || Increment->hasNUsesOrMore(3)) { - LLVM_DEBUG(dbgs() << "Cound not find valid increment\n"); + LLVM_DEBUG(dbgs() << "Could not find valid increment\n"); return false; } + // The limit is the RHS of the compare. If this doesn't match the trip count + // computed by SCEV then this is either because the limit variable has been + // widened (then leave the limit as it is), or because the limit is a constant + // and another transformation has changed the compare (e.g. icmp ult %inc, + // limit -> icmp ult %j, limit-1), then change the limit to the trip count. + Limit = Compare->getOperand(1); + const SCEV *TripCount = + SE->getTripCountFromExitCount(SE->getBackedgeTakenCount(L)); + if (!(SE->getSCEV(Limit) == TripCount) && !isWidened) { + auto *ConstantTripCount = dyn_cast(TripCount); + Limit = ConstantInt::get(Compare->getContext(), ConstantTripCount->getValue()->getValue()); + assert(SE->getSCEV(Limit) == TripCount && + "Expected constant Limit to match trip count\n"); + } IterationInstructions.insert(Increment); LLVM_DEBUG(dbgs() << "Found increment: "; Increment->dump()); LLVM_DEBUG(dbgs() << "Found limit: "; Limit->dump()); - assert(InductionPHI->getNumIncomingValues() == 2); - - if (InductionPHI->getIncomingValueForBlock(Latch) != Increment) { - LLVM_DEBUG( - dbgs() << "Incoming value from latch is not the increment inst\n"); - return false; - } - - auto *CI = dyn_cast( - InductionPHI->getIncomingValueForBlock(L->getLoopPreheader())); - if (!CI || !CI->isZero()) { - LLVM_DEBUG(dbgs() << "PHI value is not zero: "; CI->dump()); - return false; - } - LLVM_DEBUG(dbgs() << "Successfully found all loop components\n"); return true; } @@ -461,11 +465,13 @@ ScalarEvolution *SE, AssumptionCache *AC, const TargetTransformInfo *TTI) { SmallPtrSet IterationInstructions; - if (!findLoopComponents(FI.InnerLoop, IterationInstructions, FI.InnerInductionPHI, - FI.InnerLimit, FI.InnerIncrement, FI.InnerBranch, SE)) + if (!findLoopComponents(FI.InnerLoop, IterationInstructions, + FI.InnerInductionPHI, FI.InnerLimit, + FI.InnerIncrement, FI.InnerBranch, SE, FI.Widened)) return false; - if (!findLoopComponents(FI.OuterLoop, IterationInstructions, FI.OuterInductionPHI, - FI.OuterLimit, FI.OuterIncrement, FI.OuterBranch, SE)) + if (!findLoopComponents(FI.OuterLoop, IterationInstructions, + FI.OuterInductionPHI, FI.OuterLimit, + FI.OuterIncrement, FI.OuterBranch, SE, FI.Widened)) return false; // Both of the loop limit values must be invariant in the outer loop