Index: llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -6694,7 +6694,7 @@ return std::nullopt; } - if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { + if (isa(SE.getSymbolicMaxBackedgeTakenCount(L))) { LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n"); return std::nullopt; } @@ -6768,7 +6768,7 @@ auto getAlternateIVEnd = [&](PHINode &PN) -> const SCEV * { // FIXME: This does not properly account for overflow. const SCEVAddRecExpr *AddRec = cast(SE.getSCEV(&PN)); - const SCEV *BECount = SE.getBackedgeTakenCount(L); + const SCEV *BECount = SE.getSymbolicMaxBackedgeTakenCount(L); const SCEV *TermValueS = SE.getAddExpr( AddRec->getOperand(0), SE.getTruncateOrZeroExtend( @@ -6823,7 +6823,8 @@ LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() << "\nFound loop that can fold terminating condition\n" - << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" + << " BECount (SCEV): " << *SE.getSymbolicMaxBackedgeTakenCount(L) + << "\n" << " TermCond: " << *TermCond << "\n" << " BrandInst: " << *BI << "\n" << " ToFold: " << *ToFold << "\n" Index: llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-find-if.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-find-if.ll @@ -0,0 +1,66 @@ +; RUN: opt < %s -passes="loop-reduce" -S -lsr-term-fold -debug-only="loop-reduce" 2>&1 | FileCheck %s + +target datalayout = "e-m:e-p:32:32:32-n32:64:64-i64:64-i128:128-n32:64-S128" +target triple = "riscv32" + +; std::ind_if c source code +; int* __find_if(int *__first, int x, int *__last) { +; int __trip_count = (__last - __first); +; for (; __trip_count > 0 ; ++__trip_count) +; { +; if (*__first == x) +; return __first; +; ++__first; +; } +; return __first; +; } + +define dso_local ptr @__find_if(ptr noundef %__first, ptr noundef %__last, i32 noundef %x) { + +; CHECK: BECount (SCEV): (-1 + (((-1 * (ptrtoint ptr %__first to i32)) + (ptrtoint ptr %__last to i32)) /u 4)) + +; CHECK: for.body: +; CHECK-NEXT: [[LSR_IV:%.*]] = phi ptr [ [[INDEC_PTR:%.*]], %if.end ], [ %__first, %for.body.preheader ] +; CHECK-NEXT: [[LOAD:%.*]] = load i32, ptr [[LSR_IV]], align 4 +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[LOAD]], %x +; CHECK-NEXT: br i1 [[CMP]], label [[END_BRANCH:%.*]], label %[[IF_END:.*]] + +; CHECK: [[IF_END]]: +; CHECK-NEXT: [[INDEC_PTR]] = getelementptr inbounds i32, ptr [[LSR_IV]], i32 1 +; CHECK-NEXT: [[TERM_COND:%.*]] = icmp ne ptr [[INDEC_PTR]], [[LAST_PTR:%.*]] +; CHECK-NEXT: br i1 [[TERM_COND]], label %for.body, label [[END_BRANCH2:%.*]] + +entry: + %sub.ptr.lhs.cast = ptrtoint ptr %__last to i32 + %sub.ptr.rhs.cast = ptrtoint ptr %__first to i32 + %sub.ptr.sub = sub i32 %sub.ptr.lhs.cast, %sub.ptr.rhs.cast + %cmp6 = icmp sgt i32 %sub.ptr.sub, 0 + br i1 %cmp6, label %for.body.preheader, label %cleanup + +for.body.preheader: ; preds = %entry + %sub.ptr.div = lshr exact i32 %sub.ptr.sub, 2 + br label %for.body + +; Loop: +for.body: ; preds = %for.body.preheader, %if.end + %__trip_count.08 = phi i32 [ %dec, %if.end ], [ %sub.ptr.div, %for.body.preheader ] + %__first.addr.07 = phi ptr [ %incdec.ptr, %if.end ], [ %__first, %for.body.preheader ] + %load = load i32, ptr %__first.addr.07, align 4 + %cmp1 = icmp eq i32 %load, %x + br i1 %cmp1, label %cleanup.loopexit, label %if.end + +if.end: ; preds = %for.body + %incdec.ptr = getelementptr inbounds i32, ptr %__first.addr.07, i32 1 + %dec = add nsw i32 %__trip_count.08, -1 + %cmp = icmp ne i32 %__trip_count.08, 1 + br i1 %cmp, label %for.body, label %cleanup.loopexit + +; Exit blocks +cleanup.loopexit: ; preds = %for.body, %if.end + %__first.addr.0.lcssa.ph = phi ptr [ %incdec.ptr, %if.end ], [ %__first.addr.07, %for.body ] + br label %cleanup + +cleanup: ; preds = %cleanup.loopexit, %entry + %__first.addr.0.lcssa = phi ptr [ %__first, %entry ], [ %__first.addr.0.lcssa.ph, %cleanup.loopexit ] + ret ptr %__first.addr.0.lcssa +}