diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -964,6 +964,12 @@ /// If S is guaranteed to be 0, it returns the bitwidth of S. uint32_t getMinTrailingZeros(const SCEV *S); + /// Returns the max constant multiple of S. + APInt getConstantMultiple(const SCEV *S); + + // Returns the max constant multiple of S. If S is exactly 0, return 1. + APInt getNonZeroConstantMultiple(const SCEV *S); + /// Determine the unsigned range for a particular SCEV. /// NOTE: This returns a copy of the reference returned by getRangeRef. ConstantRange getUnsignedRange(const SCEV *S) { @@ -1431,14 +1437,14 @@ /// predicate by splitting it into a set of independent predicates. bool ProvingSplitPredicate = false; - /// Memoized values for the GetMinTrailingZeros - DenseMap MinTrailingZerosCache; + /// Memoized values for the getConstantMultiple + DenseMap ConstantMultipleCache; /// Return the Value set from which the SCEV expr is generated. ArrayRef getSCEVValues(const SCEV *S); - /// Private helper method for the GetMinTrailingZeros method - uint32_t getMinTrailingZerosImpl(const SCEV *S); + /// Private helper method for the getConstantMultiple method. + APInt getConstantMultipleImpl(const SCEV *S); /// Information about the number of times a particular loop exit may be /// reached before exiting the loop. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -6281,59 +6281,87 @@ return getGEPExpr(GEP, IndexExprs); } -uint32_t ScalarEvolution::getMinTrailingZerosImpl(const SCEV *S) { +APInt ScalarEvolution::getConstantMultipleImpl(const SCEV *S) { + uint64_t BitWidth = getTypeSizeInBits(S->getType()); + auto GetShiftedByZeros = [BitWidth](uint32_t TrailingZeros) { + return TrailingZeros >= BitWidth + ? APInt::getZero(BitWidth) + : APInt::getOneBitSet(BitWidth, TrailingZeros); + }; + switch (S->getSCEVType()) { case scConstant: - return cast(S)->getAPInt().countr_zero(); + return cast(S)->getAPInt(); + case scPtrToInt: + return getConstantMultiple(cast(S)->getOperand()); + case scUDivExpr: + case scVScale: + return APInt(BitWidth, 1); case scTruncate: { + // Only multiples that are a power of 2 will hold after truncation. const SCEVTruncateExpr *T = cast(S); - return std::min(getMinTrailingZeros(T->getOperand()), - (uint32_t)getTypeSizeInBits(T->getType())); + uint32_t TZ = getMinTrailingZeros(T->getOperand()); + return GetShiftedByZeros(TZ); + } + case scZeroExtend: { + const SCEVZeroExtendExpr *Z = cast(S); + return getConstantMultiple(Z->getOperand()).zext(BitWidth); } - case scZeroExtend: case scSignExtend: { - const SCEVIntegralCastExpr *E = cast(S); - uint32_t OpRes = getMinTrailingZeros(E->getOperand()); - return OpRes == getTypeSizeInBits(E->getOperand()->getType()) - ? getTypeSizeInBits(E->getType()) - : OpRes; + const SCEVSignExtendExpr *E = cast(S); + return getConstantMultiple(E->getOperand()).sext(BitWidth); } case scMulExpr: { const SCEVMulExpr *M = cast(S); - // The result is the sum of all operands results. - uint32_t SumOpRes = getMinTrailingZeros(M->getOperand(0)); - uint32_t BitWidth = getTypeSizeInBits(M->getType()); - for (unsigned I = 1, E = M->getNumOperands(); - SumOpRes != BitWidth && I != E; ++I) - SumOpRes = - std::min(SumOpRes + getMinTrailingZeros(M->getOperand(I)), BitWidth); - return SumOpRes; + if (M->hasNoUnsignedWrap()) { + // The result is the product of all operand results. + APInt Res = getConstantMultiple(M->getOperand(0)); + for (const SCEV *Operand : M->operands().drop_front()) + Res = Res * getConstantMultiple(Operand); + return Res; + } + + // If there are no wrap guarentees, find the trailing zeros, which is the + // sum of trailing zeros for all its operands. + uint32_t TZ = 0; + for (const SCEV *Operand : M->operands()) + TZ += getMinTrailingZeros(Operand); + return GetShiftedByZeros(TZ); } - case scVScale: - return 0; - case scUDivExpr: - return 0; - case scPtrToInt: case scAddExpr: - case scAddRecExpr: + case scAddRecExpr: { + const SCEVNAryExpr *N = cast(S); + if (N->hasNoUnsignedWrap()) { + // The result is GCD of all operands results. + APInt Res = getConstantMultiple(N->getOperand(0)); + for (unsigned I = 1, E = N->getNumOperands(); I < E && Res != 1; ++I) + Res = APIntOps::GreatestCommonDivisor( + Res, getConstantMultiple(N->getOperand(I))); + return Res; + } + } + // If there is no unsigned wrap guarentees, fall through to find trailing + // bits. + LLVM_FALLTHROUGH; case scUMaxExpr: case scSMaxExpr: case scUMinExpr: case scSMinExpr: case scSequentialUMinExpr: { - // The result is the min of all operands results. - ArrayRef Ops = S->operands(); - uint32_t MinOpRes = getMinTrailingZeros(Ops[0]); - for (unsigned I = 1, E = Ops.size(); MinOpRes && I != E; ++I) - MinOpRes = std::min(MinOpRes, getMinTrailingZeros(Ops[I])); - return MinOpRes; + const SCEVNAryExpr *N = cast(S); + // Find the trailing bits, which is the minimum of its operands. + uint32_t TZ = getMinTrailingZeros(N->getOperand(0)); + for (const SCEV *Operand : N->operands().drop_front()) + TZ = std::min(TZ, getMinTrailingZeros(Operand)); + return GetShiftedByZeros(TZ); } case scUnknown: { + // ask ValueTracking for known bits const SCEVUnknown *U = cast(S); - // For a SCEVUnknown, ask ValueTracking. - KnownBits Known = - computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT); - return Known.countMinTrailingZeros(); + unsigned Known = + computeKnownBits(U->getValue(), getDataLayout(), 0, &AC, nullptr, &DT) + .countMinTrailingZeros(); + return GetShiftedByZeros(Known); } case scCouldNotCompute: llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); @@ -6341,17 +6369,27 @@ llvm_unreachable("Unknown SCEV kind!"); } -uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) { - auto I = MinTrailingZerosCache.find(S); - if (I != MinTrailingZerosCache.end()) +APInt ScalarEvolution::getConstantMultiple(const SCEV *S) { + auto I = ConstantMultipleCache.find(S); + if (I != ConstantMultipleCache.end()) return I->second; - uint32_t Result = getMinTrailingZerosImpl(S); - auto InsertPair = MinTrailingZerosCache.insert({S, Result}); + APInt Result = getConstantMultipleImpl(S); + auto InsertPair = ConstantMultipleCache.insert({S, Result}); assert(InsertPair.second && "Should insert a new key"); return InsertPair.first->second; } +APInt ScalarEvolution::getNonZeroConstantMultiple(const SCEV *S) { + APInt Multiple = getConstantMultiple(S); + return Multiple == 0 ? APInt(Multiple.getBitWidth(), 1) : Multiple; +} + +uint32_t ScalarEvolution::getMinTrailingZeros(const SCEV *S) { + return std::min(getConstantMultiple(S).countTrailingZeros(), + (unsigned)getTypeSizeInBits(S->getType())); +} + /// Helper method to assign a range to V from metadata present in the IR. static std::optional GetRangeFromMetadata(Value *V) { if (Instruction *I = dyn_cast(V)) @@ -6600,16 +6638,21 @@ // If the value has known zeros, the maximum value will have those known zeros // as well. - uint32_t TZ = getMinTrailingZeros(S); - if (TZ != 0) { - if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) + if (SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED) { + APInt Multiple = getNonZeroConstantMultiple(S); + APInt Remainder = APInt::getMaxValue(BitWidth).urem(Multiple); + if (!Remainder.isZero()) ConservativeResult = ConstantRange(APInt::getMinValue(BitWidth), - APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1); - else + APInt::getMaxValue(BitWidth) - Remainder + 1); + } + else { + uint32_t TZ = getMinTrailingZeros(S); + if (TZ != 0) { ConservativeResult = ConstantRange( APInt::getSignedMinValue(BitWidth), APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1); + } } switch (S->getSCEVType()) { @@ -8228,10 +8271,10 @@ }; const SCEVConstant *TC = dyn_cast(TCExpr); - if (!TC) - // Attempt to factor more general cases. Returns the greatest power of - // two divisor. - return GetSmallMultiple(getMinTrailingZeros(TCExpr)); + if (!TC) { + APInt Multiple = getNonZeroConstantMultiple(TCExpr); + return Multiple.getActiveBits() > 32 ? 1 : *Multiple.getRawData(); + } ConstantInt *Result = TC->getValue(); assert(Result && "SCEVConstant expected to have non-null ConstantInt"); @@ -8412,7 +8455,7 @@ SignedRanges.clear(); ExprValueMap.clear(); HasRecMap.clear(); - MinTrailingZerosCache.clear(); + ConstantMultipleCache.clear(); PredicatedSCEVRewrites.clear(); FoldCache.clear(); FoldCacheUser.clear(); @@ -13437,7 +13480,7 @@ PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), PendingPhiRanges(std::move(Arg.PendingPhiRanges)), PendingMerges(std::move(Arg.PendingMerges)), - MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), + ConstantMultipleCache(std::move(Arg.ConstantMultipleCache)), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( std::move(Arg.PredicatedBackedgeTakenCounts)), @@ -13912,7 +13955,7 @@ UnsignedRanges.erase(S); SignedRanges.erase(S); HasRecMap.erase(S); - MinTrailingZerosCache.erase(S); + ConstantMultipleCache.erase(S); if (auto *AR = dyn_cast(S)) { UnsignedWrapViaInductionTried.erase(AR); @@ -14292,6 +14335,17 @@ } } } + + // Verify that ConstantMultipleCache computations are correct. + for (auto [S, Multiple] : ConstantMultipleCache) { + APInt RecomputedMultiple = SE2.getConstantMultipleImpl(S); + if (Multiple != RecomputedMultiple) { + dbgs() << "Incorrect computation in ConstantMultipleCache for " << *S + << " : Expected " << RecomputedMultiple << " but got " << Multiple + << "!\n"; + std::abort(); + } + } } bool ScalarEvolution::invalidate( diff --git a/llvm/test/Analysis/ScalarEvolution/nsw.ll b/llvm/test/Analysis/ScalarEvolution/nsw.ll --- a/llvm/test/Analysis/ScalarEvolution/nsw.ll +++ b/llvm/test/Analysis/ScalarEvolution/nsw.ll @@ -322,7 +322,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,7}<%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 7 -; CHECK-NEXT: --> {7,+,7}<%loop> U: [7,0) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {7,+,7}<%loop> U: [7,-3) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @bad_postinc_nsw_a ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 613566756 diff --git a/llvm/test/Analysis/ScalarEvolution/ranges.ll b/llvm/test/Analysis/ScalarEvolution/ranges.ll --- a/llvm/test/Analysis/ScalarEvolution/ranges.ll +++ b/llvm/test/Analysis/ScalarEvolution/ranges.ll @@ -1,6 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py - ; RUN: opt < %s -disable-output "-passes=print" 2>&1 | FileCheck %s - ; RUN: opt < %s -disable-output "-passes=print" -scev-range-iter-threshold=1 2>&1 | FileCheck %s + ; RUN: opt < %s -disable-output "-passes=print,verify" 2>&1 | FileCheck %s + ; RUN: opt < %s -disable-output "-passes=print,verify" -scev-range-iter-threshold=1 2>&1 | FileCheck %s target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64" @@ -133,7 +133,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,6}<%loop> U: [0,-2147483648) S: [0,2147483647) Exits: (6 * ((((-1 * (1 umin %n)) + %n) /u 6) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 6 -; CHECK-NEXT: --> {6,+,6}<%loop> U: [6,-1) S: [-2147483648,2147483647) Exits: (6 + (6 * ((((-1 * (1 umin %n)) + %n) /u 6) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {6,+,6}<%loop> U: [6,-3) S: [-2147483648,2147483647) Exits: (6 + (6 * ((((-1 * (1 umin %n)) + %n) /u 6) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @add_6 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 6) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 715827882 @@ -160,7 +160,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,7}<%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 7 -; CHECK-NEXT: --> {7,+,7}<%loop> U: [7,0) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {7,+,7}<%loop> U: [7,-3) S: [7,0) Exits: (7 + (7 * ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @add_7 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 7) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 613566756 @@ -215,7 +215,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,9}<%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: (9 * ((((-1 * (1 umin %n)) + %n) /u 9) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 9 -; CHECK-NEXT: --> {9,+,9}<%loop> U: [9,0) S: [9,0) Exits: (9 + (9 * ((((-1 * (1 umin %n)) + %n) /u 9) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {9,+,9}<%loop> U: [9,-3) S: [9,0) Exits: (9 + (9 * ((((-1 * (1 umin %n)) + %n) /u 9) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @add_9 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 9) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 477218588 @@ -243,7 +243,7 @@ ; CHECK-NEXT: %iv = phi i32 [ 0, %entry ], [ %iv.inc, %loop ] ; CHECK-NEXT: --> {0,+,10}<%loop> U: [0,-2147483648) S: [0,2147483647) Exits: (10 * ((((-1 * (1 umin %n)) + %n) /u 10) + (1 umin %n))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.inc = add nsw i32 %iv, 10 -; CHECK-NEXT: --> {10,+,10}<%loop> U: [10,-1) S: [-2147483648,2147483647) Exits: (10 + (10 * ((((-1 * (1 umin %n)) + %n) /u 10) + (1 umin %n)))) LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {10,+,10}<%loop> U: [10,-5) S: [-2147483648,2147483647) Exits: (10 + (10 * ((((-1 * (1 umin %n)) + %n) /u 10) + (1 umin %n)))) LoopDispositions: { %loop: Computable } ; CHECK-NEXT: Determining loop execution counts for: @add_10 ; CHECK-NEXT: Loop %loop: backedge-taken count is ((((-1 * (1 umin %n)) + %n) /u 10) + (1 umin %n)) ; CHECK-NEXT: Loop %loop: constant max backedge-taken count is 429496729 diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll --- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll @@ -520,7 +520,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 5 ; entry: %u = urem i32 %num, 5 diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple.ll --- a/llvm/test/Analysis/ScalarEvolution/trip-multiple.ll +++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py -; RUN: opt -passes='print' -disable-output %s 2>&1 | FileCheck %s +; RUN: opt -passes='print,verify' -disable-output %s 2>&1 | FileCheck %s ; Test trip multiples with functions that look like: @@ -29,7 +29,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 3 ; entry: %rem = urem i32 %num, 3 @@ -102,7 +102,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 5 ; entry: %rem = urem i32 %num, 5 @@ -139,7 +139,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 6 ; entry: %rem = urem i32 %num, 6 @@ -176,7 +176,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 7 ; entry: %rem = urem i32 %num, 7 @@ -249,7 +249,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 1 +; CHECK: Loop %for.body: Trip multiple is 9 ; entry: %rem = urem i32 %num, 9 @@ -285,7 +285,7 @@ ; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num) ; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num) ; CHECK-NEXT: Predicates: -; CHECK: Loop %for.body: Trip multiple is 2 +; CHECK: Loop %for.body: Trip multiple is 10 ; entry: %rem = urem i32 %num, 10