Index: llvm/trunk/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/trunk/lib/Analysis/ScalarEvolution.cpp +++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp @@ -3629,6 +3629,71 @@ } } +class SCEVInitRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE) { + SCEVInitRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(Scev); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } + + SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), Valid(true) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + Valid = false; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + // Only allow AddRecExprs for this loop. + if (Expr->getLoop() == L) + return Expr->getStart(); + Valid = false; + return Expr; + } + + bool isValid() { return Valid; } + +private: + const Loop *L; + bool Valid; +}; + +class SCEVShiftRewriter : public SCEVRewriteVisitor { +public: + static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + ScalarEvolution &SE) { + SCEVShiftRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(Scev); + return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + } + + SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), Valid(true) {} + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + // Only allow AddRecExprs for this loop. + if (!(SE.getLoopDisposition(Expr, L) == ScalarEvolution::LoopInvariant)) + Valid = false; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + if (Expr->getLoop() == L && Expr->isAffine()) + return SE.getMinusSCEV(Expr, Expr->getStepRecurrence(SE)); + Valid = false; + return Expr; + } + bool isValid() { return Valid; } + +private: + const Loop *L; + bool Valid; +}; + const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) @@ -3741,30 +3806,28 @@ return PHISCEV; } } - } else if (const auto *AddRec = dyn_cast(BEValue)) { + } else { // Otherwise, this could be a loop like this: // i = 0; for (j = 1; ..; ++j) { .... i = j; } // In this case, j = {1,+,1} and BEValue is j. // Because the other in-value of i (0) fits the evolution of BEValue // i really is an addrec evolution. - if (AddRec->getLoop() == L && AddRec->isAffine()) { + // + // We can generalize this saying that i is the shifted value of BEValue + // by one iteration: + // PHI(f(0), f({1,+,1})) --> f({0,+,1}) + const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); + const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); + if (Shifted != getCouldNotCompute() && + Start != getCouldNotCompute()) { const SCEV *StartVal = getSCEV(StartValueV); - - // If StartVal = j.start - j.stride, we can use StartVal as the - // initial step of the addrec evolution. - if (StartVal == - getMinusSCEV(AddRec->getOperand(0), AddRec->getOperand(1))) { - // FIXME: For constant StartVal, we should be able to infer - // no-wrap flags. - const SCEV *PHISCEV = getAddRecExpr(StartVal, AddRec->getOperand(1), - L, SCEV::FlagAnyWrap); - + if (Start == StartVal) { // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. ForgetSymbolicName(PN, SymbolicName); - ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; - return PHISCEV; + ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; + return Shifted; } } } Index: llvm/trunk/test/Analysis/ScalarEvolution/non-IV-phi.ll =================================================================== --- llvm/trunk/test/Analysis/ScalarEvolution/non-IV-phi.ll +++ llvm/trunk/test/Analysis/ScalarEvolution/non-IV-phi.ll @@ -0,0 +1,59 @@ +; RUN: opt -scalar-evolution -analyze < %s | FileCheck %s + +define void @test1(i8 %t, i32 %len) { +; CHECK-LABEL: test1 +; CHECK: %sphi = phi i32 [ %ext, %entry ], [ %idx.inc.ext, %loop ] +; CHECK-NEXT: --> (zext i8 {%t,+,1}<%loop> to i32) + + entry: + %st = zext i8 %t to i16 + %ext = zext i8 %t to i32 + %ecmp = icmp ult i16 %st, 42 + br i1 %ecmp, label %loop, label %exit + + loop: + + %idx = phi i8 [ %t, %entry ], [ %idx.inc, %loop ] + %sphi = phi i32 [ %ext, %entry ], [%idx.inc.ext, %loop] + + %idx.inc = add i8 %idx, 1 + %idx.inc.ext = zext i8 %idx.inc to i32 + %idx.ext = zext i8 %idx to i32 + + %c = icmp ult i32 %idx.inc.ext, %len + br i1 %c, label %loop, label %exit + + exit: + ret void +} + +define void @test2(i8 %t, i32 %len) { +; CHECK-LABEL: test2 +; CHECK: %sphi = phi i32 [ %ext.mul, %entry ], [ %mul, %loop ] +; CHECK-NEXT: --> (4 * (zext i8 {%t,+,1}<%loop> to i32)) + + entry: + %st = zext i8 %t to i16 + %ext = zext i8 %t to i32 + %ext.mul = mul i32 %ext, 4 + + %ecmp = icmp ult i16 %st, 42 + br i1 %ecmp, label %loop, label %exit + + loop: + + %idx = phi i8 [ %t, %entry ], [ %idx.inc, %loop ] + %sphi = phi i32 [ %ext.mul, %entry ], [%mul, %loop] + + %idx.inc = add i8 %idx, 1 + %idx.inc.ext = zext i8 %idx.inc to i32 + %mul = mul i32 %idx.inc.ext, 4 + + %idx.ext = zext i8 %idx to i32 + + %c = icmp ult i32 %idx.inc.ext, %len + br i1 %c, label %loop, label %exit + + exit: + ret void +}