diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h --- a/llvm/include/llvm/Analysis/ScalarEvolution.h +++ b/llvm/include/llvm/Analysis/ScalarEvolution.h @@ -653,8 +653,9 @@ /// Transitively follow the chain of pointer-type operands until reaching a /// SCEV that does not have a single pointer operand. This returns a /// SCEVUnknown pointer for well-formed pointer-type expressions, but corner - /// cases do exist. - const SCEV *getPointerBase(const SCEV *V); + /// cases do exist. DepthLimit allows to limit amount of traversal done, + /// limit of 0 (default) imposes no restrictions. + const SCEV *getPointerBase(const SCEV *V, unsigned DepthLimit = 0); /// Return a SCEV expression for the specified value at the specified scope /// in the program. The L value specifies a loop nest to evaluate the diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -3948,27 +3948,32 @@ return getUMinExpr(PromotedOps); } -const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { +const SCEV *ScalarEvolution::getPointerBase(const SCEV *V, + unsigned DepthLimit) { // A pointer operand may evaluate to a nonpointer expression, such as null. if (!V->getType()->isPointerTy()) return V; - if (const SCEVCastExpr *Cast = dyn_cast(V)) { - return getPointerBase(Cast->getOperand()); - } else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { - const SCEV *PtrOp = nullptr; - for (const SCEV *NAryOp : NAry->operands()) { - if (NAryOp->getType()->isPointerTy()) { - // Cannot find the base of an expression with multiple pointer operands. - if (PtrOp) - return V; - PtrOp = NAryOp; + for (unsigned Depth = 0; DepthLimit == 0 || Depth < DepthLimit; ++Depth) { + if (const SCEVCastExpr *Cast = dyn_cast(V)) { + V = Cast->getOperand(); + } else if (const SCEVNAryExpr *NAry = dyn_cast(V)) { + const SCEV *PtrOp = nullptr; + for (const SCEV *NAryOp : NAry->operands()) { + if (NAryOp->getType()->isPointerTy()) { + // Cannot find the base of an expression with multiple pointer ops. + if (PtrOp) + return V; + PtrOp = NAryOp; + } } - } - if (!PtrOp) + if (!PtrOp) // All operands were non-pointer. + return V; + V = PtrOp; + } else // Not something we can look into. return V; - return getPointerBase(PtrOp); } + // Depth cut-off reached, not allowed to look any furter. return V; }