diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -13694,7 +13694,8 @@ /// in the map. It skips AddRecExpr because we cannot guarantee that the /// replacement is loop invariant in the loop of the AddRec. /// -/// At the moment only rewriting SCEVUnknown is supported. +/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is +/// supported. class SCEVLoopGuardRewriter : public SCEVRewriteVisitor { const DenseMap ⤅ @@ -13711,6 +13712,14 @@ return Expr; return I->second; } + + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + auto I = Map.find(Expr); + if (I == Map.end()) + return SCEVRewriteVisitor::visitZeroExtendExpr( + Expr); + return I->second; + } }; const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { @@ -13777,10 +13786,22 @@ if (MatchRangeCheckIdiom()) return; - // For now, limit to conditions that provide information about unknown - // expressions. RHS also cannot contain add recurrences. - auto *LHSUnknown = dyn_cast(LHS); - if (!LHSUnknown || containsAddRecurrence(RHS)) + // If RHS is SCEVUnknown, make sure the information is applied to it. + if (isa(RHS)) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + // If LHS is a constant, apply information to the other expression. + if (isa(LHS)) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + // Do not apply information for constants or if RHS contains an AddRec. + if (isa(LHS) || containsAddRecurrence(RHS)) + return; + + // Limit to expressions that can be rewritten. + if (!isa(LHS) && !isa(LHS)) return; // Check whether LHS has already been rewritten. In that case we want to @@ -13887,6 +13908,23 @@ if (RewriteMap.empty()) return Expr; + + // Now that all rewrite information is collect, rewrite the collected + // expressions with the information in the map. This applies information to + // sub-expressions. + if (RewriteMap.size() > 1) { + SmallVector ExprsToRewrite; + for (auto &KV : RewriteMap) + ExprsToRewrite.push_back(KV.first); + + for (const SCEV *Expr : ExprsToRewrite) { + const SCEV *RewriteTo = RewriteMap[Expr]; + RewriteMap.erase(Expr); + SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); + RewriteMap[Expr] = Rewriter.visit(RewriteTo); + } + } + SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); return Rewriter.visit(Expr); } diff --git a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll --- a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll +++ b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll @@ -7,7 +7,7 @@ define void @rewrite_zext(i32 %n) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 2 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1 @@ -36,7 +36,7 @@ define i32 @rewrite_zext_min_max(i32 %N, i32* %arr) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_min_max ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))) /u 4) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 4611686018427387903 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))) /u 4) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1 @@ -134,7 +134,7 @@ define void @rewrite_zext_and_base_1(i32 %n) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1 @@ -168,7 +168,7 @@ define void @rewrite_zext_and_base_2(i32 %n) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1