diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -1273,6 +1273,9 @@ /// Mark SCEVUnknown Phis currently being processed by getRangeRef. SmallPtrSet PendingPhiRanges; + /// Mark SCEVUnknown Phis currently being processed by getRangeRefIter. + SmallPtrSet PendingPhiRangesIter; + // Mark SCEVUnknown Phis currently being processed by isImpliedViaMerge. SmallPtrSet PendingMerges; @@ -1550,7 +1553,12 @@ /// Determine the range for a particular SCEV. /// NOTE: This returns a reference to an entry in a cache. It must be /// copied if its needed for longer. - const ConstantRange &getRangeRef(const SCEV *S, RangeSignHint Hint); + const ConstantRange &getRangeRef(const SCEV *S, RangeSignHint Hint, + unsigned Depth = 0); + + /// Determine the range for a particular SCEV, but evaluates ranges for + /// operands iteratively first. + const ConstantRange &getRangeRefIter(const SCEV *S, RangeSignHint Hint); /// Determines the range for the affine SCEVAddRecExpr {\p Start,+,\p Step}. /// Helper for \c getRange. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -6394,18 +6394,75 @@ return FullSet; } +const ConstantRange & +ScalarEvolution::getRangeRefIter(const SCEV *S, + ScalarEvolution::RangeSignHint SignHint) { + DenseMap &Cache = + SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges + : SignedRanges; + SmallVector WorkList; + SmallPtrSet Seen; + + // Add Expr to the worklist, if Expr is either an N-ary expression or a + // SCEVUnknown PHI node. + auto AddToWorklist = [&WorkList, &Seen, &Cache](const SCEV *Expr) { + if (!Seen.insert(Expr).second) + return; + if (Cache.find(Expr) != Cache.end()) + return; + if (isa(Expr)) + WorkList.push_back(Expr); + else if (auto *UnknownS = dyn_cast(Expr)) + if (isa(UnknownS->getValue())) + WorkList.push_back(Expr); + }; + AddToWorklist(S); + + // Build worklist by queuing operands of N-ary expressions and phi nodes. + for (unsigned I = 0; I != WorkList.size(); ++I) { + const SCEV *P = WorkList[I]; + if (auto *NaryS = dyn_cast(P)) { + for (const SCEV *Op : NaryS->operands()) + AddToWorklist(Op); + } else { + auto *UnknownS = cast(P); + if (const PHINode *P = dyn_cast(UnknownS->getValue())) { + if (!PendingPhiRangesIter.insert(P).second) + continue; + for (auto &Op : reverse(P->operands())) + AddToWorklist(getSCEV(Op)); + } + } + } + + if (!WorkList.empty()) { + // Use getRangeRef to compute ranges for items in the worklist in reverse + // order. This will force ranges for earlier operands to be computed before + // their users in most cases. + for (const SCEV *P : + reverse(make_range(WorkList.begin() + 1, WorkList.end()))) { + getRangeRef(P, SignHint); + + if (auto *UnknownS = dyn_cast(P)) + if (const PHINode *P = dyn_cast(UnknownS->getValue())) + PendingPhiRangesIter.erase(P); + } + } + + return getRangeRef(S, SignHint, 0); +} + /// Determine the range for a particular SCEV. If SignHint is /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges /// with a "cleaner" unsigned (resp. signed) representation. -const ConstantRange & -ScalarEvolution::getRangeRef(const SCEV *S, - ScalarEvolution::RangeSignHint SignHint) { +const ConstantRange &ScalarEvolution::getRangeRef( + const SCEV *S, ScalarEvolution::RangeSignHint SignHint, unsigned Depth) { DenseMap &Cache = SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? UnsignedRanges : SignedRanges; ConstantRange::PreferredRangeType RangeType = - SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED - ? ConstantRange::Unsigned : ConstantRange::Signed; + SignHint == ScalarEvolution::HINT_RANGE_UNSIGNED ? ConstantRange::Unsigned + : ConstantRange::Signed; // See if we've computed this range already. DenseMap::iterator I = Cache.find(S); @@ -6415,6 +6472,11 @@ if (const SCEVConstant *C = dyn_cast(S)) return setRange(C, SignHint, ConstantRange(C->getAPInt())); + // Switch to iteratively computing the range for S, if it is part of a deeply + // nested expression. + if (Depth > 32) + return getRangeRefIter(S, SignHint); + unsigned BitWidth = getTypeSizeInBits(S->getType()); ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true); using OBO = OverflowingBinaryOperator; @@ -6434,23 +6496,23 @@ } if (const SCEVAddExpr *Add = dyn_cast(S)) { - ConstantRange X = getRangeRef(Add->getOperand(0), SignHint); + ConstantRange X = getRangeRef(Add->getOperand(0), SignHint, Depth + 1); unsigned WrapType = OBO::AnyWrap; if (Add->hasNoSignedWrap()) WrapType |= OBO::NoSignedWrap; if (Add->hasNoUnsignedWrap()) WrapType |= OBO::NoUnsignedWrap; for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i) - X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint), + X = X.addWithNoWrap(getRangeRef(Add->getOperand(i), SignHint, Depth + 1), WrapType, RangeType); return setRange(Add, SignHint, ConservativeResult.intersectWith(X, RangeType)); } if (const SCEVMulExpr *Mul = dyn_cast(S)) { - ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint); + ConstantRange X = getRangeRef(Mul->getOperand(0), SignHint, Depth + 1); for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i) - X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint)); + X = X.multiply(getRangeRef(Mul->getOperand(i), SignHint, Depth + 1)); return setRange(Mul, SignHint, ConservativeResult.intersectWith(X, RangeType)); } @@ -6476,41 +6538,42 @@ } const auto *NAry = cast(S); - ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint); + ConstantRange X = getRangeRef(NAry->getOperand(0), SignHint, Depth + 1); for (unsigned i = 1, e = NAry->getNumOperands(); i != e; ++i) - X = X.intrinsic(ID, {X, getRangeRef(NAry->getOperand(i), SignHint)}); + X = X.intrinsic( + ID, {X, getRangeRef(NAry->getOperand(i), SignHint, Depth + 1)}); return setRange(S, SignHint, ConservativeResult.intersectWith(X, RangeType)); } if (const SCEVUDivExpr *UDiv = dyn_cast(S)) { - ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint); - ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint); + ConstantRange X = getRangeRef(UDiv->getLHS(), SignHint, Depth + 1); + ConstantRange Y = getRangeRef(UDiv->getRHS(), SignHint, Depth + 1); return setRange(UDiv, SignHint, ConservativeResult.intersectWith(X.udiv(Y), RangeType)); } if (const SCEVZeroExtendExpr *ZExt = dyn_cast(S)) { - ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint); + ConstantRange X = getRangeRef(ZExt->getOperand(), SignHint, Depth + 1); return setRange(ZExt, SignHint, ConservativeResult.intersectWith(X.zeroExtend(BitWidth), RangeType)); } if (const SCEVSignExtendExpr *SExt = dyn_cast(S)) { - ConstantRange X = getRangeRef(SExt->getOperand(), SignHint); + ConstantRange X = getRangeRef(SExt->getOperand(), SignHint, Depth + 1); return setRange(SExt, SignHint, ConservativeResult.intersectWith(X.signExtend(BitWidth), RangeType)); } if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast(S)) { - ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint); + ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint, Depth + 1); return setRange(PtrToInt, SignHint, X); } if (const SCEVTruncateExpr *Trunc = dyn_cast(S)) { - ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint); + ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint, Depth + 1); return setRange(Trunc, SignHint, ConservativeResult.intersectWith(X.truncate(BitWidth), RangeType)); @@ -6640,12 +6703,13 @@ RangeType); // A range of Phi is a subset of union of all ranges of its input. - if (const PHINode *Phi = dyn_cast(U->getValue())) { + if (PHINode *Phi = dyn_cast(U->getValue())) { // Make sure that we do not run over cycled Phis. if (PendingPhiRanges.insert(Phi).second) { ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false); + for (const auto &Op : Phi->operands()) { - auto OpRange = getRangeRef(getSCEV(Op), SignHint); + auto OpRange = getRangeRef(getSCEV(Op), SignHint, Depth + 1); RangeFromOps = RangeFromOps.unionWith(OpRange); // No point to continue if we already have a full set. if (RangeFromOps.isFullSet())