diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -601,10 +601,6 @@ /// Create code for the loop exit value of the reduction. void fixReduction(VPReductionPHIRecipe *Phi, VPTransformState &State); - /// Clear NSW/NUW flags from reduction instructions if necessary. - void clearReductionWrapFlags(VPReductionPHIRecipe *PhiR, - VPTransformState &State); - /// Iteratively sink the scalarized operands of a predicated instruction into /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); @@ -3770,9 +3766,6 @@ // This is the vector-clone of the value that leaves the loop. Type *VecTy = State.get(LoopExitInstDef, 0)->getType(); - // Wrap flags are in general invalid after vectorization, clear them. - clearReductionWrapFlags(PhiR, State); - // Before each round, move the insertion point right between // the PHIs and the values we are going to write. // This allows us to write both PHINodes and the extractelement @@ -3952,38 +3945,6 @@ OrigPhi->setIncomingValue(IncomingEdgeBlockIdx, LoopExitInst); } -void InnerLoopVectorizer::clearReductionWrapFlags(VPReductionPHIRecipe *PhiR, - VPTransformState &State) { - const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); - RecurKind RK = RdxDesc.getRecurrenceKind(); - if (RK != RecurKind::Add && RK != RecurKind::Mul) - return; - - SmallVector Worklist; - SmallPtrSet Visited; - Worklist.push_back(PhiR); - Visited.insert(PhiR); - - while (!Worklist.empty()) { - VPValue *Cur = Worklist.pop_back_val(); - for (unsigned Part = 0; Part < UF; ++Part) { - Value *V = State.get(Cur, Part); - if (!isa(V)) - break; - cast(V)->dropPoisonGeneratingFlags(); - } - - for (VPUser *U : Cur->users()) { - auto *UserRecipe = dyn_cast(U); - if (!UserRecipe) - continue; - for (VPValue *V : UserRecipe->definedValues()) - if (Visited.insert(V).second) - Worklist.push_back(V); - } - } -} - void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { // The basic block and loop containing the predicated instruction. auto *PredBB = PredInst->getParent(); @@ -8882,6 +8843,7 @@ // Adjust the recipes for any inloop reductions. adjustRecipesForReductions(cast(TopRegion->getExiting()), Plan, RecipeBuilder, Range.Start); + VPlanTransforms::clearReductionWrapFlags(*Plan); // Sink users of fixed-order recurrence past the recipe defining the previous // value and introduce FirstOrderRecurrenceSplice VPInstructions. diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -81,6 +81,9 @@ /// not valid. static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder); + /// Clear NSW/NUW flags from reduction instructions if necessary. + static void clearReductionWrapFlags(VPlan &Plan); + /// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the /// resulting plan to \p BestVF and \p BestUF. static void optimizeForVFAndUF(VPlan &Plan, ElementCount BestVF, diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -755,6 +755,41 @@ return true; } +void VPlanTransforms::clearReductionWrapFlags(VPlan &Plan) { + for (VPRecipeBase &R : + Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis()) { + auto *PhiR = dyn_cast(&R); + if (!PhiR) + continue; + const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); + RecurKind RK = RdxDesc.getRecurrenceKind(); + if (RK != RecurKind::Add && RK != RecurKind::Mul) + continue; + + SmallVector Worklist; + SmallPtrSet Visited; + Worklist.push_back(PhiR); + Visited.insert(PhiR); + + while (!Worklist.empty()) { + VPValue *Cur = Worklist.pop_back_val(); + if (auto *OpWithFlags = + dyn_cast(Cur->getDefiningRecipe())) { + OpWithFlags->dropPoisonGeneratingFlags(); + } + + for (VPUser *U : Cur->users()) { + auto *UserRecipe = dyn_cast(U); + if (!UserRecipe) + continue; + for (VPValue *V : UserRecipe->definedValues()) + if (Visited.insert(V).second) + Worklist.push_back(V); + } + } + } +} + void VPlanTransforms::truncateToMinimalBitwidths( VPlan &Plan, const MapVector &MinBWs) { auto GetType = [](VPValue *Op) {