Index: llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp =================================================================== --- llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -5649,6 +5649,51 @@ DeadInsts.emplace_back(OperandIsInstr); } +static bool AllUseInBB(Instruction *I, BasicBlock *BB) { + for (User *U : I->users()) { + Instruction *UI = cast(U); + if (UI->getParent() != BB) + return false; + if (UI->mayHaveSideEffects()) { + if (StoreInst *SI = dyn_cast(UI)) { + if (SI->getPointerOperand() == I) + continue; + } + return false; + } + } + return true; +} + +// Trying to hoist the IVInc to loop header if all IVInc user is in +// the loop header. It will help backend to generate post index load/store +// when the latch block is different from loop header block. +static bool CanHoistIVInc(const TargetTransformInfo &TTI, const LSRFixup &Fixup, + const LSRUse &LU, Instruction *IVIncInsertPos, + Loop *L) { + if (LU.Kind != LSRUse::Address) + return false; + + BasicBlock *LHeader = L->getHeader(); + if (IVIncInsertPos->getParent() == LHeader) + return false; + + if (IVIncInsertPos->getParent() != L->getLoopLatch()) + return false; + + Instruction *User = dyn_cast(Fixup.OperandValToReplace); + if (!User || !AllUseInBB(User, LHeader)) + return false; + + Instruction *I = Fixup.UserInst; + if ((isa(I) && + TTI.isIndexedLoadLegal(TTI.MIM_PostInc, I->getType())) || + (isa(I) && + TTI.isIndexedStoreLegal(TTI.MIM_PostInc, I->getType()))) + return true; + return false; +} + /// Rewrite all the fixup locations with new values, following the chosen /// solution. void LSRInstance::ImplementSolution( @@ -5657,8 +5702,6 @@ // we can remove them after we are done working. SmallVector DeadInsts; - Rewriter.setIVIncInsertPos(L, IVIncInsertPos); - // Mark phi nodes that terminate chains so the expander tries to reuse them. for (const IVChain &Chain : IVChainVec) { if (PHINode *PN = dyn_cast(Chain.tailUserInst())) @@ -5666,11 +5709,17 @@ } // Expand the new value definitions and update the users. - for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) - for (const LSRFixup &Fixup : Uses[LUIdx].Fixups) { - Rewrite(Uses[LUIdx], Fixup, *Solution[LUIdx], DeadInsts); + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + const LSRUse &LU = Uses[LUIdx]; + for (const LSRFixup &Fixup : LU.Fixups) { + Instruction *InsertPos = CanHoistIVInc(TTI, Fixup, LU, IVIncInsertPos, L) + ? L->getHeader()->getTerminator() + : IVIncInsertPos; + Rewriter.setIVIncInsertPos(L, InsertPos); + Rewrite(LU, Fixup, *Solution[LUIdx], DeadInsts); Changed = true; } + } for (const IVChain &Chain : IVChainVec) { GenerateIVChain(Chain, DeadInsts); Index: llvm/test/Transforms/LoopStrengthReduce/AArch64/pr53625.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoopStrengthReduce/AArch64/pr53625.ll @@ -0,0 +1,56 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=aarch64-unknown-unknown | FileCheck %s + +target datalayout = "e-m:w-p:64:64-i32:32-i64:64-i128:128-n32:64-S128" + +; we need to generate post index load for this case +define i32 @test(i32 %c, ptr %a, ptr %b) { +; CHECK-LABEL: test: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: cmp w0, #1 +; CHECK-NEXT: b.lt .LBB0_4 +; CHECK-NEXT: // %bb.1: // %for.body.preheader +; CHECK-NEXT: mov w8, w0 +; CHECK-NEXT: .LBB0_2: // %for.body +; CHECK-NEXT: // =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: ldr w9, [x1], #4 +; CHECK-NEXT: ldr w10, [x2], #4 +; CHECK-NEXT: tst w10, w9 +; CHECK-NEXT: b.ne .LBB0_5 +; CHECK-NEXT: // %bb.3: // %for.cond +; CHECK-NEXT: // in Loop: Header=BB0_2 Depth=1 +; CHECK-NEXT: subs x8, x8, #1 +; CHECK-NEXT: b.ne .LBB0_2 +; CHECK-NEXT: .LBB0_4: +; CHECK-NEXT: mov w0, wzr +; CHECK-NEXT: ret +; CHECK-NEXT: .LBB0_5: +; CHECK-NEXT: mov w0, #1 +; CHECK-NEXT: ret +entry: + %cmp13 = icmp sgt i32 %c, 0 + br i1 %cmp13, label %for.body.preheader, label %return + +for.body.preheader: ; preds = %entry + %wide.trip.count = zext i32 %c to i64 + br label %for.body + +for.cond: ; preds = %for.body + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond.not = icmp eq i64 %indvars.iv.next, %wide.trip.count + br i1 %exitcond.not, label %return, label %for.body + +for.body: ; preds = %for.body.preheader, %for.cond + %indvars.iv = phi i64 [ 0, %for.body.preheader ], [ %indvars.iv.next, %for.cond ] + %arrayidx = getelementptr inbounds i32, ptr %a, i64 %indvars.iv + %0 = load i32, ptr %arrayidx, align 4 + %arrayidx2 = getelementptr inbounds i32, ptr %b, i64 %indvars.iv + %1 = load i32, ptr %arrayidx2, align 4 + %and = and i32 %1, %0 + %tobool3.not = icmp eq i32 %and, 0 + br i1 %tobool3.not, label %for.cond, label %return + +return: ; preds = %for.cond, %for.body, %entry + %retval.1 = phi i32 [ 0, %entry ], [ 0, %for.cond ], [ 1, %for.body ] + ret i32 %retval.1 +}