diff --git a/llvm/include/llvm/Analysis/LoopUnrollAnalyzer.h b/llvm/include/llvm/Analysis/LoopUnrollAnalyzer.h --- a/llvm/include/llvm/Analysis/LoopUnrollAnalyzer.h +++ b/llvm/include/llvm/Analysis/LoopUnrollAnalyzer.h @@ -47,6 +47,8 @@ ConstantInt *Offset = nullptr; }; + bool NeedsResimplify = false; + public: UnrolledInstAnalyzer(unsigned Iteration, DenseMap &SimplifiedValues, @@ -58,6 +60,8 @@ // Allow access to the initial visit method. using Base::visit; + bool needsResimplify() const { return NeedsResimplify; } + private: /// A cache of pointer bases and constant-folded offsets corresponding /// to GEP (or derived from GEP) instructions. diff --git a/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp b/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp --- a/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp +++ b/llvm/lib/Analysis/LoopUnrollAnalyzer.cpp @@ -143,6 +143,7 @@ Constant *CV = CDS->getElementAsConstant(Index); assert(CV && "Constant expected."); SimplifiedValues[&I] = CV; + NeedsResimplify = true; return true; } diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp --- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -394,7 +394,8 @@ SmallVector PHIUsedList; // Helper function to accumulate cost for instructions in the loop. - auto AddCostRecursively = [&](Instruction &RootI, int Iteration) { + auto AddCostRecursively = [&](Instruction &RootI, int Iteration, + bool SingleIterEvalOnly = false) { assert(Iteration >= 0 && "Cannot have a negative iteration!"); assert(CostWorklist.empty() && "Must start with an empty cost list"); assert(PHIUsedList.empty() && "Must start with an empty phi used list"); @@ -428,7 +429,7 @@ if (PhiI->getParent() == L->getHeader()) { assert(Cost.IsFree && "Loop PHIs shouldn't be evaluated as they " "inherently simplify during unrolling."); - if (Iteration == 0) + if (!SingleIterEvalOnly && Iteration == 0) continue; // Push the incoming value from the backedge into the PHI used list @@ -436,8 +437,12 @@ // cost worklist for the next iteration (as we count backwards). if (auto *OpI = dyn_cast( PhiI->getIncomingValueForBlock(L->getLoopLatch()))) - if (L->contains(OpI)) - PHIUsedList.push_back(OpI); + if (L->contains(OpI)) { + if (SingleIterEvalOnly) + CostWorklist.push_back(OpI); + else + PHIUsedList.push_back(OpI); + } continue; } @@ -468,6 +473,8 @@ // We've exhausted the search. break; + if (SingleIterEvalOnly) + break; assert(Iteration > 0 && "Cannot track PHI-used values past the first iteration!"); CostWorklist.append(PHIUsedList.begin(), PHIUsedList.end()); @@ -486,6 +493,11 @@ TargetTransformInfo::TargetCostKind CostKind = L->getHeader()->getParent()->hasMinSize() ? TargetTransformInfo::TCK_CodeSize : TargetTransformInfo::TCK_SizeAndLatency; + + bool SingleIterEvalOnly = false; + + InstructionCost FirstIterUnrolledCost = 0; + // Simulate execution of each iteration of the loop counting instructions, // which would be simplified. // Since the same load will take different values on different iterations, @@ -633,6 +645,24 @@ << " UnrolledCost: " << UnrolledCost << "\n"); return None; } + + if (Iteration == 0) + FirstIterUnrolledCost = UnrolledCost; + + if (Iteration == 1 && !Analyzer.needsResimplify()) { + SingleIterEvalOnly = true; + UnrolledCost -= FirstIterUnrolledCost; + SmallVector ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + for (BasicBlock *Exiting : ExitingBlocks) { + for (BasicBlock *Succ : successors(Exiting)) { + if (L->contains(Succ)) + continue; + ExitWorklist.insert({Exiting, Succ}); + } + } + break; + } } while (!ExitWorklist.empty()) { @@ -647,10 +677,16 @@ Value *Op = PN->getIncomingValueForBlock(ExitingBB); if (auto *OpI = dyn_cast(Op)) if (L->contains(OpI)) - AddCostRecursively(*OpI, TripCount - 1); + AddCostRecursively(*OpI, SingleIterEvalOnly ? 1 : TripCount - 1, + SingleIterEvalOnly); } } + if (SingleIterEvalOnly) { + UnrolledCost *= TripCount; + RolledDynamicCost = (RolledDynamicCost / 2) * TripCount; + } + assert(UnrolledCost.isValid() && RolledDynamicCost.isValid() && "All instructions must have a valid cost, whether the " "loop is rolled or unrolled.");