diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -630,10 +630,6 @@ /// Create code for the loop exit value of the reduction. void fixReduction(VPReductionPHIRecipe *Phi, VPTransformState &State); - /// Clear NSW/NUW flags from reduction instructions if necessary. - void clearReductionWrapFlags(VPReductionPHIRecipe *PhiR, - VPTransformState &State); - /// Iteratively sink the scalarized operands of a predicated instruction into /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); @@ -3929,9 +3925,6 @@ // This is the vector-clone of the value that leaves the loop. Type *VecTy = State.get(LoopExitInstDef, 0)->getType(); - // Wrap flags are in general invalid after vectorization, clear them. - clearReductionWrapFlags(PhiR, State); - // Before each round, move the insertion point right between // the PHIs and the values we are going to write. // This allows us to write both PHINodes and the extractelement @@ -4111,38 +4104,6 @@ OrigPhi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); } -void InnerLoopVectorizer::clearReductionWrapFlags(VPReductionPHIRecipe *PhiR, - VPTransformState &State) { - const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); - RecurKind RK = RdxDesc.getRecurrenceKind(); - if (RK != RecurKind::Add && RK != RecurKind::Mul) - return; - - SmallVector Worklist; - SmallPtrSet Visited; - Worklist.push_back(PhiR); - Visited.insert(PhiR); - - while (!Worklist.empty()) { - VPValue *Cur = Worklist.pop_back_val(); - for (unsigned Part = 0; Part < UF; ++Part) { - Value *V = State.get(Cur, Part); - if (!isa(V)) - break; - cast(V)->dropPoisonGeneratingFlags(); - } - - for (VPUser *U : Cur->users()) { - auto *UserRecipe = dyn_cast(U); - if (!UserRecipe) - continue; - for (VPValue *V : UserRecipe->definedValues()) - if (Visited.insert(V).second) - Worklist.push_back(V); - } - } -} - void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { // The basic block and loop containing the predicated instruction. auto *PredBB = PredInst->getParent(); @@ -9317,6 +9278,8 @@ Builder.createNaryOp(Instruction::Select, {Cond, Red, PhiR}); } } + + VPlanTransforms::clearReductionWrapFlags(*Plan); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -81,6 +81,9 @@ /// not valid. static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder); + /// Clear NSW/NUW flags from reduction instructions if necessary. + static void clearReductionWrapFlags(VPlan &Plan); + /// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the /// resulting plan to \p BestVF and \p BestUF. static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -746,3 +746,35 @@ } return true; } + +void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) { + for (VPRecipeBase &R : + Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) { + auto *PhiR = dyn_cast(&R); + if (!PhiR) + continue; + const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); + RecurKind RK = RdxDesc.getRecurrenceKind(); + if (RK != RecurKind::Add && RK != RecurKind::Mul) + continue; + + SmallSetVector Worklist; + Worklist.insert(PhiR); + + for (unsigned I = 0; I != Worklist.size(); ++I) { + VPValue *Cur = Worklist[I]; + if (auto *RecWithFlags = + dyn_cast(Cur->getDefiningRecipe())) { + RecWithFlags->dropPoisonGeneratingFlags(); + } + + for (VPUser *U : Cur->users()) { + auto *UserRecipe = dyn_cast(U); + if (!UserRecipe) + continue; + for (VPValue *V : UserRecipe->definedValues()) + Worklist.insert(V); + } + } + } +}