diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -6664,6 +6664,25 @@ // For `IsToHelpFold`, other IV that is an affine AddRec will be sufficient to // replace the terminating condition auto IsToHelpFold = [&](PHINode &PN) -> bool { + const SCEVAddRecExpr *AddRec = cast(SE.getSCEV(&PN)); + const SCEV *BECount = SE.getBackedgeTakenCount(L); + const SCEV *TermValueS = SE.getAddExpr( + AddRec->getOperand(0), + SE.getTruncateOrZeroExtend( + SE.getMulExpr( + AddRec->getOperand(1), + SE.getTruncateOrZeroExtend( + SE.getAddExpr(BECount, SE.getOne(BECount->getType())), + AddRec->getOperand(1)->getType())), + AddRec->getOperand(0)->getType())); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); + if (!Expander.isSafeToExpand(TermValueS)) { + LLVM_DEBUG( + dbgs() << "Is not safe to expand terminating value for phi node" << PN + << "\n"); + return false; + } // TODO: Right now we limit the phi node to help the folding be of a start // value of getelementptr. We can extend to any kinds of IV as long as it is // an affine AddRec. Add a switch to cover more types of instructions here @@ -6813,41 +6832,33 @@ AddRec->getOperand(1)->getType())), AddRec->getOperand(0)->getType())); - // NOTE: If this is triggered, we should add this into predicate - if (!Expander.isSafeToExpand(TermValueS)) { - LLVMContext &Ctx = L->getHeader()->getContext(); - Ctx.emitError( - "Terminating value is not safe to expand, need to add it to " - "predicate"); - } else { // Now we replace the condition with ToHelpFold and remove ToFold - Changed = true; - NumTermFold++; - - Value *TermValue = Expander.expandCodeFor( - TermValueS, PtrTy, LoopPreheader->getTerminator()); - - LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" - << *StartValue << "\n" - << "Terminating value of new term-cond phi-node:\n" - << *TermValue << "\n"); - - // Create new terminating condition at loop latch - BranchInst *BI = cast(LoopLatch->getTerminator()); - ICmpInst *OldTermCond = cast(BI->getCondition()); - IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); - Value *NewTermCond = LatchBuilder.CreateICmp( - OldTermCond->getPredicate(), LoopValue, TermValue, - "lsr_fold_term_cond.replaced_term_cond"); - - LLVM_DEBUG(dbgs() << "Old term-cond:\n" - << *OldTermCond << "\n" - << "New term-cond:\b" << *NewTermCond << "\n"); - - BI->setCondition(NewTermCond); - - OldTermCond->eraseFromParent(); - DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); - } + Changed = true; + NumTermFold++; + + Value *TermValue = Expander.expandCodeFor(TermValueS, PtrTy, + LoopPreheader->getTerminator()); + + LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" + << *StartValue << "\n" + << "Terminating value of new term-cond phi-node:\n" + << *TermValue << "\n"); + + // Create new terminating condition at loop latch + BranchInst *BI = cast(LoopLatch->getTerminator()); + ICmpInst *OldTermCond = cast(BI->getCondition()); + IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); + Value *NewTermCond = LatchBuilder.CreateICmp( + OldTermCond->getPredicate(), LoopValue, TermValue, + "lsr_fold_term_cond.replaced_term_cond"); + + LLVM_DEBUG(dbgs() << "Old term-cond:\n" + << *OldTermCond << "\n" + << "New term-cond:\b" << *NewTermCond << "\n"); + + BI->setCondition(NewTermCond); + + OldTermCond->eraseFromParent(); + DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); ExpCleaner.markResultUsed(); } diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll @@ -158,3 +158,62 @@ for.end: ; preds = %for.body ret void } + +; The test case is reduced from FFmpeg/libavfilter/ebur128.c +; Testing check if terminating value is safe to expand +%struct.FFEBUR128State = type { i32, ptr, i64, i64 } + +@histogram_energy_boundaries = global [1001 x double] zeroinitializer, align 8 + +define void @ebur128_calc_gating_block(ptr %st, ptr %optional_output) { +; CHECK: Is not safe to expand terminating value for phi node %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ] +entry: + %0 = load i32, ptr %st, align 8 + %conv = zext i32 %0 to i64 + %cmp28.not = icmp eq i32 %0, 0 + br i1 %cmp28.not, label %for.end13, label %for.cond2.preheader.lr.ph + +for.cond2.preheader.lr.ph: ; preds = %entry + %audio_data_index = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 3 + %1 = load i64, ptr %audio_data_index, align 8 + %div = udiv i64 %1, %conv + %cmp525.not = icmp ult i64 %1, %conv + %audio_data = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 1 + %umax = tail call i64 @llvm.umax.i64(i64 %div, i64 1) + br label %for.cond2.preheader + +for.cond2.preheader: ; preds = %for.cond2.preheader.lr.ph, %for.inc11 + %channel_sum.030 = phi double [ 0.000000e+00, %for.cond2.preheader.lr.ph ], [ %channel_sum.1.lcssa, %for.inc11 ] + %c.029 = phi i64 [ 0, %for.cond2.preheader.lr.ph ], [ %inc12, %for.inc11 ] + br i1 %cmp525.not, label %for.inc11, label %for.body7.lr.ph + +for.body7.lr.ph: ; preds = %for.cond2.preheader + %2 = load ptr, ptr %audio_data, align 8 + br label %for.body7 + +for.body7: ; preds = %for.body7.lr.ph, %for.body7 + %channel_sum.127 = phi double [ %channel_sum.030, %for.body7.lr.ph ], [ %add10, %for.body7 ] + %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ] + %mul = mul i64 %i.026, %conv + %add = add i64 %mul, %c.029 + %arrayidx = getelementptr inbounds double, ptr %2, i64 %add + %3 = load double, ptr %arrayidx, align 8 + %add10 = fadd double %channel_sum.127, %3 + %inc = add nuw i64 %i.026, 1 + %exitcond.not = icmp eq i64 %inc, %umax + br i1 %exitcond.not, label %for.inc11, label %for.body7 + +for.inc11: ; preds = %for.body7, %for.cond2.preheader + %channel_sum.1.lcssa = phi double [ %channel_sum.030, %for.cond2.preheader ], [ %add10, %for.body7 ] + %inc12 = add nuw nsw i64 %c.029, 1 + %exitcond32.not = icmp eq i64 %inc12, %conv + br i1 %exitcond32.not, label %for.end13, label %for.cond2.preheader + +for.end13: ; preds = %for.inc11, %entry + %channel_sum.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %channel_sum.1.lcssa, %for.inc11 ] + %add14 = fadd double %channel_sum.0.lcssa, 0.000000e+00 + store double %add14, ptr %optional_output, align 8 + ret void +} + +declare i64 @llvm.umax.i64(i64, i64)