diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -13696,30 +13696,43 @@ return getUMinFromMismatchedTypes(ExitCounts); } -/// This rewriter is similar to SCEVParameterRewriter (it replaces SCEVUnknown -/// components following the Map (Value -> SCEV)), but skips AddRecExpr because -/// we cannot guarantee that the replacement is loop invariant in the loop of -/// the AddRec. +/// A rewriter to replace SCEV expressions in Map with the corresponding entry +/// in the map. It skips AddRecExpr because we cannot guarantee that the +/// replacement is loop invariant in the loop of the AddRec. +/// +/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is +/// supported. class SCEVLoopGuardRewriter : public SCEVRewriteVisitor { - ValueToSCEVMapTy ⤅ + DenseMap Map; public: - SCEVLoopGuardRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) + SCEVLoopGuardRewriter(ScalarEvolution &SE, + DenseMap &M) : SCEVRewriteVisitor(SE), Map(M) {} const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; } const SCEV *visitUnknown(const SCEVUnknown *Expr) { - auto I = Map.find(Expr->getValue()); + auto I = Map.find(Expr); if (I == Map.end()) return Expr; return I->second; } + + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + auto I = Map.find(Expr); + if (I == Map.end()) + return SCEVRewriteVisitor::visitZeroExtendExpr( + Expr); + return I->second; + } }; const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, - const SCEV *RHS, ValueToSCEVMapTy &RewriteMap) { + const SCEV *RHS, + DenseMap + &RewriteMap) { // WARNING: It is generally unsound to apply any wrap flags to the proposed // replacement SCEV which isn't directly implied by the structure of that // SCEV. In particular, using contextual facts to imply flags is *NOT* @@ -13736,8 +13749,8 @@ const SCEV *URemRHS = nullptr; if (matchURem(LHS, URemLHS, URemRHS)) { if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { - Value *V = LHSUnknown->getValue(); - RewriteMap[V] = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); + auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); + RewriteMap[LHSUnknown] = Multiple; return; } } @@ -13769,9 +13782,9 @@ // Bail out, unless we have a non-wrapping, monotonic range. if (ExactRegion.isWrappedSet() || ExactRegion.isFullSet()) return false; - auto I = RewriteMap.find(LHSUnknown->getValue()); + auto I = RewriteMap.find(LHSUnknown); const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown; - RewriteMap[LHSUnknown->getValue()] = getUMaxExpr( + RewriteMap[LHSUnknown] = getUMaxExpr( getConstant(ExactRegion.getUnsignedMin()), getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax()))); return true; @@ -13779,15 +13792,27 @@ if (MatchRangeCheckIdiom()) return; - // For now, limit to conditions that provide information about unknown - // expressions. RHS also cannot contain add recurrences. - auto *LHSUnknown = dyn_cast(LHS); - if (!LHSUnknown || containsAddRecurrence(RHS)) + // If RHS is SCEVUnknown, make sure the information is applied to it. + if (isa(RHS)) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + // If LHS is a constant, apply information to the other expression. + if (isa(LHS)) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + // Do not apply information for constants or if RHS contains an AddRec. + if (isa(LHS) || containsAddRecurrence(RHS)) + return; + + // Limit to expressions that can be rewritten. + if (!isa(LHS) && !isa(LHS)) return; // Check whether LHS has already been rewritten. In that case we want to // chain further rewrites onto the already rewritten value. - auto I = RewriteMap.find(LHSUnknown->getValue()); + auto I = RewriteMap.find(LHS); const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS; const SCEV *RewrittenRHS = nullptr; switch (Predicate) { @@ -13833,13 +13858,13 @@ } if (RewrittenRHS) - RewriteMap[LHSUnknown->getValue()] = RewrittenRHS; + RewriteMap[LHS] = RewrittenRHS; }; // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors // leading to the original header. // TODO: share this logic with isLoopEntryGuardedByCond. - ValueToSCEVMapTy RewriteMap; + DenseMap RewriteMap; for (std::pair Pair( L->getLoopPredecessor(), L->getHeader()); Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { @@ -13889,6 +13914,23 @@ if (RewriteMap.empty()) return Expr; + + // Now that all rewrite information is collect, rewrite the collected + // expressions with the information in the map. This applies information to + // sub-expressions. + if (RewriteMap.size() > 1) { + SmallVector ExprsToRewrite; + for (auto &KV : RewriteMap) + ExprsToRewrite.push_back(KV.first); + + for (const SCEV *Expr : ExprsToRewrite) { + const SCEV *RewriteTo = RewriteMap[Expr]; + RewriteMap.erase(Expr); + SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); + RewriteMap[Expr] = Rewriter.visit(RewriteTo); + } + } + SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); return Rewriter.visit(Expr); } diff --git a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll --- a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll +++ b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll @@ -7,7 +7,7 @@ define void @rewrite_zext(i32 %n) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 2 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1 @@ -36,7 +36,7 @@ define i32 @rewrite_zext_min_max(i32 %N, i32* %arr) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_min_max ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))) /u 4) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 4611686018427387903 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))) /u 4) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1 @@ -134,7 +134,7 @@ define void @rewrite_zext_and_base_1(i32 %n) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1 @@ -168,7 +168,7 @@ define void @rewrite_zext_and_base_2(i32 %n) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1