diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -6572,6 +6572,110 @@ return nullptr; } +static std::pair> +canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, + const LoopInfo &LI) { + const std::pair> CantFold = { + false, {nullptr, nullptr}}; + + // Only inspect on simple loop structure + if (!L->isLoopSimplifyForm() || !L->isRecursivelyLCSSAForm(DT, LI) || + !L->isInnermost()) + return CantFold; + + BasicBlock *LoopLatch = L->getLoopLatch(); + BasicBlock *LoopPreheader = L->getLoopPreheader(); + + if (!LoopLatch || !LoopPreheader || !SE.hasLoopInvariantBackedgeTakenCount(L)) + return CantFold; + + // TODO: Can we do something for greater than and less than? + // Terminal condition is foldable when it is an eq/neq icmp + BranchInst *BI = cast(LoopLatch->getTerminator()); + if (BI->isUnconditional()) + return CantFold; + Value *TermCond = BI->getCondition(); + if (!isa(TermCond)) + return CantFold; + if (!cast(TermCond)->isEquality()) + return CantFold; + + std::function IsToFold = [&](PHINode &PN) -> bool { + if (PN.getNumIncomingValues() != 2) + return false; + Value *V = cast(&PN); + + while (V->getNumUses() == 1) + V = *V->user_begin(); + + if (V->getNumUses() != 2) + return false; + + Value *VToPN = nullptr; + Value *VToTermCond = nullptr; + for (User *U : V->users()) { + while (U->getNumUses() == 1) { + if (isa(U)) + VToPN = U; + if (U == TermCond) + VToTermCond = U; + U = *U->user_begin(); + } + } + return VToPN && VToTermCond; + }; + + std::function IsToHelpFold = [&](PHINode &PN) -> bool { + if (PN.getNumIncomingValues() != 2) + return false; + int FromPreheader = -1; + int FromLoopLatch = -1; + + for (unsigned I = 0; I < PN.getNumIncomingValues(); ++I) { + if (PN.getIncomingBlock(I) == LoopPreheader) + FromPreheader = I; + if (PN.getIncomingBlock(I) == LoopLatch) + FromLoopLatch = I; + } + + if (FromPreheader == -1 || FromLoopLatch == -1) + return false; + + Value *StartValue = PN.getIncomingValue(FromPreheader); + return isa(StartValue); + }; + + PHINode *ToFold = nullptr; + PHINode *ToHelpFold = nullptr; + + for (PHINode &PN : L->getHeader()->phis()) { + if (!SE.isSCEVable(PN.getType())) + continue; + const SCEV *S = SE.getSCEV(&PN); + const SCEVAddRecExpr *AddRec = dyn_cast(S); + // Only speculate on affine AddRec + if (!AddRec) + continue; + if (!AddRec->isAffine()) + continue; + + if (IsToFold(PN)) + ToFold = &PN; + else if (IsToHelpFold(PN)) + ToHelpFold = &PN; + } + + LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() + << "\nFound loop that can fold terminal condition\n" + << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" + << " TermCond: " << *TermCond << "\n" + << " BrandInst: " << *BI << "\n" + << " ToFold: " << *ToFold << "\n" + << " ToHelpFold: " << *ToHelpFold << "\n"); + + return {ToFold != nullptr && ToHelpFold != nullptr, {ToFold, ToHelpFold}}; +} + static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, @@ -6630,6 +6734,93 @@ } } + auto CanFoldTerminalCondition = canFoldTermCondOfLoop(L, SE, DT, LI); + if (CanFoldTerminalCondition.first == true) { + Changed = true; + BasicBlock *LoopPreheader = L->getLoopPreheader(); + BasicBlock *LoopLatch = L->getLoopLatch(); + + PHINode *ToFold = CanFoldTerminalCondition.second.first; + PHINode *ToHelpFold = CanFoldTerminalCondition.second.second; + + LLVM_DEBUG(dbgs() << "To fold phi-node:\n" + << *ToFold << "\n" + << "New term-cond phi-node:\n" + << *ToHelpFold << "\n"); + + Value *StartValue = nullptr; + Value *LoopValue = nullptr; + + for (Use &U : ToHelpFold->operands()) { + if (ToHelpFold->getIncomingBlock(U) == LoopPreheader) + StartValue = cast(&U); + else + LoopValue = cast(&U); + } + + // SCEVExpander for both use in preheader and latch + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); + SCEVExpanderCleaner ExpCleaner(Expander); + + // Create new terminal value at loop header + GetElementPtrInst *StartValueGEP = cast(StartValue); + Type *PtrTy = StartValueGEP->getPointerOperand()->getType(); + + const SCEV *BECount = SE.getBackedgeTakenCount(L); + const SCEVAddRecExpr *AddRec = cast(SE.getSCEV(ToHelpFold)); + + // TermValue = Start + Stride * (BackedgeCount + 1) + const SCEV *TermValueS = SE.getAddExpr( + AddRec->getOperand(0), + SE.getTruncateOrZeroExtend( + SE.getMulExpr( + AddRec->getOperand(1), + SE.getTruncateOrZeroExtend( + SE.getAddExpr(BECount, SE.getOne(BECount->getType())), + AddRec->getOperand(1)->getType())), + AddRec->getOperand(0)->getType())); + + // NOTE: If this is triggered, we should predicate to check for safety + if (!Expander.isSafeToExpand(TermValueS)) { + LLVMContext &Ctx = L->getHeader()->getContext(); + Ctx.emitError( + "Terminal value is not safe to expand, need to add it to predicate"); + } + + Value *TermValue = Expander.expandCodeFor(TermValueS, PtrTy, + LoopPreheader->getTerminator()); + + LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" + << *StartValue << "\n" + << "Terminal value of new term-cond phi-node:\n" + << *TermValue << "\n"); + + // Create new terminal condition at loop latch + BranchInst *BI = cast(LoopLatch->getTerminator()); + ICmpInst *OldTermCond = cast(BI->getCondition()); + IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); + Value *NewTermCond = LatchBuilder.CreateICmp( + OldTermCond->getPredicate(), LoopValue, TermValue, + "lsr_fold_term_cond.replaced_term_cond"); + + LLVM_DEBUG(dbgs() << "Old term-cond:\n" + << *OldTermCond << "\n" + << "New term-cond:\b" << *NewTermCond << "\n"); + + BI->setCondition(NewTermCond); + + OldTermCond->replaceAllUsesWith(PoisonValue::get(OldTermCond->getType())); + OldTermCond->eraseFromParent(); + + // Cleanup the old terminal condition that is no longer used + // Clear the PHINode and DCE will do the rest... + while (ToFold->getNumIncomingValues()) + ToFold->removeIncomingValue(0u); + + ExpCleaner.markResultUsed(); + } + if (SalvageableDVIRecords.empty()) return Changed; diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-complicate-add-rec.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-complicate-add-rec.ll --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-complicate-add-rec.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-complicate-add-rec.ll @@ -20,21 +20,22 @@ ; CHECK-NEXT: [[TOBOOL_NOT3:%.*]] = icmp eq i32 [[LENGTH:%.*]], 0 ; CHECK-NEXT: br i1 [[TOBOOL_NOT3]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY_PREHEADER:%.*]] ; CHECK: for.body.preheader: -; CHECK-NEXT: [[ADD_PTR:%.*]] = getelementptr inbounds ptr, ptr [[MARK:%.*]] +; CHECK-NEXT: [[ADD_PTR:%.*]] = getelementptr ptr, ptr [[MARK:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = shl i32 [[LENGTH]], 2 +; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[ADD_PTR]], i32 [[TMP0]] ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.cond.cleanup.loopexit: ; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] ; CHECK: for.cond.cleanup: ; CHECK-NEXT: ret void ; CHECK: for.body: -; CHECK-NEXT: [[I_05:%.*]] = phi i32 [ [[DEC:%.*]], [[FOR_BODY]] ], [ [[LENGTH]], [[FOR_BODY_PREHEADER]] ] ; CHECK-NEXT: [[DST_04:%.*]] = phi ptr [ [[INCDEC_PTR:%.*]], [[FOR_BODY]] ], [ [[ADD_PTR]], [[FOR_BODY_PREHEADER]] ] -; CHECK-NEXT: [[TMP0:%.*]] = load ptr, ptr [[DST_04]], align 8 -; CHECK-NEXT: call void @sv_2mortal(ptr [[TMP0]]) +; CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[DST_04]], align 8 +; CHECK-NEXT: call void @sv_2mortal(ptr [[TMP1]]) ; CHECK-NEXT: [[INCDEC_PTR]] = getelementptr inbounds ptr, ptr [[DST_04]], i64 1 -; CHECK-NEXT: [[DEC]] = add i32 [[I_05]], -1 -; CHECK-NEXT: [[TOBOOL_NOT:%.*]] = icmp eq i32 [[DEC]], 0 -; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY]] +; CHECK-NEXT: [[DEC:%.*]] = add i32 poison, -1 +; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[INCDEC_PTR]], [[UGLYGEP]] +; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY]] ; entry: %tobool.not3 = icmp eq i32 %length, 0 diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-const-tripcount.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-const-tripcount.ll --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-const-tripcount.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-const-tripcount.ll @@ -23,12 +23,12 @@ ; CHECK-NEXT: entry: ; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[BUBBLES:%.*]], i64 84 ; CHECK-NEXT: [[UGLYGEP3:%.*]] = getelementptr i8, ptr [[GB:%.*]], i64 84 +; CHECK-NEXT: [[UGLYGEP1:%.*]] = getelementptr i8, ptr [[BUBBLES]], i64 1600 ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.body: ; CHECK-NEXT: [[LSR_IV6:%.*]] = phi ptr [ [[UGLYGEP7:%.*]], [[FOR_INC:%.*]] ], [ getelementptr (i8, ptr @board, i64 21), [[ENTRY:%.*]] ] ; CHECK-NEXT: [[LSR_IV4:%.*]] = phi ptr [ [[UGLYGEP5:%.*]], [[FOR_INC]] ], [ [[UGLYGEP3]], [[ENTRY]] ] ; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[UGLYGEP2:%.*]], [[FOR_INC]] ], [ [[UGLYGEP]], [[ENTRY]] ] -; CHECK-NEXT: [[LSR_IV:%.*]] = phi i64 [ [[LSR_IV_NEXT:%.*]], [[FOR_INC]] ], [ 379, [[ENTRY]] ] ; CHECK-NEXT: [[TMP0:%.*]] = load i8, ptr [[LSR_IV6]], align 1 ; CHECK-NEXT: [[CMP1_NOT:%.*]] = icmp eq i8 [[TMP0]], 3 ; CHECK-NEXT: br i1 [[CMP1_NOT]], label [[FOR_INC]], label [[LOR_LHS_FALSE:%.*]] @@ -52,12 +52,12 @@ ; CHECK-NEXT: store i32 -1, ptr [[LSR_IV4]], align 4 ; CHECK-NEXT: br label [[FOR_INC]] ; CHECK: for.inc: -; CHECK-NEXT: [[LSR_IV_NEXT]] = add nsw i64 [[LSR_IV]], -1 +; CHECK-NEXT: [[LSR_IV_NEXT:%.*]] = add nsw i64 poison, -1 ; CHECK-NEXT: [[UGLYGEP2]] = getelementptr i8, ptr [[LSR_IV1]], i64 4 ; CHECK-NEXT: [[UGLYGEP5]] = getelementptr i8, ptr [[LSR_IV4]], i64 4 ; CHECK-NEXT: [[UGLYGEP7]] = getelementptr i8, ptr [[LSR_IV6]], i64 1 -; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[LSR_IV_NEXT]], 0 -; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_END:%.*]], label [[FOR_BODY]] +; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[UGLYGEP2]], [[UGLYGEP1]] +; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END:%.*]], label [[FOR_BODY]] ; CHECK: for.end: ; CHECK-NEXT: ret void ; diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-runtime-tripcount.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-runtime-tripcount.ll --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-runtime-tripcount.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-runtime-tripcount.ll @@ -44,11 +44,11 @@ ; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[ADD2]] to i64 ; CHECK-NEXT: [[TMP2:%.*]] = add nsw i64 [[WIDE_TRIP_COUNT]], -1 ; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[CALL]], i64 1 +; CHECK-NEXT: [[UGLYGEP1:%.*]] = getelementptr i8, ptr [[CALL]], i32 [[ADD2]] ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.body: ; CHECK-NEXT: [[LSR_IV3:%.*]] = phi ptr [ [[UGLYGEP4:%.*]], [[FOR_BODY]] ], [ [[SEQ:%.*]], [[FOR_BODY_LR_PH]] ] ; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[UGLYGEP2:%.*]], [[FOR_BODY]] ], [ [[UGLYGEP]], [[FOR_BODY_LR_PH]] ] -; CHECK-NEXT: [[LSR_IV:%.*]] = phi i64 [ [[LSR_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[TMP2]], [[FOR_BODY_LR_PH]] ] ; CHECK-NEXT: [[TMP3:%.*]] = load i8, ptr [[LSR_IV3]], align 1 ; CHECK-NEXT: [[IDXPROM9:%.*]] = zext i8 [[TMP3]] to i64 ; CHECK-NEXT: [[ARRAYIDX10:%.*]] = getelementptr inbounds i32, ptr [[TMP1]], i64 [[IDXPROM9]] @@ -61,11 +61,11 @@ ; CHECK-NEXT: [[COND:%.*]] = select i1 [[CMP14]], i64 [[CONV1730]], i64 [[SUB_PTR_SUB]] ; CHECK-NEXT: [[CONV18:%.*]] = trunc i64 [[COND]] to i8 ; CHECK-NEXT: store i8 [[CONV18]], ptr [[LSR_IV1]], align 1 -; CHECK-NEXT: [[LSR_IV_NEXT]] = add nsw i64 [[LSR_IV]], -1 +; CHECK-NEXT: [[LSR_IV_NEXT:%.*]] = add nsw i64 poison, -1 ; CHECK-NEXT: [[UGLYGEP2]] = getelementptr i8, ptr [[LSR_IV1]], i64 1 ; CHECK-NEXT: [[UGLYGEP4]] = getelementptr i8, ptr [[LSR_IV3]], i64 1 -; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[LSR_IV_NEXT]], 0 -; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_END_LOOPEXIT:%.*]], label [[FOR_BODY]] +; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[UGLYGEP2]], [[UGLYGEP1]] +; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END_LOOPEXIT:%.*]], label [[FOR_BODY]] ; CHECK: for.end.loopexit: ; CHECK-NEXT: br label [[FOR_END]] ; CHECK: for.end: