diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -6664,11 +6664,26 @@ // For `IsToHelpFold`, other IV that is an affine AddRec will be sufficient to // replace the terminating condition auto IsToHelpFold = [&](PHINode &PN) -> bool { + const SCEVAddRecExpr *AddRec = cast(SE.getSCEV(&PN)); + const SCEV *BECount = SE.getBackedgeTakenCount(L); + const SCEV *TermValueS = SE.getAddExpr( + AddRec->getOperand(0), + SE.getTruncateOrZeroExtend( + SE.getMulExpr( + AddRec->getOperand(1), + SE.getTruncateOrZeroExtend( + SE.getAddExpr(BECount, SE.getOne(BECount->getType())), + AddRec->getOperand(1)->getType())), + AddRec->getOperand(0)->getType())); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); + // TODO: Right now we limit the phi node to help the folding be of a start // value of getelementptr. We can extend to any kinds of IV as long as it is // an affine AddRec. Add a switch to cover more types of instructions here // and down in the actual transformation. - return isa(PN.getIncomingValueForBlock(LoopPreheader)); + return Expander.isSafeToExpand(TermValueS) && + isa(PN.getIncomingValueForBlock(LoopPreheader)); }; PHINode *ToFold = nullptr; @@ -6813,41 +6828,33 @@ AddRec->getOperand(1)->getType())), AddRec->getOperand(0)->getType())); - // NOTE: If this is triggered, we should add this into predicate - if (!Expander.isSafeToExpand(TermValueS)) { - LLVMContext &Ctx = L->getHeader()->getContext(); - Ctx.emitError( - "Terminating value is not safe to expand, need to add it to " - "predicate"); - } else { // Now we replace the condition with ToHelpFold and remove ToFold - Changed = true; - NumTermFold++; - - Value *TermValue = Expander.expandCodeFor( - TermValueS, PtrTy, LoopPreheader->getTerminator()); - - LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" - << *StartValue << "\n" - << "Terminating value of new term-cond phi-node:\n" - << *TermValue << "\n"); - - // Create new terminating condition at loop latch - BranchInst *BI = cast(LoopLatch->getTerminator()); - ICmpInst *OldTermCond = cast(BI->getCondition()); - IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); - Value *NewTermCond = LatchBuilder.CreateICmp( - OldTermCond->getPredicate(), LoopValue, TermValue, - "lsr_fold_term_cond.replaced_term_cond"); - - LLVM_DEBUG(dbgs() << "Old term-cond:\n" - << *OldTermCond << "\n" - << "New term-cond:\b" << *NewTermCond << "\n"); - - BI->setCondition(NewTermCond); - - OldTermCond->eraseFromParent(); - DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); - } + Changed = true; + NumTermFold++; + + Value *TermValue = Expander.expandCodeFor(TermValueS, PtrTy, + LoopPreheader->getTerminator()); + + LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" + << *StartValue << "\n" + << "Terminating value of new term-cond phi-node:\n" + << *TermValue << "\n"); + + // Create new terminating condition at loop latch + BranchInst *BI = cast(LoopLatch->getTerminator()); + ICmpInst *OldTermCond = cast(BI->getCondition()); + IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); + Value *NewTermCond = LatchBuilder.CreateICmp( + OldTermCond->getPredicate(), LoopValue, TermValue, + "lsr_fold_term_cond.replaced_term_cond"); + + LLVM_DEBUG(dbgs() << "Old term-cond:\n" + << *OldTermCond << "\n" + << "New term-cond:\b" << *NewTermCond << "\n"); + + BI->setCondition(NewTermCond); + + OldTermCond->eraseFromParent(); + DeleteDeadPHIs(L->getHeader(), &TLI, MSSAU.get()); ExpCleaner.markResultUsed(); }