diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -186,6 +186,10 @@ "lsr-setupcost-depth-limit", cl::Hidden, cl::init(7), cl::desc("The limit on recursion depth for LSRs setup cost")); +static cl::opt AllowTerminalConditionFoldingAfterLSR( + "lsr-term-fold", cl::Hidden, cl::init(false), + cl::desc("Attempt to replace primary IV with other IV.")); + #ifndef NDEBUG // Stress test IV chain generation. static cl::opt StressIVChain( @@ -6572,6 +6576,132 @@ return nullptr; } +static std::tuple +canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, + const LoopInfo &LI) { + const std::tuple CantFold = {false, nullptr, + nullptr}; + + if (!L->isInnermost()) { + LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n"); + return CantFold; + } + // Only inspect on simple loop structure + if (!L->isLoopSimplifyForm()) { + LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n"); + return CantFold; + } + + if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { + LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n"); + return CantFold; + } + + BasicBlock *LoopPreheader = L->getLoopPreheader(); + BasicBlock *LoopLatch = L->getLoopLatch(); + + // TODO: Can we do something for greater than and less than? + // Terminating condition is foldable when it is an eq/neq icmp + BranchInst *BI = cast(LoopLatch->getTerminator()); + if (BI->isUnconditional()) + return CantFold; + Value *TermCond = BI->getCondition(); + if (!isa(TermCond) || !cast(TermCond)->isEquality()) + return CantFold; + + // For `IsToFold`, a primary IV can be replaced by other affine AddRec when it + // is only used by the terminating condition. To check for this, we may need + // to traverse through a chain of use-def until we can examine the final + // usage. + // *----------------------* + // *---->| LoopHeader: | + // | | PrimaryIV = phi ... | + // | *----------------------* + // | | + // | | + // | chain of + // | single use + // used by | + // phi | + // | Value + // | / \ + // | chain of chain of + // | single use single use + // | / \ + // | / \ + // *- Value Value --> used by terminating condition + auto IsToFold = [&](PHINode &PN) -> bool { + if (PN.getNumIncomingValues() != 2) + return false; + Value *V = &PN; + + while (V->getNumUses() == 1) + V = *V->user_begin(); + + if (V->getNumUses() != 2) + return false; + + Value *VToPN = nullptr; + Value *VToTermCond = nullptr; + for (User *U : V->users()) { + while (U->getNumUses() == 1) { + if (isa(U)) + VToPN = U; + if (U == TermCond) + VToTermCond = U; + U = *U->user_begin(); + } + } + return VToPN && VToTermCond; + }; + + // For `IsToHelpFold`, other IV that is an affine AddRec will be sufficient to + // replace the terminating condition + auto IsToHelpFold = [&](PHINode &PN) -> bool { + if (PN.getNumIncomingValues() != 2) + return false; + + // TODO: Right now we limit the phi node to help the folding be of a start + // value of getelementptr. We can extend to any kinds of IV as long as it is + // an affine AddRec. Add a switch to cover more types of instructions here + // and down in the actual transformation. + return PN.getBasicBlockIndex(LoopPreheader) != -1 && + PN.getBasicBlockIndex(LoopLatch) != -1 && + isa(PN.getIncomingValueForBlock(LoopPreheader)); + }; + + PHINode *ToFold = nullptr; + PHINode *ToHelpFold = nullptr; + + for (PHINode &PN : L->getHeader()->phis()) { + if (!SE.isSCEVable(PN.getType())) + continue; + const SCEV *S = SE.getSCEV(&PN); + const SCEVAddRecExpr *AddRec = dyn_cast(S); + // Only speculate on affine AddRec + if (!AddRec || !AddRec->isAffine()) + continue; + + if (IsToFold(PN)) + ToFold = &PN; + else if (IsToHelpFold(PN)) + ToHelpFold = &PN; + } + + LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() + << "Cannot find other AddRec IV to help folding\n";); + + LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() + << "\nFound loop that can fold terminating condition\n" + << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" + << " TermCond: " << *TermCond << "\n" + << " BrandInst: " << *BI << "\n" + << " ToFold: " << *ToFold << "\n" + << " ToHelpFold: " << *ToHelpFold << "\n"); + + return {ToFold != nullptr && ToHelpFold != nullptr, ToFold, ToHelpFold}; +} + static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI, @@ -6630,6 +6760,92 @@ } } + if (AllowTerminalConditionFoldingAfterLSR) { + auto CanFoldTerminalCondition = canFoldTermCondOfLoop(L, SE, DT, LI); + if (std::get<0>(CanFoldTerminalCondition) == true) { + BasicBlock *LoopPreheader = L->getLoopPreheader(); + BasicBlock *LoopLatch = L->getLoopLatch(); + + PHINode *ToFold = std::get<1>(CanFoldTerminalCondition); + PHINode *ToHelpFold = std::get<2>(CanFoldTerminalCondition); + + LLVM_DEBUG(dbgs() << "To fold phi-node:\n" + << *ToFold << "\n" + << "New term-cond phi-node:\n" + << *ToHelpFold << "\n"); + + Value *StartValue = ToHelpFold->getIncomingValueForBlock(LoopPreheader); + Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); + + // SCEVExpander for both use in preheader and latch + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); + SCEVExpanderCleaner ExpCleaner(Expander); + + // Create new terminating value at loop header + GetElementPtrInst *StartValueGEP = cast(StartValue); + Type *PtrTy = StartValueGEP->getPointerOperand()->getType(); + + const SCEV *BECount = SE.getBackedgeTakenCount(L); + const SCEVAddRecExpr *AddRec = + cast(SE.getSCEV(ToHelpFold)); + + // TermValue = Start + Stride * (BackedgeCount + 1) + const SCEV *TermValueS = SE.getAddExpr( + AddRec->getOperand(0), + SE.getTruncateOrZeroExtend( + SE.getMulExpr( + AddRec->getOperand(1), + SE.getTruncateOrZeroExtend( + SE.getAddExpr(BECount, SE.getOne(BECount->getType())), + AddRec->getOperand(1)->getType())), + AddRec->getOperand(0)->getType())); + + // NOTE: If this is triggered, we should add this into predicate + if (!Expander.isSafeToExpand(TermValueS)) { + LLVMContext &Ctx = L->getHeader()->getContext(); + Ctx.emitError( + "Terminating value is not safe to expand, need to add it to " + "predicate"); + } else { // Now we replace the condition with ToHelpFold and remove ToFold + Changed = true; + + Value *TermValue = Expander.expandCodeFor( + TermValueS, PtrTy, LoopPreheader->getTerminator()); + + LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" + << *StartValue << "\n" + << "Terminating value of new term-cond phi-node:\n" + << *TermValue << "\n"); + + // Create new terminating condition at loop latch + BranchInst *BI = cast(LoopLatch->getTerminator()); + ICmpInst *OldTermCond = cast(BI->getCondition()); + IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); + Value *NewTermCond = LatchBuilder.CreateICmp( + OldTermCond->getPredicate(), LoopValue, TermValue, + "lsr_fold_term_cond.replaced_term_cond"); + + LLVM_DEBUG(dbgs() << "Old term-cond:\n" + << *OldTermCond << "\n" + << "New term-cond:\b" << *NewTermCond << "\n"); + + BI->setCondition(NewTermCond); + + OldTermCond->replaceAllUsesWith( + PoisonValue::get(OldTermCond->getType())); + OldTermCond->eraseFromParent(); + + // Cleanup the old terminating condition that is no longer used + // Clear the PHINode and DCE will do the rest... + while (ToFold->getNumIncomingValues()) + ToFold->removeIncomingValue(0u); + } + + ExpCleaner.markResultUsed(); + } + } + if (SalvageableDVIRecords.empty()) return Changed; diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-complicate-add-rec.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-complicate-add-rec.ll --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-complicate-add-rec.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-complicate-add-rec.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -passes="loop-reduce" -S | FileCheck %s +; RUN: opt < %s -passes="loop-reduce" -S -lsr-term-fold | FileCheck %s ; This is compiled from the following code ; void foo(int*); ; void ptr_of_ptr_addrec(int **ptrptr, int length) { @@ -15,19 +15,20 @@ define void @ptr_of_ptr_addrec(ptr %ptrptr, i32 %length) { ; CHECK-LABEL: @ptr_of_ptr_addrec( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[START_PTRPTR:%.*]] = getelementptr inbounds ptr, ptr [[PTRPTR:%.*]] +; CHECK-NEXT: [[START_PTRPTR:%.*]] = getelementptr ptr, ptr [[PTRPTR:%.*]] +; CHECK-NEXT: [[TMP0:%.*]] = shl i32 [[LENGTH:%.*]], 2 +; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[START_PTRPTR]], i32 [[TMP0]] ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.cond.cleanup: ; CHECK-NEXT: ret void ; CHECK: for.body: -; CHECK-NEXT: [[I_05:%.*]] = phi i32 [ [[DEC:%.*]], [[FOR_BODY]] ], [ [[LENGTH:%.*]], [[ENTRY:%.*]] ] -; CHECK-NEXT: [[IT_04:%.*]] = phi ptr [ [[INCDEC_PTR:%.*]], [[FOR_BODY]] ], [ [[START_PTRPTR]], [[ENTRY]] ] -; CHECK-NEXT: [[TMP0:%.*]] = load ptr, ptr [[IT_04]], align 8 -; CHECK-NEXT: tail call void @foo(ptr [[TMP0]]) +; CHECK-NEXT: [[IT_04:%.*]] = phi ptr [ [[INCDEC_PTR:%.*]], [[FOR_BODY]] ], [ [[START_PTRPTR]], [[ENTRY:%.*]] ] +; CHECK-NEXT: [[TMP1:%.*]] = load ptr, ptr [[IT_04]], align 8 +; CHECK-NEXT: tail call void @foo(ptr [[TMP1]]) ; CHECK-NEXT: [[INCDEC_PTR]] = getelementptr inbounds ptr, ptr [[IT_04]], i64 1 -; CHECK-NEXT: [[DEC]] = add i32 [[I_05]], -1 -; CHECK-NEXT: [[TOBOOL_NOT:%.*]] = icmp eq i32 [[DEC]], 0 -; CHECK-NEXT: br i1 [[TOBOOL_NOT]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY]] +; CHECK-NEXT: [[DEC:%.*]] = add i32 poison, -1 +; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[INCDEC_PTR]], [[UGLYGEP]] +; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY]] ; entry: %start.ptrptr = getelementptr inbounds ptr, ptr %ptrptr diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-const-tripcount.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-const-tripcount.ll --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-const-tripcount.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-const-tripcount.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -passes="loop-reduce" -S | FileCheck %s +; RUN: opt < %s -passes="loop-reduce" -S -lsr-term-fold | FileCheck %s ; This is compiled from the following code ; void close_bubbles(int a[400]) { ; for (int i = 21; i < 400; i++) @@ -12,15 +12,15 @@ ; CHECK-LABEL: @close_bubbles( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[A:%.*]], i64 84 +; CHECK-NEXT: [[UGLYGEP1:%.*]] = getelementptr i8, ptr [[A]], i32 1600 ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.body: ; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[UGLYGEP2:%.*]], [[FOR_BODY]] ], [ [[UGLYGEP]], [[ENTRY:%.*]] ] -; CHECK-NEXT: [[LSR_IV:%.*]] = phi i64 [ [[LSR_IV_NEXT:%.*]], [[FOR_BODY]] ], [ 379, [[ENTRY]] ] ; CHECK-NEXT: store i32 1, ptr [[LSR_IV1]], align 4 -; CHECK-NEXT: [[LSR_IV_NEXT]] = add nsw i64 [[LSR_IV]], -1 +; CHECK-NEXT: [[LSR_IV_NEXT:%.*]] = add nsw i64 poison, -1 ; CHECK-NEXT: [[UGLYGEP2]] = getelementptr i8, ptr [[LSR_IV1]], i64 4 -; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[LSR_IV_NEXT]], 0 -; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_END:%.*]], label [[FOR_BODY]] +; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[UGLYGEP2]], [[UGLYGEP1]] +; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_END:%.*]], label [[FOR_BODY]] ; CHECK: for.end: ; CHECK-NEXT: ret void ; diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-negative-testcase.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-negative-testcase.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-negative-testcase.ll @@ -0,0 +1,74 @@ +; REQUIRES: asserts +; RUN: opt < %s -passes="loop-reduce" -S -debug -lsr-term-fold 2>&1 | FileCheck %s + +target datalayout = "e-p:32:32:32-n32" + +define i32 @loop_variant(ptr %ar, i32 %n, i32 %m) { +; CHECK: Cannot fold on backedge that is loop variant +entry: + br label %for.cond + +for.cond: ; preds = %for.cond, %entry + %n.addr.0 = phi i32 [ %n, %entry ], [ %mul, %for.cond ] + %cmp = icmp slt i32 %n.addr.0, %m + %mul = shl nsw i32 %n.addr.0, 1 + br i1 %cmp, label %for.cond, label %for.end + +for.end: ; preds = %for.cond + ret i32 %n.addr.0 +} + +define i32 @nested_loop(ptr %ar, i32 %n, i32 %m, i32 %o) { +; CHECK: Cannot fold on backedge that is loop variant +; CHECK: Cannot fold on non-innermost loop +entry: + %cmp15 = icmp sgt i32 %o, 0 + br i1 %cmp15, label %for.body, label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup3, %entry + %cnt.0.lcssa = phi i32 [ 0, %entry ], [ %cnt.1.lcssa, %for.cond.cleanup3 ] + ret i32 %cnt.0.lcssa + +for.body: ; preds = %entry, %for.cond.cleanup3 + %i.017 = phi i32 [ %inc6, %for.cond.cleanup3 ], [ 0, %entry ] + %cnt.016 = phi i32 [ %cnt.1.lcssa, %for.cond.cleanup3 ], [ 0, %entry ] + %sub = sub nsw i32 %n, %i.017 + %cmp212 = icmp slt i32 %sub, %m + br i1 %cmp212, label %for.body4, label %for.cond.cleanup3 + +for.cond.cleanup3: ; preds = %for.body4, %for.body + %cnt.1.lcssa = phi i32 [ %cnt.016, %for.body ], [ %inc, %for.body4 ] + %inc6 = add nuw nsw i32 %i.017, 1 + %cmp = icmp slt i32 %inc6, %o + br i1 %cmp, label %for.body, label %for.cond.cleanup + +for.body4: ; preds = %for.body, %for.body4 + %j.014 = phi i32 [ %mul, %for.body4 ], [ %sub, %for.body ] + %cnt.113 = phi i32 [ %inc, %for.body4 ], [ %cnt.016, %for.body ] + %inc = add nsw i32 %cnt.113, 1 + %mul = shl nsw i32 %j.014, 1 + %cmp2 = icmp slt i32 %mul, %m + br i1 %cmp2, label %for.body4, label %for.cond.cleanup3 +} + +define void @no_iv_to_help(ptr %mark, i32 signext %length) { +; CHECK: Cannot find other AddRec IV to help folding +entry: + %tobool.not3 = icmp eq i32 %length, 0 + br i1 %tobool.not3, label %for.cond.cleanup, label %for.body + +for.cond.cleanup: ; preds = %for.body, %entry + ret void + +for.body: ; preds = %entry, %for.body + %i.05 = phi i32 [ %dec, %for.body ], [ %length, %entry ] + %dst.04 = phi ptr [ %incdec.ptr, %for.body ], [ %mark, %entry ] + %0 = load ptr, ptr %dst.04, align 8 + call ptr @sv_2mortal(ptr %0) + %incdec.ptr = getelementptr inbounds ptr, ptr %dst.04, i64 1 + %dec = add nsw i32 %i.05, -1 + %tobool.not = icmp eq i32 %dec, 0 + br i1 %tobool.not, label %for.cond.cleanup, label %for.body +} + +declare void @sv_2mortal(ptr) diff --git a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-runtime-tripcount.ll b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-runtime-tripcount.ll --- a/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-runtime-tripcount.ll +++ b/llvm/test/Transforms/LoopStrengthReduce/lsr-fold-iv-runtime-tripcount.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -passes="loop-reduce" -S | FileCheck %s +; RUN: opt < %s -passes="loop-reduce" -S -lsr-term-fold | FileCheck %s ; This is compiled from the following code ; void foo(char*); ; void runtime_tripcount(char *seq, int L) @@ -16,10 +16,11 @@ ; CHECK-NEXT: [[CMP_NOT3:%.*]] = icmp slt i32 [[L:%.*]], 1 ; CHECK-NEXT: br i1 [[CMP_NOT3]], label [[FOR_COND_CLEANUP:%.*]], label [[FOR_BODY_PREHEADER:%.*]] ; CHECK: for.body.preheader: -; CHECK-NEXT: [[TMP0:%.*]] = add nuw i32 [[L]], 1 +; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[L]], 1 ; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[TMP0]] to i64 ; CHECK-NEXT: [[TMP1:%.*]] = add nsw i64 [[WIDE_TRIP_COUNT]], -1 ; CHECK-NEXT: [[UGLYGEP:%.*]] = getelementptr i8, ptr [[SEQ:%.*]], i64 1 +; CHECK-NEXT: [[UGLYGEP1:%.*]] = getelementptr i8, ptr [[SEQ]], i32 [[TMP0]] ; CHECK-NEXT: br label [[FOR_BODY:%.*]] ; CHECK: for.cond.cleanup.loopexit: ; CHECK-NEXT: br label [[FOR_COND_CLEANUP]] @@ -27,13 +28,12 @@ ; CHECK-NEXT: ret void ; CHECK: for.body: ; CHECK-NEXT: [[LSR_IV1:%.*]] = phi ptr [ [[UGLYGEP2:%.*]], [[FOR_BODY]] ], [ [[UGLYGEP]], [[FOR_BODY_PREHEADER]] ] -; CHECK-NEXT: [[LSR_IV:%.*]] = phi i64 [ [[LSR_IV_NEXT:%.*]], [[FOR_BODY]] ], [ [[TMP1]], [[FOR_BODY_PREHEADER]] ] ; CHECK-NEXT: [[TMP2:%.*]] = load i8, ptr [[LSR_IV1]], align 1 ; CHECK-NEXT: tail call void @foo(i8 noundef zeroext [[TMP2]]) -; CHECK-NEXT: [[LSR_IV_NEXT]] = add nsw i64 [[LSR_IV]], -1 +; CHECK-NEXT: [[LSR_IV_NEXT:%.*]] = add nsw i64 poison, -1 ; CHECK-NEXT: [[UGLYGEP2]] = getelementptr i8, ptr [[LSR_IV1]], i64 1 -; CHECK-NEXT: [[EXITCOND_NOT:%.*]] = icmp eq i64 [[LSR_IV_NEXT]], 0 -; CHECK-NEXT: br i1 [[EXITCOND_NOT]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY]] +; CHECK-NEXT: [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND:%.*]] = icmp eq ptr [[UGLYGEP2]], [[UGLYGEP1]] +; CHECK-NEXT: br i1 [[LSR_FOLD_TERM_COND_REPLACED_TERM_COND]], label [[FOR_COND_CLEANUP_LOOPEXIT:%.*]], label [[FOR_BODY]] ; entry: %cmp.not3 = icmp slt i32 %L, 1