diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -4131,6 +4131,41 @@ return getMinusSCEV(getMinusOne(Ty), V); } +/// Compute an expression equivalent to S - getPointerBase(S). +static const SCEV *removePointerBase(ScalarEvolution *SE, const SCEV *P) { + assert(P->getType()->isPointerTy()); + + if (auto *AddRec = dyn_cast(P)) { + // The base of an AddRec is the first operand. + SmallVector Ops{AddRec->operands()}; + Ops[0] = removePointerBase(SE, Ops[0]); + // Don't try to transfer nowrap flags for now. We could in some cases + // (for example, if pointer operand of the AddRec is a SCEVUnknown). + return SE->getAddRecExpr(Ops, AddRec->getLoop(), SCEV::FlagAnyWrap); + } + if (auto *Add = dyn_cast(P)) { + // The base of an Add is the pointer operand. + SmallVector Ops{Add->operands()}; + const SCEV **PtrOp = nullptr; + for (const SCEV *&AddOp : Ops) { + if (AddOp->getType()->isPointerTy()) { + // If we find an Add with multiple pointer operands, treat it as a + // pointer base to be consistent with getPointerBase. Eventually + // we should be able to assert this is impossible. + if (PtrOp) + return SE->getZero(P->getType()); + PtrOp = &AddOp; + } + } + *PtrOp = removePointerBase(SE, *PtrOp); + // Don't try to transfer nowrap flags for now. We could in some cases + // (for example, if the pointer operand of the Add is a SCEVUnknown). + return SE->getAddExpr(Ops); + } + // Any other expression must be a pointer base. + return SE->getZero(P->getType()); +} + const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags, unsigned Depth) { @@ -4145,6 +4180,8 @@ if (!LHS->getType()->isPointerTy() || getPointerBase(LHS) != getPointerBase(RHS)) return getCouldNotCompute(); + LHS = removePointerBase(this, LHS); + RHS = removePointerBase(this, RHS); } // We represent LHS - RHS as LHS + (-1)*RHS. This transformation