diff --git a/llvm/include/llvm/Analysis/IVDescriptors.h b/llvm/include/llvm/Analysis/IVDescriptors.h --- a/llvm/include/llvm/Analysis/IVDescriptors.h +++ b/llvm/include/llvm/Analysis/IVDescriptors.h @@ -186,14 +186,9 @@ /// previous iteration (e.g. if the value is defined in the previous /// iteration, we refer to it as first-order recurrence, if it is defined in /// the iteration before the previous, we refer to it as second-order - /// recurrence and so on). \p SinkAfter includes pairs of instructions where - /// the first will be rescheduled to appear after the second if/when the loop - /// is vectorized. It may be augmented with additional pairs if needed in - /// order to handle Phi as a first-order recurrence. - static bool - isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop, - MapVector &SinkAfter, - DominatorTree *DT); + /// recurrence and so on). + static bool isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop, + DominatorTree *DT); RecurKind getRecurrenceKind() const { return Kind; } diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h --- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h +++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h @@ -512,10 +512,6 @@ /// Holds the phi nodes that are fixed-order recurrences. RecurrenceSet FixedOrderRecurrences; - /// Holds instructions that need to sink past other instructions to handle - /// fixed-order recurrences. - MapVector SinkAfter; - /// Holds the widest induction type encountered. Type *WidestIndTy = nullptr; diff --git a/llvm/lib/Analysis/IVDescriptors.cpp b/llvm/lib/Analysis/IVDescriptors.cpp --- a/llvm/lib/Analysis/IVDescriptors.cpp +++ b/llvm/lib/Analysis/IVDescriptors.cpp @@ -927,9 +927,8 @@ return false; } -bool RecurrenceDescriptor::isFixedOrderRecurrence( - PHINode *Phi, Loop *TheLoop, - MapVector &SinkAfter, DominatorTree *DT) { +bool RecurrenceDescriptor::isFixedOrderRecurrence(PHINode *Phi, Loop *TheLoop, + DominatorTree *DT) { // Ensure the phi node is in the loop header and has two incoming values. if (Phi->getParent() != TheLoop->getHeader() || @@ -965,8 +964,7 @@ Previous = dyn_cast(PrevPhi->getIncomingValueForBlock(Latch)); } - if (!Previous || !TheLoop->contains(Previous) || isa(Previous) || - SinkAfter.count(Previous)) // Cannot rely on dominance due to motion. + if (!Previous || !TheLoop->contains(Previous) || isa(Previous)) return false; // Ensure every user of the phi node (recursively) is dominated by the @@ -975,23 +973,9 @@ // loop. // TODO: Consider extending this sinking to handle memory instructions. - // We optimistically assume we can sink all users after Previous. Keep a set - // of instructions to sink after Previous ordered by dominance in the common - // basic block. It will be applied to SinkAfter if all users can be sunk. - auto CompareByComesBefore = [](const Instruction *A, const Instruction *B) { - return A->comesBefore(B); - }; - std::set InstrsToSink( - CompareByComesBefore); - BasicBlock *PhiBB = Phi->getParent(); SmallVector WorkList; auto TryToPushSinkCandidate = [&](Instruction *SinkCandidate) { - // Already sunk SinkCandidate. - if (SinkCandidate->getParent() == PhiBB && - InstrsToSink.find(SinkCandidate) != InstrsToSink.end()) - return true; - // Cyclic dependence. if (Previous == SinkCandidate) return false; @@ -1004,56 +988,13 @@ SinkCandidate->mayHaveSideEffects() || SinkCandidate->mayReadFromMemory() || SinkCandidate->isTerminator()) return false; - - // Avoid sinking an instruction multiple times (if multiple operands are - // fixed order recurrences) by sinking once - after the latest 'previous' - // instruction. - auto It = SinkAfter.find(SinkCandidate); - if (It != SinkAfter.end()) { - auto *OtherPrev = It->second; - // Find the earliest entry in the 'sink-after' chain. The last entry in - // the chain is the original 'Previous' for a recurrence handled earlier. - auto EarlierIt = SinkAfter.find(OtherPrev); - while (EarlierIt != SinkAfter.end()) { - Instruction *EarlierInst = EarlierIt->second; - EarlierIt = SinkAfter.find(EarlierInst); - // Bail out if order has not been preserved. - if (EarlierIt != SinkAfter.end() && - !DT->dominates(EarlierInst, OtherPrev)) - return false; - OtherPrev = EarlierInst; - } - // Bail out if order has not been preserved. - if (OtherPrev != It->second && !DT->dominates(It->second, OtherPrev)) - return false; - - // SinkCandidate is already being sunk after an instruction after - // Previous. Nothing left to do. - if (DT->dominates(Previous, OtherPrev) || Previous == OtherPrev) - return true; - - // If there are other instructions to be sunk after SinkCandidate, remove - // and re-insert SinkCandidate can break those instructions. Bail out for - // simplicity. - if (any_of(SinkAfter, - [SinkCandidate](const std::pair &P) { - return P.second == SinkCandidate; - })) - return false; - - // Otherwise, Previous comes after OtherPrev and SinkCandidate needs to be - // re-sunk to Previous, instead of sinking to OtherPrev. Remove - // SinkCandidate from SinkAfter to ensure it's insert position is updated. - SinkAfter.erase(SinkCandidate); - } - + // // If we reach a PHI node that is not dominated by Previous, we reached a // header PHI. No need for sinking. if (isa(SinkCandidate)) return true; // Sink User tentatively and check its users - InstrsToSink.insert(SinkCandidate); WorkList.push_back(SinkCandidate); return true; }; @@ -1068,11 +1009,6 @@ } } - // We can sink all users of Phi. Update the mapping. - for (Instruction *I : InstrsToSink) { - SinkAfter[I] = Previous; - Previous = I; - } return true; } diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -721,8 +721,7 @@ continue; } - if (RecurrenceDescriptor::isFixedOrderRecurrence(Phi, TheLoop, - SinkAfter, DT)) { + if (RecurrenceDescriptor::isFixedOrderRecurrence(Phi, TheLoop, DT)) { AllowedExit.insert(Phi); FixedOrderRecurrences.insert(Phi); continue; @@ -894,18 +893,6 @@ } } - // For fixed order recurrences, we use the previous value (incoming value from - // the latch) to check if it dominates all users of the recurrence. Bail out - // if we have to sink such an instruction for another recurrence, as the - // dominance requirement may not hold after sinking. - BasicBlock *LoopLatch = TheLoop->getLoopLatch(); - if (any_of(FixedOrderRecurrences, [LoopLatch, this](const PHINode *Phi) { - Instruction *V = - cast(Phi->getIncomingValueForBlock(LoopLatch)); - return SinkAfter.contains(V); - })) - return false; - // Now we know the widest induction type, check if our found induction // is the same size. If it's not, unset it here and InnerLoopVectorizer // will create another. diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -8993,7 +8993,8 @@ // Sink users of fixed-order recurrence past the recipe defining the previous // value and introduce FirstOrderRecurrenceSplice VPInstructions. - VPlanTransforms::adjustFixedOrderRecurrences(*Plan, Builder); + if (!VPlanTransforms::adjustFixedOrderRecurrences(*Plan, Builder)) + return std::nullopt; // Interleave memory: for each Interleave Group we marked earlier as relevant // for this VPlan, replace the Recipes widening its memory instructions with a diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -77,7 +77,7 @@ /// to combine the value from the recurrence phis and previous values. The /// current implementation assumes all users can be sunk after the previous /// value, which is enforced by earlier legality checks. - static void adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder); + static bool adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder); /// Optimize \p Plan based on \p BestVF and \p BestUF. This may restrict the /// resulting plan to \p BestVF and \p BestUF. diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -658,35 +658,40 @@ // Sink users of \p FOR after the recipe defining the previous value \p Previous // of the recurrence. -static void -sinkRecurrenceUsersAfterPrevious(VPFirstOrderRecurrencePHIRecipe *FOR, - VPRecipeBase *Previous, - VPDominatorTree &VPDT) { +static bool sinkRecurrenceUsersAfterPrevious( + VPFirstOrderRecurrencePHIRecipe *FOR, VPRecipeBase *Previous, + SmallPtrSetImpl &SeenPrevious, VPDominatorTree &VPDT) { // Collect recipes that need sinking. SmallVector WorkList; SmallPtrSet Seen; Seen.insert(Previous); auto TryToPushSinkCandidate = [&](VPRecipeBase *SinkCandidate) { - assert( - SinkCandidate != Previous && - "The previous value cannot depend on the users of the recurrence phi."); + // The previous value cannot depend on the users of the recurrence phi. + if (SinkCandidate == Previous) + return false; + if (isa(SinkCandidate) || !Seen.insert(SinkCandidate).second || properlyDominates(Previous, SinkCandidate, VPDT)) - return; + return true; + + if (SeenPrevious.contains(SinkCandidate)) + return false; WorkList.push_back(SinkCandidate); + return true; }; // Recursively sink users of FOR after Previous. WorkList.push_back(FOR); for (unsigned I = 0; I != WorkList.size(); ++I) { VPRecipeBase *Current = WorkList[I]; - assert(Current->getNumDefinedValues() == 1 && - "only recipes with a single defined value expected"); - for (VPUser *User : Current->getVPSingleValue()->users()) { - if (auto *R = dyn_cast(User)) - TryToPushSinkCandidate(R); + for (VPValue *Val : Current->definedValues()) { + for (VPUser *User : Val->users()) { + if (auto *R = dyn_cast(User)) + if (!TryToPushSinkCandidate(R)) + return false; + } } } @@ -703,14 +708,16 @@ SinkCandidate->moveAfter(Previous); Previous = SinkCandidate; } + return true; } -void VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan, +bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan, VPBuilder &Builder) { VPDominatorTree VPDT; VPDT.recalculate(Plan); SmallVector RecurrencePhis; + SmallPtrSet SeenPrevious; for (VPRecipeBase &R : Plan.getVectorLoopRegion()->getEntry()->getEntryBasicBlock()->phis()) if (auto *FOR = dyn_cast(&R)) @@ -728,7 +735,9 @@ Previous = PrevPhi->getBackedgeValue()->getDefiningRecipe(); } - sinkRecurrenceUsersAfterPrevious(FOR, Previous, VPDT); + if (!sinkRecurrenceUsersAfterPrevious(FOR, Previous, SeenPrevious, VPDT)) + return false; + SeenPrevious.insert(Previous); // Introduce a recipe to combine the incoming and previous values of a // fixed-order recurrence. @@ -747,4 +756,5 @@ // all users. RecurSplice->setOperand(0, FOR); } + return true; } diff --git a/llvm/test/Transforms/LoopVectorize/first-order-recurrence-chains.ll b/llvm/test/Transforms/LoopVectorize/first-order-recurrence-chains.ll --- a/llvm/test/Transforms/LoopVectorize/first-order-recurrence-chains.ll +++ b/llvm/test/Transforms/LoopVectorize/first-order-recurrence-chains.ll @@ -634,12 +634,38 @@ ret void } -; Make sure LLVM doesn't generate wrong data in SinkAfter, and causes crash in -; loop vectorizer. define void @test_crash(ptr %p) { -; CHECK-LABEL: @test_crash -; CHECK-NOT: vector.body: -; CHECK: ret +; CHECK-LABEL: @test_crash( +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, %vector.ph ], [ [[INDEX_NEXT:%.*]], %vector.body ] +; CHECK-NEXT: [[VECTOR_RECUR:%.*]] = phi <4 x double> [ , %vector.ph ], [ [[BROADCAST_SPLAT:%.*]], %vector.body ] +; CHECK-NEXT: [[VECTOR_RECUR1:%.*]] = phi <4 x double> [ , %vector.ph ], [ [[BROADCAST_SPLAT4:%.*]], %vector.body ] +; CHECK-NEXT: [[VECTOR_RECUR2:%.*]] = phi <4 x double> [ , %vector.ph ], [ [[TMP4:%.*]], %vector.body ] +; CHECK-NEXT: [[TMP0:%.*]] = load double, ptr null, align 8 +; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <4 x double> poison, double [[TMP0]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLAT]] = shufflevector <4 x double> [[BROADCAST_SPLATINSERT]], <4 x double> poison, <4 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x double> [[VECTOR_RECUR]], <4 x double> [[BROADCAST_SPLAT]], <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = fdiv <4 x double> zeroinitializer, [[TMP1]] +; CHECK-NEXT: [[TMP3:%.*]] = load double, ptr null, align 8 +; CHECK-NEXT: [[BROADCAST_SPLATINSERT3:%.*]] = insertelement <4 x double> poison, double [[TMP3]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLAT4]] = shufflevector <4 x double> [[BROADCAST_SPLATINSERT3]], <4 x double> poison, <4 x i32> zeroinitializer +; CHECK-NEXT: [[TMP4]] = shufflevector <4 x double> [[VECTOR_RECUR1]], <4 x double> [[BROADCAST_SPLAT4]], <4 x i32> +; CHECK-NEXT: [[TMP5:%.*]] = shufflevector <4 x double> [[VECTOR_RECUR2]], <4 x double> [[TMP4]], <4 x i32> +; CHECK-NEXT: [[TMP6:%.*]] = extractelement <4 x double> [[TMP2]], i32 3 +; CHECK-NEXT: store double [[TMP6]], ptr [[P:%.*]], align 8 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 4 +; CHECK-NEXT: [[TMP7:%.*]] = icmp eq i64 [[INDEX_NEXT]], 0 +; CHECK-NEXT: br i1 [[TMP7]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP28:![0-9]+]] +; CHECK: middle.block: +; CHECK-NEXT: [[CMP_N:%.*]] = icmp eq i64 0, 0 +; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT]], i32 3 +; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT_FOR_PHI:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT]], i32 2 +; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT5:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT4]], i32 3 +; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT_FOR_PHI6:%.*]] = extractelement <4 x double> [[BROADCAST_SPLAT4]], i32 2 +; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT9:%.*]] = extractelement <4 x double> [[TMP4]], i32 3 +; CHECK-NEXT: [[VECTOR_RECUR_EXTRACT_FOR_PHI10:%.*]] = extractelement <4 x double> [[TMP4]], i32 2 +; CHECK-NEXT: br i1 [[CMP_N]], label %End, label %scalar.ph +; Entry: br label %Loop