diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -953,11 +953,22 @@ /// and recompute is simpler. void forgetBlockAndLoopDispositions(Value *V = nullptr); + /// Determine the maximum constant multiple C such that `S % C == 0`. This + /// assumes that S is an unsigned SCEV. If S does not guarentee no unsigned + /// wrap, then we compute the max multiple such that the multiple is a power + /// of 2. Multiples that are powers of 2 will still divide S even if S wraps. + /// + /// If S is a multiple of 0, than any integer is a multiple of S. We return 0 + /// in this case. + APInt getMaxConstantMultiple(const SCEV *S); + + APInt getMaxNonZeroConstantMultiple(const SCEV *S); + /// Determine the minimum number of zero bits that S is guaranteed to end in /// (at every loop iteration). It is, at the same time, the minimum number /// of times S is divisible by 2. For example, given {4,+,8} it returns 2. /// If S is guaranteed to be 0, it returns the bitwidth of S. - uint32_t GetMinTrailingZeros(const SCEV *S); + uint32_t getMinTrailingZeros(const SCEV *S); /// Determine the unsigned range for a particular SCEV. /// NOTE: This returns a copy of the reference returned by getRangeRef. @@ -1356,14 +1367,14 @@ /// predicate by splitting it into a set of independent predicates. bool ProvingSplitPredicate = false; - /// Memoized values for the GetMinTrailingZeros - DenseMap MinTrailingZerosCache; + /// Memoized values for the GetMaxConstantTripMultiple + DenseMap MaxConstantTripMultipleCache; /// Return the Value set from which the SCEV expr is generated. ArrayRef getSCEVValues(const SCEV *S); - /// Private helper method for the GetMinTrailingZeros method - uint32_t GetMinTrailingZerosImpl(const SCEV *S); + /// Private helper method for the GetMaxConstantTripMultiple method + APInt getMaxConstantMultipleImpl(const SCEV *S); /// Information about the number of loop iterations for which a loop exit's /// branch condition evaluates to the not-taken path. This is a temporary diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -1315,7 +1315,7 @@ } // Return zero if truncating to known zeros. - uint32_t MinTrailingZeros = GetMinTrailingZeros(Op); + uint32_t MinTrailingZeros = getMinTrailingZeros(Op); if (MinTrailingZeros >= getTypeSizeInBits(Ty)) return getZero(Ty); @@ -1600,7 +1600,7 @@ // Find number of trailing zeros of (x + y + ...) w/o the C first: uint32_t TZ = BitWidth; for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I) - TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I))); + TZ = std::min(TZ, SE.getMinTrailingZeros(WholeAddExpr->getOperand(I))); if (TZ) { // Set D to be as many least significant bits of C as possible while still // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap: @@ -1617,7 +1617,7 @@ const APInt &ConstantStart, const SCEV *Step) { const unsigned BitWidth = ConstantStart.getBitWidth(); - const uint32_t TZ = SE.GetMinTrailingZeros(Step); + const uint32_t TZ = SE.getMinTrailingZeros(Step); if (TZ) return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth) : ConstantStart; @@ -6302,71 +6302,110 @@ return getGEPExpr(GEP, IndexExprs); } -uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { - if (const SCEVConstant *C = dyn_cast(S)) - return C->getAPInt().countTrailingZeros(); - - if (const SCEVPtrToIntExpr *I = dyn_cast(S)) - return GetMinTrailingZeros(I->getOperand()); - - if (const SCEVTruncateExpr *T = dyn_cast(S)) - return std::min(GetMinTrailingZeros(T->getOperand()), - (uint32_t)getTypeSizeInBits(T->getType())); +APInt ScalarEvolution::getMaxConstantMultipleImpl(const SCEV *S) { + uint64_t BitWidth = getTypeSizeInBits(S->getType()); + auto GetPowerOfTwo = [&BitWidth](uint32_t TrailingZeros) { + return TrailingZeros >= BitWidth + ? APInt::getZero(BitWidth) + : APInt::getOneBitSet(BitWidth, TrailingZeros); + }; - if (const SCEVZeroExtendExpr *E = dyn_cast(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) - ? getTypeSizeInBits(E->getType()) - : OpRes; + switch (S->getSCEVType()) { + case scConstant: + return cast(S)->getAPInt(); + case scUnknown: { + const SCEVUnknown *U = cast(S); + // For a SCEVUnknown, ask ValueTracking for known bits. + unsigned TZ = + computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT) + .countMinTrailingZeros(); + return GetPowerOfTwo(TZ); } - - if (const SCEVSignExtendExpr *E = dyn_cast(S)) { - uint32_t OpRes = GetMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) - ? getTypeSizeInBits(E->getType()) - : OpRes; + case scPtrToInt: + return getMaxConstantMultiple(cast(S)->getOperand()); + case scTruncate: { + // Only multiples that are a power of 2 will hold after truncation. + const SCEVTruncateExpr *T = cast(S); + uint32_t TZ = getMinTrailingZeros(T->getOperand()); + return GetPowerOfTwo(TZ); } - - if (const SCEVMulExpr *M = dyn_cast(S)) { - // The result is the sum of all operands results. - uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0)); - uint32_t BitWidth = getTypeSizeInBits(M->getType()); - for (unsigned i = 1, e = M->getNumOperands(); - SumOpRes != BitWidth && i != e; ++i) - SumOpRes = - std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)), BitWidth); - return SumOpRes; + case scZeroExtend: { + const SCEVZeroExtendExpr *Z = cast(S); + return getMaxConstantMultiple(Z->getOperand()).zext(BitWidth); } + case scSignExtend: { + const SCEVSignExtendExpr *E = cast(S); + return getMaxConstantMultiple(E->getOperand()).sext(BitWidth); + } + + case scMulExpr: { + const SCEVMulExpr *M = cast(S); + if (M->hasNoUnsignedWrap()) { + // The result is the product of all operand results. + APInt Res = APInt(BitWidth, 1); + for (unsigned I = 0, E = M->getNumOperands(); I != E; ++I) + Res = Res * getMaxConstantMultiple(M->getOperand(I)); + return Res; + } - if (isa(S) || isa(S) || isa(S) || - isa(S)) { - // The result is the min of all operands results. - const SCEVNAryExpr *M = cast(S); - uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0)); - for (unsigned I = 1, E = M->getNumOperands(); MinOpRes && I != E; ++I) - MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(I))); - return MinOpRes; + // If there are no wrap guarentees, find the trailing zeros, which is the + // sum of trailing zeros for all its operands. + uint32_t TZ = 0; + for (const SCEV *Operand : M->operands()) { + TZ += getMinTrailingZeros(Operand); + } + return GetPowerOfTwo(TZ); } - if (const SCEVUnknown *U = dyn_cast(S)) { - // For a SCEVUnknown, ask ValueTracking. - KnownBits Known = computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT); - return Known.countMinTrailingZeros(); + case scAddExpr: + case scAddRecExpr: + case scUMinExpr: + case scUMaxExpr: + case scSMinExpr: + case scSMaxExpr: + case scSequentialUMinExpr: { + const SCEVNAryExpr *N = cast(S); + if (N->hasNoUnsignedWrap()) { + // The result is GCD of all operands results. + APInt Res = getMaxConstantMultiple(N->getOperand(0)); + for (const SCEV *Operand : N->operands()) + Res = APIntOps::GreatestCommonDivisor(Res, + getMaxConstantMultiple(Operand)); + return Res; + } + + // If there are no wrap guarantees, find the trailing bits, which is the + // minimum of its operands. + uint32_t TZ = getMinTrailingZeros(N->getOperand(0)); + for (int I = 0, E = N->getNumOperands(); I != E && TZ; ++I) + TZ = std::min(TZ, getMinTrailingZeros(N->getOperand(I))); + return GetPowerOfTwo(TZ); } - // SCEVUDivExpr - return 0; + default: + return APInt(BitWidth, 1); + } } -uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { - auto I = MinTrailingZerosCache.find(S); - if (I != MinTrailingZerosCache.end()) +APInt ScalarEvolution::getMaxConstantMultiple(const SCEV *S) { + auto I = MaxConstantTripMultipleCache.find(S); + if (I != MaxConstantTripMultipleCache.end()) return I->second; - uint32_t Result = GetMinTrailingZerosImpl(S); - auto InsertPair = MinTrailingZerosCache.insert({S, Result}); + APInt Multiple = getMaxConstantMultipleImpl(S); + auto InsertPair = MaxConstantTripMultipleCache.insert({S, Multiple}); assert(InsertPair.second && "Should insert a new key"); - return InsertPair.first->second; + return Multiple; +} + +APInt ScalarEvolution::getMaxNonZeroConstantMultiple(const SCEV *S) { + return APIntOps::umax(getMaxConstantMultiple(S), + APInt(getTypeSizeInBits(S->getType()), 1)); +} + +uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) { + return std::min(getMaxConstantMultiple(S).countTrailingZeros(), + (unsigned)getTypeSizeInBits(S->getType())); } /// Helper method to assign a range to V from metadata present in the IR. @@ -6596,15 +6635,18 @@ ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); using OBO = OverflowingBinaryOperator; - // If the value has known zeros, the maximum value will have those known zeros - // as well. - uint32_t TZ = GetMinTrailingZeros(S); - if (TZ != 0) { - if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) + // The largest possible value is `Max - (Max % Multiple)`. + if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { + APInt MaxMultiple = getMaxNonZeroConstantMultiple(S); + APInt Remainder = APInt::getMaxValue(BitWidth).urem(MaxMultiple); + if (!Remainder.isZero()) { ConservativeResult = ConstantRange(APInt::getMinValue(BitWidth), - APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); - else + APInt::getMaxValue(BitWidth) - Remainder + 1); + } + } else { + uint32_t TZ = getMinTrailingZeros(S); + if (TZ != 0) ConservativeResult = ConstantRange( APInt::getSignedMinValue(BitWidth), APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); @@ -8222,12 +8264,10 @@ const SCEV *TCExpr = getTripCountFromExitCount(ExitCount); const SCEVConstant *TC = dyn_cast(TCExpr); - if (!TC) - // Attempt to factor more general cases. Returns the greatest power of - // two divisor. If overflow happens, the trip count expression is still - // divisible by the greatest power of 2 divisor returned. - return 1U << std::min((uint32_t)31, - GetMinTrailingZeros(applyLoopGuards(TCExpr, L))); + if (!TC) { + return *getMaxNonZeroConstantMultiple(applyLoopGuards(TCExpr, L)) + .getRawData(); + } ConstantInt *Result = TC->getValue(); @@ -8409,7 +8449,7 @@ SignedRanges.clear(); ExprValueMap.clear(); HasRecMap.clear(); - MinTrailingZerosCache.clear(); + MaxConstantTripMultipleCache.clear(); PredicatedSCEVRewrites.clear(); FoldCache.clear(); FoldCacheUser.clear(); @@ -10044,7 +10084,7 @@ // // B is divisible by D if and only if the multiplicity of prime factor 2 for B // is not less than multiplicity of this prime factor for D. - if (SE.GetMinTrailingZeros(B) < Mult2) + if (SE.getMinTrailingZeros(B) < Mult2) return SE.getCouldNotCompute(); // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic @@ -13396,7 +13436,7 @@ PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), PendingPhiRanges(std::move(Arg.PendingPhiRanges)), PendingMerges(std::move(Arg.PendingMerges)), - MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), + MaxConstantTripMultipleCache(std::move(Arg.MaxConstantTripMultipleCache)), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( std::move(Arg.PredicatedBackedgeTakenCounts)), @@ -13893,7 +13933,7 @@ UnsignedRanges.erase(S); SignedRanges.erase(S); HasRecMap.erase(S); - MinTrailingZerosCache.erase(S); + MaxConstantTripMultipleCache.erase(S); if (auto *AR = dyn_cast(S)) { UnsignedWrapViaInductionTried.erase(AR); diff --git a/llvm/test/Analysis/ScalarEvolution/nsw.ll b/llvm/test/Analysis/ScalarEvolution/nsw.ll --- a/llvm/test/Analysis/ScalarEvolution/nsw.ll +++ b/llvm/test/Analysis/ScalarEvolution/nsw.ll @@ -322,7 +322,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,7}<%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 7 -; CHECK-NEXT: --> {7,+,7}<%loop> U: [7,0) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {7,+,7}<%loop> U: [7,-3) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @bad_postinc_nsw_a ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 613566756 diff --git a/llvm/test/Analysis/ScalarEvolution/ranges.ll b/llvm/test/Analysis/ScalarEvolution/ranges.ll --- a/llvm/test/Analysis/ScalarEvolution/ranges.ll +++ b/llvm/test/Analysis/ScalarEvolution/ranges.ll @@ -133,7 +133,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,6}<%loop> U: [0,-2147483648) S: [0,2147483647) Exits: (6 * ((((-1 * (1 umin %n)) + %n) /u 6) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 6 -; CHECK-NEXT: --> {6,+,6}<%loop> U: [6,-1) S: [-2147483648,2147483647) Exits: (6 + (6 * ((((-1 * (1 umin %n)) + %n) /u 6) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {6,+,6}<%loop> U: [6,-3) S: [-2147483648,2147483647) Exits: (6 + (6 * ((((-1 * (1 umin %n)) + %n) /u 6) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @add_6 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 6) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 715827882 @@ -160,7 +160,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,7}<%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 7 -; CHECK-NEXT: --> {7,+,7}<%loop> U: [7,0) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {7,+,7}<%loop> U: [7,-3) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @add_7 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 613566756 @@ -215,7 +215,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,9}<%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (9 * ((((-1 * (1 umin %n)) + %n) /u 9) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 9 -; CHECK-NEXT: --> {9,+,9}<%loop> U: [9,0) S: [9,0) Exits: (9 + (9 * ((((-1 * (1 umin %n)) + %n) /u 9) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {9,+,9}<%loop> U: [9,-3) S: [9,0) Exits: (9 + (9 * ((((-1 * (1 umin %n)) + %n) /u 9) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @add_9 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 9) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 477218588 @@ -243,7 +243,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,10}<%loop> U: [0,-2147483648) S: [0,2147483647) Exits: (10 * ((((-1 * (1 umin %n)) + %n) /u 10) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 10 -; CHECK-NEXT: --> {10,+,10}<%loop> U: [10,-1) S: [-2147483648,2147483647) Exits: (10 + (10 * ((((-1 * (1 umin %n)) + %n) /u 10) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {10,+,10}<%loop> U: [10,-5) S: [-2147483648,2147483647) Exits: (10 + (10 * ((((-1 * (1 umin %n)) + %n) /u 10) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @add_10 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 10) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 429496729