Index: lib/Analysis/ScalarEvolution.cpp =================================================================== --- lib/Analysis/ScalarEvolution.cpp +++ lib/Analysis/ScalarEvolution.cpp @@ -8247,10 +8247,9 @@ return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); } -/// Find the roots of the quadratic equation for the given quadratic chrec -/// {L,+,M,+,N}. This returns either the two roots (which might be the same) or -/// two SCEVCouldNotCompute objects. -static Optional> +/// Find the smallest non-negative number N, such that after N iterations +/// the quadratic chrec {L,+,M,+,N} equals 0 or changes sign. +static Optional SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); const SCEVConstant *LC = dyn_cast(AddRec->getOperand(0)); @@ -8262,54 +8261,112 @@ return None; uint32_t BitWidth = LC->getAPInt().getBitWidth(); - const APInt &L = LC->getAPInt(); - const APInt &M = MC->getAPInt(); - const APInt &N = NC->getAPInt(); - APInt Two(BitWidth, 2); - - // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C - - // The A coefficient is N/2 - APInt A = N.sdiv(Two); + APInt L = LC->getAPInt(); + APInt M = MC->getAPInt(); + APInt N = NC->getAPInt(); + + uint32_t NewWidth = std::max(32u, 2*BitWidth); + if (BitWidth == 1) { + N = N.zext(NewWidth); + M = M.zext(NewWidth); + L = L.zext(NewWidth); + } else { + N = N.sext(NewWidth); + M = M.sext(NewWidth); + L = L.sext(NewWidth); + } - // The B coefficient is M-N/2 - APInt B = M; - B -= A; // A is the same as N/2. + // The increments are M, M+N, M+2N, ..., so the accumulated values are + // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is, + // L+M, L+2M+N, L+3M+3N, ... + // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N. + // + // The equation Acc = 0 is then + // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0. + // In a quadratic form it becomes: + // N n^2 + (2M-N) n + 2L = 0. + + APInt A = N; + APInt B = 2 * M - A; + APInt C = 2 * L; + APInt R = APInt::getOneBitSet(NewWidth, BitWidth); + + // Make A >= 0 for simplicity. + if (A.isNegative()) { + A.negate(); + B.negate(); + C.negate(); + R.negate(); + } + LLVM_DEBUG(dbgs() << __func__ << ": solving " << A << "x^2 + " << B + << "x + " << C << " for AddRec: " << *AddRec << '\n'); + + APInt D = B*B - 4*A*C; + + // If there are no solutions, or if both are negative, try solving + // An^2 + Bn + C = 2R, + // where R is 2^BitWidth (i.e. zero in the original type). It's 2R + // because the whole equation was multiplied by 2. + bool NegB = B.isNegative(), NegC = C.isNegative(); + bool BothNeg = !NegB && !NegC; + + if (D.isNegative() || BothNeg) { + C -= 2 * R; + D += 8 * A * R; + NegC = C.isNegative(); + } + if (D.isNegative()) + return None; - // The C coefficient is L. - const APInt& C = L; + auto TruncToOrig = [BitWidth,&SE] (const APInt &V) { + if (V.isIntN(BitWidth)) + return cast(SE.getConstant(V.trunc(BitWidth))); + return cast(SE.getConstant(V)); + }; - // Compute the B^2-4ac term. - APInt SqrtTerm = B; - SqrtTerm *= B; - SqrtTerm -= 4 * (A * C); + APInt X; - if (SqrtTerm.isNegative()) { - // The loop is provably infinite. - return None; - } + APInt SQ = D.sqrt(); + APInt TwoA = 2 * A; - // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest - // integer value or else APInt::sqrt() will assert. - APInt SqrtVal = SqrtTerm.sqrt(); - - // Compute the two solutions for the quadratic formula. - // The divisions must be performed as signed divisions. - APInt NegB = -std::move(B); - APInt TwoA = std::move(A); - TwoA <<= 1; - if (TwoA.isNullValue()) + APInt Rem; + // Calculate the smallest positive root. If C > 0, both roots have + // the same sign, if C < 0 they have different signs. + if (NegC) + APInt::sdivrem(-B + SQ, TwoA, X, Rem); + else + APInt::sdivrem(-B - SQ, TwoA, X, Rem); + + if (SQ * SQ == D) { + // If the division was not exact, add 1. + if (!Rem.isNullValue()) + X += 1; + LLVM_DEBUG(dbgs() << __func__ << ": solution (exact): " << X << '\n'); + return TruncToOrig(X); + } + + // The actual value of the square root should be between SQ and 2*SQ. + // Calculate the "overestimate" of the root. + APInt Y = NegC ? (-B + 2*SQ).sdiv(TwoA) + : (-B - 2*SQ).sdiv(TwoA); + // If AX^2 + BX + C < 0 and AY^2 + BY + C > 0, then use bisection + // to find the best approximation of the root. + APInt VX = (A*X + B)*X + C; + APInt VY = (A*Y + B)*Y + C; + if (!VX.isNegative() || VY.isNegative()) return None; - LLVMContext &Context = SE.getContext(); - - ConstantInt *Solution1 = - ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); - ConstantInt *Solution2 = - ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); + while (X != Y) { + APInt T = (X + Y).sdiv(2); + APInt V = (A*T + B)*T + C; + if (V.isNegative()) + X = T+1; + else + Y = T; + } - return std::make_pair(cast(SE.getConstant(Solution1)), - cast(SE.getConstant(Solution2))); + LLVM_DEBUG(dbgs() << __func__ << ": solution (approx): " << X << '\n'); + return TruncToOrig(X); } ScalarEvolution::ExitLimit @@ -8344,23 +8401,15 @@ // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { - if (auto Roots = SolveQuadraticEquation(AddRec, *this)) { - const SCEVConstant *R1 = Roots->first; - const SCEVConstant *R2 = Roots->second; - // Pick the smallest positive root value. - if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp( - CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. - - // We can only use this value if the chrec ends up with an exact zero - // value at this index. When solving for "X*X != 5", for example, we - // should not accept a root of 2. - const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); - if (Val->isZero()) - // We found a quadratic root! - return ExitLimit(R1, R1, false, Predicates); - } + if (auto Root = SolveQuadraticEquation(AddRec, *this)) { + // We can only use this value if the chrec ends up with an exact zero + // value at this index. When solving for "X*X != 5", for example, we + // should not accept a root of 2. + const SCEVConstant *R = Root.getValue(); + const SCEV *Val = AddRec->evaluateAtIteration(R, *this); + if (Val->isZero()) + // We found a quadratic root! + return ExitLimit(R, R, false, Predicates); } return getCouldNotCompute(); } @@ -10477,41 +10526,33 @@ const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap); // Next, solve the constructed addrec - if (auto Roots = + if (auto Root = SolveQuadraticEquation(cast(NewAddRec), SE)) { - const SCEVConstant *R1 = Roots->first; - const SCEVConstant *R2 = Roots->second; - // Pick the smallest positive root value. - if (ConstantInt *CB = dyn_cast(ConstantExpr::getICmp( - ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. - + const SCEVConstant *R = Root.getValue(); // Make sure the root is not off by one. The returned iteration should // not be in the range, but the previous one should be. When solving // for "X*X < 5", for example, we should not return a root of 2. - ConstantInt *R1Val = - EvaluateConstantChrecAtConstant(this, R1->getValue(), SE); - if (Range.contains(R1Val->getValue())) { + ConstantInt *RootVal = + EvaluateConstantChrecAtConstant(this, R->getValue(), SE); + if (Range.contains(RootVal->getValue())) { // The next iteration must be out of the range... ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getAPInt() + 1); + ConstantInt::get(SE.getContext(), R->getAPInt() + 1); - R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); - if (!Range.contains(R1Val->getValue())) + RootVal = EvaluateConstantChrecAtConstant(this, NextVal, SE); + if (!Range.contains(RootVal->getValue())) return SE.getConstant(NextVal); return SE.getCouldNotCompute(); // Something strange happened } - // If R1 was not in the range, then it is a good return value. Make - // sure that R1-1 WAS in the range though, just in case. + // If RootVal was not in the range, then it is a good return value. + // Make sure that RootVal-1 WAS in the range though, just in case. ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getAPInt() - 1); - R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); - if (Range.contains(R1Val->getValue())) - return R1; + ConstantInt::get(SE.getContext(), R->getAPInt() - 1); + RootVal = EvaluateConstantChrecAtConstant(this, NextVal, SE); + if (Range.contains(RootVal->getValue())) + return R; return SE.getCouldNotCompute(); // Something strange happened - } } } Index: test/Analysis/ScalarEvolution/solve-quadratic-i1.ll =================================================================== --- /dev/null +++ test/Analysis/ScalarEvolution/solve-quadratic-i1.ll @@ -0,0 +1,39 @@ +; RUN: opt -analyze -scalar-evolution -S < %s | FileCheck %s + +; CHECK: Printing analysis 'Scalar Evolution Analysis' for function 'f1': +; CHECK-NEXT: Classifying expressions for: @f1 +; CHECK-NEXT: %a.0 = phi i16 [ 2, %entry ], [ %inc, %lbl1 ] +; CHECK-NEXT: --> {2,+,1}<%lbl1> U: [2,-32768) S: [2,-32768) Exits: 3 LoopDispositions: { %lbl1: Computable } +; CHECK-NEXT: %b.0 = phi i16 [ 1, %entry ], [ %add, %lbl1 ] +; CHECK-NEXT: --> {1,+,2,+,1}<%lbl1> U: full-set S: full-set Exits: 3 LoopDispositions: { %lbl1: Computable } +; CHECK-NEXT: %inc = add nsw i16 %a.0, 1 +; CHECK-NEXT: --> {3,+,1}<%lbl1> U: [3,0) S: [3,0) Exits: 4 LoopDispositions: { %lbl1: Computable } +; CHECK-NEXT: %add = add nsw i16 %b.0, %a.0 +; CHECK-NEXT: --> {3,+,3,+,1}<%lbl1> U: full-set S: full-set Exits: 6 LoopDispositions: { %lbl1: Computable } +; CHECK-NEXT: %and = and i16 %add, 1 +; CHECK-NEXT: --> (zext i1 {true,+,true,+,true}<%lbl1> to i16) U: [0,2) S: [0,2) Exits: 0 LoopDispositions: { %lbl1: Computable } +; CHECK-NEXT: Determining loop execution counts for: @f1 +; CHECK-NEXT: Loop %lbl1: backedge-taken count is 1 +; CHECK-NEXT: Loop %lbl1: max backedge-taken count is 1 +; CHECK-NEXT: Loop %lbl1: Predicated backedge-taken count is 1 +; CHECK-NEXT: Predicates: +; CHECK: Loop %lbl1: Trip multiple is 2 + +target triple = "x86_64-unknown-linux-gnu" + +define void @f1() { +entry: + br label %lbl1 + +lbl1: ; preds = %lbl1, %entry + %a.0 = phi i16 [ 2, %entry ], [ %inc, %lbl1 ] + %b.0 = phi i16 [ 1, %entry ], [ %add, %lbl1 ] + %inc = add nsw i16 %a.0, 1 + %add = add nsw i16 %b.0, %a.0 + %and = and i16 %add, 1 + %tobool = icmp ne i16 %and, 0 + br i1 %tobool, label %lbl1, label %if.end + +if.end: ; preds = %lbl1 + ret void +} Index: test/Analysis/ScalarEvolution/solve-quadratic.ll =================================================================== --- /dev/null +++ test/Analysis/ScalarEvolution/solve-quadratic.ll @@ -0,0 +1,27 @@ +; RUN: opt -analyze -scalar-evolution -S < %s | FileCheck %s + +; CHECK: Loop %b1: backedge-taken count is 255 + +@g0 = global i32 0, align 4 +@g1 = global i16 0, align 2 + +define signext i32 @f0() { +b0: + br label %b1 + +b1: ; preds = %b1, %b0 + %v1 = phi i16 [ 0, %b0 ], [ %v2, %b1 ] + %v2 = add i16 %v1, -1 + %v3 = mul i16 %v2, %v2 + %v4 = icmp eq i16 %v3, 0 + br i1 %v4, label %b2, label %b1 + +b2: ; preds = %b1 + %v5 = phi i16 [ %v2, %b1 ] + %v6 = phi i16 [ %v3, %b1 ] + %v7 = sext i16 %v5 to i32 + store i32 %v7, i32* @g0, align 4 + store i16 %v6, i16* @g1, align 2 + ret i32 0 +} +