diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -759,9 +759,13 @@ /// Convert from an "exit count" (i.e. "backedge taken count") to a "trip /// count". A "trip count" is the number of times the header of the loop /// will execute if an exit is taken after the specified number of backedges - /// have been taken. (e.g. TripCount = ExitCount + 1) A zero result - /// must be interpreted as a loop having an unknown trip count. - const SCEV *getTripCountFromExitCount(const SCEV *ExitCount); + /// have been taken. (e.g. TripCount = ExitCount + 1). Note that the + /// expression can overflow if ExitCount = UINT_MAX. \p Extend controls + /// how potential overflow is handled. If true, a wider result type is + /// returned. ex: EC = 255 (i8), TC = 256 (i9). If false, result unsigned + /// wraps with 2s-complement semantics. ex: EC = 255 (i8), TC = 0 (i8) + const SCEV *getTripCountFromExitCount(const SCEV *ExitCount, + bool Extend = true); /// Returns the exact trip count of the loop if we can compute it, and /// the result is a small constant. '0' is used to represent an unknown diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -7216,10 +7216,21 @@ // Iteration Count Computation Code // -const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount) { - // Get the trip count from the BE count by adding 1. Overflow, results - // in zero which means "unknown". - return getAddExpr(ExitCount, getOne(ExitCount->getType())); +const SCEV *ScalarEvolution::getTripCountFromExitCount(const SCEV *ExitCount, + bool Extend) { + if (isa(ExitCount)) + return getCouldNotCompute(); + + auto *ExitCountType = ExitCount->getType(); + assert(ExitCountType->isIntegerTy()); + + if (!Extend) + return getAddExpr(ExitCount, getOne(ExitCountType)); + + auto *WiderType = Type::getIntNTy(ExitCountType->getContext(), + 1 + ExitCountType->getScalarSizeInBits()); + return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType), + getOne(WiderType)); } static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { diff --git a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp --- a/llvm/lib/Transforms/Scalar/LoopFlatten.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFlatten.cpp @@ -202,7 +202,12 @@ LLVM_DEBUG(dbgs() << "Backedge-taken count is not predictable\n"); return false; } - const SCEV *SCEVTripCount = SE->getTripCountFromExitCount(BackedgeTakenCount); + // The use of the Extend=false flag on getTripCountFromExitCount was added + // during a refactoring to preserve existing behavior. However, there's + // nothing obvious in the surrounding code when handles the overflow case. + // FIXME: audit code to establish whether there's a latent bug here. + const SCEV *SCEVTripCount = + SE->getTripCountFromExitCount(BackedgeTakenCount, false); const SCEV *SCEVRHS = SE->getSCEV(RHS); if (SCEVRHS == SCEVTripCount) return setLoopComponents(RHS, TripCount, Increment, IterationInstructions); @@ -214,7 +219,7 @@ // Find the extended backedge taken count and extended trip count using // SCEV. One of these should now match the RHS of the compare. BackedgeTCExt = SE->getZeroExtendExpr(BackedgeTakenCount, RHS->getType()); - SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt); + SCEVTripCountExt = SE->getTripCountFromExitCount(BackedgeTCExt, false); if (SCEVRHS != BackedgeTCExt && SCEVRHS != SCEVTripCountExt) { LLVM_DEBUG(dbgs() << "Could not find valid trip count\n"); return false;