Index: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h =================================================================== --- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h +++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h @@ -551,19 +551,36 @@ const SCEV *ExactNotTaken; const SCEV *MaxNotTaken; - /// A predicate union guard for this ExitLimit. The result is only - /// valid if this predicate evaluates to 'true' at run-time. - SCEVUnionPredicate Predicate; + /// A set of predicate guards for this ExitLimit. The result is only valid + /// if all of the predicates in \c Predicates evaluate to 'true' at + /// run-time. + SmallPtrSet Predicates; + + void addPredicate(const SCEVPredicate *P) { + assert(!isa(P) && "Only add leaf predicates here!"); + Predicates.insert(P); + } /*implicit*/ ExitLimit(const SCEV *E) : ExactNotTaken(E), MaxNotTaken(E) {} - ExitLimit(const SCEV *E, const SCEV *M, SCEVUnionPredicate &P) - : ExactNotTaken(E), MaxNotTaken(M), Predicate(P) { + ExitLimit( + const SCEV *E, const SCEV *M, + ArrayRef *> PredSetList) + : ExactNotTaken(E), MaxNotTaken(M) { assert((isa(ExactNotTaken) || !isa(MaxNotTaken)) && "Exact is not allowed to be less precise than Max"); + for (auto *PredSet : PredSetList) + for (auto *P : *PredSet) + addPredicate(P); } + ExitLimit(const SCEV *E, const SCEV *M, + const SmallPtrSetImpl &PredSet) + : ExitLimit(E, M, {&PredSet}) {} + + ExitLimit(const SCEV *E, const SCEV *M) : ExitLimit(E, M, None) {} + /// Test whether this ExitLimit contains any computed information, or /// whether it's all SCEVCouldNotCompute values. bool hasAnyInfo() const { @@ -1581,9 +1598,9 @@ SCEVUnionPredicate &A); /// Tries to convert the \p S expression to an AddRec expression, /// adding additional predicates to \p Preds as required. - const SCEVAddRecExpr * - convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, - SCEVUnionPredicate &Preds); + const SCEVAddRecExpr *convertSCEVToAddRecWithPredicates( + const SCEV *S, const Loop *L, + SmallPtrSetImpl &Preds); private: /// Compute the backedge taken count knowing the interval difference, the Index: llvm/trunk/lib/Analysis/ScalarEvolution.cpp =================================================================== --- llvm/trunk/lib/Analysis/ScalarEvolution.cpp +++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp @@ -5656,11 +5656,14 @@ [&](const EdgeExitInfo &EEI) { BasicBlock *ExitBB = EEI.first; const ExitLimit &EL = EEI.second; - if (EL.Predicate.isAlwaysTrue()) + if (EL.Predicates.empty()) return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr); - return ExitNotTakenInfo( - ExitBB, EL.ExactNotTaken, - llvm::make_unique(std::move(EL.Predicate))); + + std::unique_ptr Predicate(new SCEVUnionPredicate); + for (auto *Pred : EL.Predicates) + Predicate->add(Pred); + + return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate)); }); } @@ -5691,7 +5694,7 @@ BasicBlock *ExitBB = ExitingBlocks[i]; ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); - assert((AllowPredicates || EL.Predicate.isAlwaysTrue()) && + assert((AllowPredicates || EL.Predicates.empty()) && "Predicated exit limit when predicates are not allowed!"); // 1. For each exit that can be computed, add an entry to ExitCounts. @@ -5861,9 +5864,6 @@ BECount = EL0.ExactNotTaken; } - SCEVUnionPredicate NP; - NP.add(&EL0.Predicate); - NP.add(&EL1.Predicate); // There are cases (e.g. PR26207) where computeExitLimitFromCond is able // to be more aggressive when computing BECount than when computing // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and @@ -5873,7 +5873,7 @@ !isa(BECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, NP); + return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates}); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. @@ -5912,10 +5912,7 @@ BECount = EL0.ExactNotTaken; } - SCEVUnionPredicate NP; - NP.add(&EL0.Predicate); - NP.add(&EL1.Predicate); - return ExitLimit(BECount, MaxBECount, NP); + return ExitLimit(BECount, MaxBECount, {&EL0.Predicates, &EL1.Predicates}); } } @@ -6300,8 +6297,7 @@ unsigned BitWidth = getTypeSizeInBits(RHS->getType()); const SCEV *UpperBound = getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); - SCEVUnionPredicate P; - return ExitLimit(getCouldNotCompute(), UpperBound, P); + return ExitLimit(getCouldNotCompute(), UpperBound); } return getCouldNotCompute(); @@ -7062,7 +7058,7 @@ // effectively V != 0. We know and take advantage of the fact that this // expression only being used in a comparison by zero context. - SCEVUnionPredicate P; + SmallPtrSet Predicates; // If the value is a constant if (const SCEVConstant *C = dyn_cast(V)) { // If the value is already zero, the branch will execute zero times. @@ -7075,7 +7071,7 @@ // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the // algorithm below. - AddRec = convertSCEVToAddRecWithPredicates(V, L, P); + AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates); if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); @@ -7097,7 +7093,7 @@ // should not accept a root of 2. const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); if (Val->isZero()) - return ExitLimit(R1, R1, P); // We found a quadratic root! + return ExitLimit(R1, R1, Predicates); // We found a quadratic root! } } return getCouldNotCompute(); @@ -7154,7 +7150,7 @@ else MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount, P); + return ExitLimit(Distance, MaxBECount, Predicates); } // As a special case, handle the instance where Step is a positive power of @@ -7209,7 +7205,7 @@ const SCEV *Limit = getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); - return ExitLimit(Limit, Limit, P); + return ExitLimit(Limit, Limit, Predicates); } } @@ -7222,14 +7218,14 @@ loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact, P); + return ExitLimit(Exact, Exact, Predicates); } // Then, try to solve the above equation provided that Start is constant. if (const SCEVConstant *StartC = dyn_cast(Start)) { const SCEV *E = SolveLinEquationWithOverflow( StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); - return ExitLimit(E, E, P); + return ExitLimit(E, E, Predicates); } return getCouldNotCompute(); } @@ -8634,7 +8630,7 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsExit, bool AllowPredicates) { - SCEVUnionPredicate P; + SmallPtrSet Predicates; // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -8646,7 +8642,7 @@ // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the // algorithm below. - IV = convertSCEVToAddRecWithPredicates(LHS, L, P); + IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); PredicatedIV = true; } @@ -8762,14 +8758,14 @@ if (isa(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, P); + return ExitLimit(BECount, MaxBECount, Predicates); } ScalarEvolution::ExitLimit ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsExit, bool AllowPredicates) { - SCEVUnionPredicate P; + SmallPtrSet Predicates; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -8779,7 +8775,7 @@ // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the // algorithm below. - IV = convertSCEVToAddRecWithPredicates(LHS, L, P); + IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8839,7 +8835,7 @@ if (isa(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, P); + return ExitLimit(BECount, MaxBECount, Predicates); } const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, @@ -10161,25 +10157,34 @@ class SCEVPredicateRewriter : public SCEVRewriteVisitor { public: - // Rewrites \p S in the context of a loop L and the predicate A. - // If Assume is true, rewrite is free to add further predicates to A - // such that the result will be an AddRecExpr. + /// Rewrites \p S in the context of a loop L and the SCEV predication + /// infrastructure. + /// + /// If \p Pred is non-null, the SCEV expression is rewritten to respect the + /// equivalences present in \p Pred. + /// + /// If \p NewPreds is non-null, rewrite is free to add further predicates to + /// \p NewPreds such that the result will be an AddRecExpr. static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, - SCEVUnionPredicate &A, bool Assume) { - SCEVPredicateRewriter Rewriter(L, SE, A, Assume); + SmallPtrSetImpl *NewPreds, + SCEVUnionPredicate *Pred) { + SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred); return Rewriter.visit(S); } SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, - SCEVUnionPredicate &P, bool Assume) - : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} + SmallPtrSetImpl *NewPreds, + SCEVUnionPredicate *Pred) + : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { - auto ExprPreds = P.getPredicatesForExpr(Expr); - for (auto *Pred : ExprPreds) - if (const auto *IPred = dyn_cast(Pred)) - if (IPred->getLHS() == Expr) - return IPred->getRHS(); + if (Pred) { + auto ExprPreds = Pred->getPredicatesForExpr(Expr); + for (auto *Pred : ExprPreds) + if (const auto *IPred = dyn_cast(Pred)) + if (IPred->getLHS() == Expr) + return IPred->getRHS(); + } return Expr; } @@ -10220,32 +10225,31 @@ bool addOverflowAssumption(const SCEVAddRecExpr *AR, SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { auto *A = SE.getWrapPredicate(AR, AddedFlags); - if (!Assume) { + if (!NewPreds) { // Check if we've already made this assumption. - if (P.implies(A)) - return true; - return false; + return Pred && Pred->implies(A); } - P.add(A); + NewPreds->insert(A); return true; } - SCEVUnionPredicate &P; + SmallPtrSetImpl *NewPreds; + SCEVUnionPredicate *Pred; const Loop *L; - bool Assume; }; } // end anonymous namespace const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false); + return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds); } -const SCEVAddRecExpr * -ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, - SCEVUnionPredicate &Preds) { - SCEVUnionPredicate TransformPreds; - S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true); +const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( + const SCEV *S, const Loop *L, + SmallPtrSetImpl &Preds) { + + SmallPtrSet TransformPreds; + S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); auto *AddRec = dyn_cast(S); if (!AddRec) @@ -10253,7 +10257,9 @@ // Since the transformation was successful, we can now transfer the SCEV // predicates. - Preds.add(&TransformPreds); + for (auto *P : TransformPreds) + Preds.insert(P); + return AddRec; } @@ -10480,11 +10486,15 @@ const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { const SCEV *Expr = this->getSCEV(V); - auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); + SmallPtrSet NewPreds; + auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds); if (!New) return nullptr; + for (auto *P : NewPreds) + Preds.add(P); + updateGeneration(); RewriteMap[SE.getSCEV(V)] = {Generation, New}; return New;