diff --git a/llvm/lib/Analysis/LoopCacheAnalysis.cpp b/llvm/lib/Analysis/LoopCacheAnalysis.cpp --- a/llvm/lib/Analysis/LoopCacheAnalysis.cpp +++ b/llvm/lib/Analysis/LoopCacheAnalysis.cpp @@ -297,7 +297,7 @@ Type *WiderType = SE.getWiderType(Stride->getType(), TripCount->getType()); const SCEV *CacheLineSize = SE.getConstant(WiderType, CLS); Stride = SE.getNoopOrAnyExtend(Stride, WiderType); - TripCount = SE.getNoopOrAnyExtend(TripCount, WiderType); + TripCount = SE.getNoopOrZeroExtend(TripCount, WiderType); const SCEV *Numerator = SE.getMulExpr(Stride, TripCount); RefCost = SE.getUDivExpr(Numerator, CacheLineSize); @@ -323,8 +323,8 @@ const SCEV *TripCount = computeTripCount(*AR->getLoop(), *Sizes.back(), SE); Type *WiderType = SE.getWiderType(RefCost->getType(), TripCount->getType()); - RefCost = SE.getMulExpr(SE.getNoopOrAnyExtend(RefCost, WiderType), - SE.getNoopOrAnyExtend(TripCount, WiderType)); + RefCost = SE.getMulExpr(SE.getNoopOrZeroExtend(RefCost, WiderType), + SE.getNoopOrZeroExtend(TripCount, WiderType)); } LLVM_DEBUG(dbgs().indent(4) @@ -334,7 +334,7 @@ // Attempt to fold RefCost into a constant. if (auto ConstantCost = dyn_cast(RefCost)) - return ConstantCost->getValue()->getSExtValue(); + return ConstantCost->getValue()->getZExtValue(); LLVM_DEBUG(dbgs().indent(4) << "RefCost is not a constant! Setting to RefCost=InvalidCost " diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -8046,6 +8046,12 @@ if (!Extend) return getAddExpr(ExitCount, getOne(ExitCountType)); + ConstantRange ExitCountRange = + getRangeRef(ExitCount, RangeSignHint::HINT_RANGE_UNSIGNED); + if (!ExitCountRange.contains( + APInt::getMaxValue(ExitCountRange.getBitWidth()))) + return getAddExpr(ExitCount, getOne(ExitCountType)); + auto *WiderType = Type::getIntNTy(ExitCountType->getContext(), 1 + ExitCountType->getScalarSizeInBits()); return getAddExpr(getNoopOrZeroExtend(ExitCount, WiderType), @@ -8228,15 +8234,14 @@ return 1; // Get the trip count - const SCEV *TCExpr = getTripCountFromExitCount(ExitCount); + const SCEV *TCExpr = getTripCountFromExitCount(applyLoopGuards(ExitCount, L)); const SCEVConstant *TC = dyn_cast(TCExpr); if (!TC) // Attempt to factor more general cases. Returns the greatest power of // two divisor. If overflow happens, the trip count expression is still // divisible by the greatest power of 2 divisor returned. - return 1U << std::min((uint32_t)31, - GetMinTrailingZeros(applyLoopGuards(TCExpr, L))); + return 1U << std::min((uint32_t)31, GetMinTrailingZeros(TCExpr)); ConstantInt *Result = TC->getValue();