Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -4127,6 +4127,47 @@ bool Valid = true; }; +/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post +/// increment expression in case its Loop is L. If it is not L then +/// if IgnoreOtherLoops is true then use AddRec itself +/// otherwise rewrite cannot be done. +/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. +class SCEVPostIncRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, + bool IgnoreOtherLoops = false) { + SCEVPostIncRewriter Rewriter(L, SE, IgnoreOtherLoops); + const SCEV *Result = Rewriter.visit(S); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + if (!SE.isLoopInvariant(Expr, L)) + Valid = false; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + // Only allow AddRecExprs for this loop. + if (Expr->getLoop() == L) + return Expr->getPostIncExpr(SE); + // If we should not ignore other loops then invalidate result. + Valid = Valid && IgnoreOtherLoops; + return Expr; + } + + bool isValid() { return Valid; } + +private: + explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE, + bool IgnoreOtherLoops) + : SCEVRewriteVisitor(SE), L(L), IgnoreOtherLoops(IgnoreOtherLoops) {} + + const Loop *L; + const bool IgnoreOtherLoops; + bool Valid = true; +}; + /// This class evaluates the compare condition by matching it against the /// condition of loop latch. If there is a match we assume a true value /// for the condition while building SCEV nodes.