diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -6614,7 +6614,7 @@ return nullptr; } -static Optional> +static Optional>> canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, const LoopInfo &LI) { if (!L->isInnermost()) { @@ -6699,16 +6699,37 @@ // For `IsToHelpFold`, other IV that is an affine AddRec will be sufficient to // replace the terminating condition - auto IsToHelpFold = [&](PHINode &PN) -> bool { + auto IsToHelpFold = [&](PHINode &PN) -> std::pair { + const SCEVAddRecExpr *AddRec = cast(SE.getSCEV(&PN)); + const SCEV *BECount = SE.getBackedgeTakenCount(L); + const SCEV *TermValueS = SE.getAddExpr( + AddRec->getOperand(0), + SE.getTruncateOrZeroExtend( + SE.getMulExpr( + AddRec->getOperand(1), + SE.getTruncateOrZeroExtend( + SE.getAddExpr(BECount, SE.getOne(BECount->getType())), + AddRec->getOperand(1)->getType())), + AddRec->getOperand(0)->getType())); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); + if (!Expander.isSafeToExpand(TermValueS)) { + LLVM_DEBUG( + dbgs() << "Is not safe to expand terminating value for phi node" << PN + << "\n"); + return {false, nullptr}; + } // TODO: Right now we limit the phi node to help the folding be of a start // value of getelementptr. We can extend to any kinds of IV as long as it is // an affine AddRec. Add a switch to cover more types of instructions here // and down in the actual transformation. - return isa(PN.getIncomingValueForBlock(LoopPreheader)); + return {isa(PN.getIncomingValueForBlock(LoopPreheader)), + TermValueS}; }; PHINode *ToFold = nullptr; PHINode *ToHelpFold = nullptr; + const SCEV *TermValueS = nullptr; for (PHINode &PN : L->getHeader()->phis()) { if (!SE.isSCEVable(PN.getType())) { @@ -6729,8 +6750,10 @@ if (IsToFold(PN)) ToFold = &PN; - else if (IsToHelpFold(PN)) + else if (auto P = IsToHelpFold(PN); P.first) { ToHelpFold = &PN; + TermValueS = P.second; + } } LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() @@ -6746,7 +6769,7 @@ if (!ToFold || !ToHelpFold) return None; - return {{ToFold, ToHelpFold}}; + return {{ToFold, {ToHelpFold, TermValueS}}}; } static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, @@ -6810,11 +6833,14 @@ if (AllowTerminatingConditionFoldingAfterLSR) { auto CanFoldTerminatingCondition = canFoldTermCondOfLoop(L, SE, DT, LI); if (CanFoldTerminatingCondition) { + Changed = true; + NumTermFold++; + BasicBlock *LoopPreheader = L->getLoopPreheader(); BasicBlock *LoopLatch = L->getLoopLatch(); PHINode *ToFold = CanFoldTerminatingCondition->first; - PHINode *ToHelpFold = CanFoldTerminatingCondition->second; + PHINode *ToHelpFold = CanFoldTerminatingCondition->second.first; (void)ToFold; LLVM_DEBUG(dbgs() << "To fold phi-node:\n" @@ -6834,56 +6860,35 @@ GetElementPtrInst *StartValueGEP = cast(StartValue); Type *PtrTy = StartValueGEP->getPointerOperand()->getType(); - const SCEV *BECount = SE.getBackedgeTakenCount(L); - const SCEVAddRecExpr *AddRec = - cast(SE.getSCEV(ToHelpFold)); - - // TermValue = Start + Stride * (BackedgeCount + 1) - const SCEV *TermValueS = SE.getAddExpr( - AddRec->getOperand(0), - SE.getTruncateOrZeroExtend( - SE.getMulExpr( - AddRec->getOperand(1), - SE.getTruncateOrZeroExtend( - SE.getAddExpr(BECount, SE.getOne(BECount->getType())), - AddRec->getOperand(1)->getType())), - AddRec->getOperand(0)->getType())); - - // NOTE: If this is triggered, we should add this into predicate - if (!Expander.isSafeToExpand(TermValueS)) { - LLVMContext &Ctx = L->getHeader()->getContext(); - Ctx.emitError( - "Terminating value is not safe to expand, need to add it to " - "predicate"); - } else { // Now we replace the condition with ToHelpFold and remove ToFold - Changed = true; - NumTermFold++; - - Value *TermValue = Expander.expandCodeFor( - TermValueS, PtrTy, LoopPreheader->getTerminator()); - - LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" - << *StartValue << "\n" - << "Terminating value of new term-cond phi-node:\n" - << *TermValue << "\n"); - - // Create new terminating condition at loop latch - BranchInst *BI = cast(LoopLatch->getTerminator()); - ICmpInst *OldTermCond = cast(BI->getCondition()); - IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); - Value *NewTermCond = LatchBuilder.CreateICmp( - OldTermCond->getPredicate(), LoopValue, TermValue, - "lsr_fold_term_cond.replaced_term_cond"); - - LLVM_DEBUG(dbgs() << "Old term-cond:\n" - << *OldTermCond << "\n" - << "New term-cond:\b" << *NewTermCond << "\n"); - - BI->setCondition(NewTermCond); - - OldTermCond->eraseFromParent(); - DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); - } + const SCEV *TermValueS = CanFoldTerminatingCondition->second.second; + assert( + Expander.isSafeToExpand(TermValueS) && + "Terminating value was checked safe in canFoldTerminatingCondition"); + + Value *TermValue = Expander.expandCodeFor(TermValueS, PtrTy, + LoopPreheader->getTerminator()); + + LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" + << *StartValue << "\n" + << "Terminating value of new term-cond phi-node:\n" + << *TermValue << "\n"); + + // Create new terminating condition at loop latch + BranchInst *BI = cast(LoopLatch->getTerminator()); + ICmpInst *OldTermCond = cast(BI->getCondition()); + IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); + Value *NewTermCond = LatchBuilder.CreateICmp( + OldTermCond->getPredicate(), LoopValue, TermValue, + "lsr_fold_term_cond.replaced_term_cond"); + + LLVM_DEBUG(dbgs() << "Old term-cond:\n" + << *OldTermCond << "\n" + << "New term-cond:\b" << *NewTermCond << "\n"); + + BI->setCondition(NewTermCond); + + OldTermCond->eraseFromParent(); + DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); ExpCleaner.markResultUsed(); } diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-term-fold-negative-testcase.ll @@ -158,3 +158,82 @@ for.end: ; preds = %for.body ret void } + +; The test case is reduced from FFmpeg/libavfilter/ebur128.c +; Testing check if terminating value is safe to expand +%struct.FFEBUR128State = type { i32, ptr, i64, i64 } + +@histogram_energy_boundaries = global [1001 x double] zeroinitializer, align 8 + +define void @ebur128_calc_gating_block(ptr %st, ptr %optional_output) { +; CHECK: Is not safe to expand terminating value for phi node %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ] +entry: + %0 = load i32, ptr %st, align 8 + %conv = zext i32 %0 to i64 + %cmp28.not = icmp eq i32 %0, 0 + br i1 %cmp28.not, label %for.end13, label %for.cond2.preheader.lr.ph + +for.cond2.preheader.lr.ph: ; preds = %entry + %audio_data_index = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 3 + %1 = load i64, ptr %audio_data_index, align 8 + %div = udiv i64 %1, %conv + %cmp525.not = icmp ult i64 %1, %conv + %audio_data = getelementptr inbounds %struct.FFEBUR128State, ptr %st, i64 0, i32 1 + %umax = tail call i64 @llvm.umax.i64(i64 %div, i64 1) + br label %for.cond2.preheader + +for.cond2.preheader: ; preds = %for.cond2.preheader.lr.ph, %for.inc11 + %channel_sum.030 = phi double [ 0.000000e+00, %for.cond2.preheader.lr.ph ], [ %channel_sum.1.lcssa, %for.inc11 ] + %c.029 = phi i64 [ 0, %for.cond2.preheader.lr.ph ], [ %inc12, %for.inc11 ] + br i1 %cmp525.not, label %for.inc11, label %for.body7.lr.ph + +for.body7.lr.ph: ; preds = %for.cond2.preheader + %2 = load ptr, ptr %audio_data, align 8 + br label %for.body7 + +for.body7: ; preds = %for.body7.lr.ph, %for.body7 + %channel_sum.127 = phi double [ %channel_sum.030, %for.body7.lr.ph ], [ %add10, %for.body7 ] + %i.026 = phi i64 [ 0, %for.body7.lr.ph ], [ %inc, %for.body7 ] + %mul = mul i64 %i.026, %conv + %add = add i64 %mul, %c.029 + %arrayidx = getelementptr inbounds double, ptr %2, i64 %add + %3 = load double, ptr %arrayidx, align 8 + %add10 = fadd double %channel_sum.127, %3 + %inc = add nuw i64 %i.026, 1 + %exitcond.not = icmp eq i64 %inc, %umax + br i1 %exitcond.not, label %for.inc11, label %for.body7 + +for.inc11: ; preds = %for.body7, %for.cond2.preheader + %channel_sum.1.lcssa = phi double [ %channel_sum.030, %for.cond2.preheader ], [ %add10, %for.body7 ] + %inc12 = add nuw nsw i64 %c.029, 1 + %exitcond32.not = icmp eq i64 %inc12, %conv + br i1 %exitcond32.not, label %for.end13, label %for.cond2.preheader + +for.end13: ; preds = %for.inc11, %entry + %channel_sum.0.lcssa = phi double [ 0.000000e+00, %entry ], [ %channel_sum.1.lcssa, %for.inc11 ] + %add14 = fadd double %channel_sum.0.lcssa, 0.000000e+00 + store double %add14, ptr %optional_output, align 8 + ret void +} + +declare i64 @llvm.umax.i64(i64, i64) + +%struct.PAKT_INFO = type { i32, i32, i32, [0 x i32] } + +define i64 @alac_seek(ptr %0) { +; CHECK: Is not safe to expand terminating value for phi node %indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ] +entry: + %div = udiv i64 1, 0 + br label %for.body.i + +for.body.i: ; preds = %for.body.i, %entry + %indvars.iv.i = phi i64 [ 0, %entry ], [ %indvars.iv.next.i, %for.body.i ] + %arrayidx.i = getelementptr %struct.PAKT_INFO, ptr %0, i64 0, i32 3, i64 %indvars.iv.i + %1 = load i32, ptr %arrayidx.i, align 4 + %indvars.iv.next.i = add i64 %indvars.iv.i, 1 + %exitcond.not.i = icmp eq i64 %indvars.iv.i, %div + br i1 %exitcond.not.i, label %alac_pakt_block_offset.exit, label %for.body.i + +alac_pakt_block_offset.exit: ; preds = %for.body.i + ret i64 0 +}