diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -590,8 +590,8 @@ /// Handle all cross-iteration phis in the header. void fixCrossIterationPHIs(VPTransformState &State); - /// Fix a first-order recurrence. This is the second phase of vectorizing - /// this phi node. + /// Create the exit value of first order recurrences in the middle block and + /// update their users. void fixFirstOrderRecurrence(VPWidenPHIRecipe *PhiR, VPTransformState &State); /// Fix a reduction cross-iteration phi. This is the second phase of @@ -4222,17 +4222,12 @@ // After execution completes the vector loop, we extract the next value of // the recurrence (x) to use as the initial value in the scalar loop. - auto *IdxTy = Builder.getInt32Ty(); - auto *VecPhi = cast(State.get(PhiR, 0)); - - // Fix the latch value of the new recurrence in the vector loop. - VPValue *PreviousDef = PhiR->getBackedgeValue(); - Value *Incoming = State.get(PreviousDef, UF - 1); - VecPhi->addIncoming(Incoming, LI->getLoopFor(LoopVectorBody)->getLoopLatch()); - // Extract the last vector element in the middle block. This will be the // initial value for the recurrence when jumping to the scalar loop. + VPValue *PreviousDef = PhiR->getBackedgeValue(); + Value *Incoming = State.get(PreviousDef, UF - 1); auto *ExtractForScalar = Incoming; + auto *IdxTy = Builder.getInt32Ty(); if (VF.isVector()) { auto *One = ConstantInt::get(IdxTy, 1); Builder.SetInsertPoint(LoopMiddleBlock->getTerminator()); diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -815,6 +815,19 @@ for (VPBlockBase *Block : depth_first(Entry)) Block->execute(State); + // Fix the latch value of the first-order recurrences in the vector loop. Only + // a single part is generated, regardless of the UF. + VPBasicBlock *Header = Entry->getEntryBasicBlock(); + for (VPRecipeBase &R : Header->phis()) { + if (auto *FOR = dyn_cast(&R)) { + auto *VecPhi = cast(State->get(FOR, 0)); + + VPValue *PreviousDef = FOR->getBackedgeValue(); + Value *Incoming = State->get(PreviousDef, State->UF - 1); + VecPhi->addIncoming(Incoming, VectorLatchBB); + } + } + // Setup branch terminator successors for VPBBs in VPBBsToFix based on // VPBB's successors. for (auto VPBB : State->CFG.VPBBsToFix) {