diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -13694,7 +13694,8 @@ /// in the map. It skips AddRecExpr because we cannot guarantee that the /// replacement is loop invariant in the loop of the AddRec. /// -/// At the moment only rewriting SCEVUnknown is supported. +/// At the moment only rewriting SCEVUnknown and SCEVZeroExtendExpr is +/// supported. class SCEVLoopGuardRewriter : public SCEVRewriteVisitor { const DenseMap ⤅ @@ -13711,9 +13712,18 @@ return Expr; return I->second; } + + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + auto I = Map.find(Expr); + if (I == Map.end()) + return SCEVRewriteVisitor::visitZeroExtendExpr( + Expr); + return I->second; + } }; const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { + SmallVector ExprsToRewrite; auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, DenseMap @@ -13736,6 +13746,7 @@ if (const SCEVUnknown *LHSUnknown = dyn_cast(URemLHS)) { auto Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS); RewriteMap[LHSUnknown] = Multiple; + ExprsToRewrite.push_back(LHSUnknown); return; } } @@ -13749,7 +13760,8 @@ // Check for a condition of the form (-C1 + X < C2). InstCombine will // create this form when combining two checks of the form (X u< C2 + C1) and // (X >=u C1). - auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap]() { + auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap, + &ExprsToRewrite]() { auto *AddExpr = dyn_cast(LHS); if (!AddExpr || AddExpr->getNumOperands() != 2) return false; @@ -13772,21 +13784,35 @@ RewriteMap[LHSUnknown] = getUMaxExpr( getConstant(ExactRegion.getUnsignedMin()), getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax()))); + ExprsToRewrite.push_back(LHSUnknown); return true; }; if (MatchRangeCheckIdiom()) return; - // For now, limit to conditions that provide information about unknown - // expressions. RHS also cannot contain add recurrences. - auto *LHSUnknown = dyn_cast(LHS); - if (!LHSUnknown || containsAddRecurrence(RHS)) + // If RHS is SCEVUnknown, make sure the information is applied to it. + if (isa(RHS)) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + // If LHS is a constant, apply information to the other expression. + if (isa(LHS)) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + // Do not apply information for constants or if RHS contains an AddRec. + if (isa(LHS) || containsAddRecurrence(RHS)) + return; + + // Limit to expressions that can be rewritten. + if (!isa(LHS) && !isa(LHS)) return; // Check whether LHS has already been rewritten. In that case we want to // chain further rewrites onto the already rewritten value. auto I = RewriteMap.find(LHS); const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHS; + const SCEV *RewrittenRHS = nullptr; switch (Predicate) { case CmpInst::ICMP_ULT: @@ -13830,8 +13856,11 @@ break; } - if (RewrittenRHS) + if (RewrittenRHS) { RewriteMap[LHS] = RewrittenRHS; + if (LHS == RewrittenLHS) + ExprsToRewrite.push_back(LHS); + } }; // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors @@ -13887,6 +13916,19 @@ if (RewriteMap.empty()) return Expr; + + // Now that all rewrite information is collect, rewrite the collected + // expressions with the information in the map. This applies information to + // sub-expressions. + if (ExprsToRewrite.size() > 1) { + for (const SCEV *Expr : ExprsToRewrite) { + const SCEV *RewriteTo = RewriteMap[Expr]; + RewriteMap.erase(Expr); + SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); + RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)}); + } + } + SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); return Rewriter.visit(Expr); } diff --git a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll --- a/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll +++ b/llvm/test/Analysis/ScalarEvolution/max-backedge-taken-count-guard-info-rewrite-expressions.ll @@ -7,7 +7,7 @@ define void @rewrite_zext(i32 %n) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 2 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1 @@ -36,7 +36,7 @@ define i32 @rewrite_zext_min_max(i32 %N, i32* %arr) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_min_max ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))) /u 4) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 4611686018427387903 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-4 + (4 * ((zext i32 (16 umin %N) to i64) /u 4))) /u 4) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1 @@ -134,7 +134,7 @@ define void @rewrite_zext_and_base_1(i32 %n) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1 @@ -168,7 +168,7 @@ define void @rewrite_zext_and_base_2(i32 %n) { ; CHECK-LABEL: Determining loop execution counts for: @rewrite_zext_and_base ; CHECK-NEXT: Loop %loop: backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) -; CHECK-NEXT: Loop %loop: max backedge-taken count is 2305843009213693951 +; CHECK-NEXT: Loop %loop: max backedge-taken count is 3 ; CHECK-NEXT: Loop %loop: Predicated backedge-taken count is ((-8 + (8 * ((zext i32 %n to i64) /u 8))) /u 8) ; CHECK-NEXT: Predicates: ; CHECK: Loop %loop: Trip multiple is 1