diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -594,8 +594,7 @@ /// update their users. void fixFirstOrderRecurrence(VPWidenPHIRecipe *PhiR, VPTransformState &State); - /// Fix a reduction cross-iteration phi. This is the second phase of - /// vectorizing this phi node. + /// Create code for the loop exit value of the reduction. void fixReduction(VPReductionPHIRecipe *Phi, VPTransformState &State); /// Clear NSW/NUW flags from reduction instructions if necessary. @@ -4303,22 +4302,6 @@ // Wrap flags are in general invalid after vectorization, clear them. clearReductionWrapFlags(RdxDesc, State); - // Fix the vector-loop phi. - - // Reductions do not have to start at zero. They can start with - // any loop invariant values. - BasicBlock *VectorLoopLatch = LI->getLoopFor(LoopVectorBody)->getLoopLatch(); - - unsigned LastPartForNewPhi = PhiR->isOrdered() ? 1 : UF; - for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { - Value *VecRdxPhi = State.get(PhiR->getVPSingleValue(), Part); - Value *Val = State.get(PhiR->getBackedgeValue(), Part); - if (PhiR->isOrdered()) - Val = State.get(PhiR->getBackedgeValue(), UF - 1); - - cast(VecRdxPhi)->addIncoming(Val, VectorLoopLatch); - } - // Before each round, move the insertion point right between // the PHIs and the values we are going to write. // This allows us to write both PHINodes and the extractelement diff --git a/llvm/lib/Transforms/Vectorize/VPlan.cpp b/llvm/lib/Transforms/Vectorize/VPlan.cpp --- a/llvm/lib/Transforms/Vectorize/VPlan.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlan.cpp @@ -867,6 +867,21 @@ if (!EnableVPlanNativePath) updateDominatorTree(State->DT, VectorPreHeaderBB, VectorLatchBB, L->getExitBlock()); + + // Fixup reduction PHI nodes in the vectorized loop header. + for (VPRecipeBase &R : Header->phis()) { + auto *PhiR = dyn_cast(&R); + if (!PhiR) + continue; + + unsigned LastPartForNewPhi = PhiR->isOrdered() ? 1 : State->UF; + for (unsigned Part = 0; Part < LastPartForNewPhi; ++Part) { + Value *VecRdxPhi = State->get(PhiR->getVPSingleValue(), Part); + Value *Val = State->get(PhiR->getBackedgeValue(), + PhiR->isOrdered() ? State->UF - 1 : Part); + cast(VecRdxPhi)->addIncoming(Val, VectorLatchBB); + } + } } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)