diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -1266,7 +1266,7 @@ void collectElementTypesForWidening(); /// Split reductions into those that happen in the loop, and those that happen - /// outside. In loop reductions are collected into InLoopReductionChains. + /// outside. In loop reductions are collected into InLoopReductions. void collectInLoopReductions(); /// Returns true if we should use strict in-order reductions for the given @@ -1602,20 +1602,9 @@ return foldTailByMasking() || Legal->blockNeedsPredication(BB); } - /// A SmallMapVector to store the InLoop reduction op chains, mapping phi - /// nodes to the chain of instructions representing the reductions. Uses a - /// MapVector to ensure deterministic iteration order. - using ReductionChainMap = - SmallMapVector, 4>; - - /// Return the chain of instructions representing an inloop reduction. - const ReductionChainMap &getInLoopReductionChains() const { - return InLoopReductionChains; - } - /// Returns true if the Phi is part of an inloop reduction. bool isInLoopReduction(PHINode *Phi) const { - return InLoopReductionChains.count(Phi); + return InLoopReductions.contains(Phi); } /// Estimate cost of an intrinsic call instruction CI if it were vectorized @@ -1779,15 +1768,12 @@ /// scalarized. DenseMap> ForcedScalars; - /// PHINodes of the reductions that should be expanded in-loop along with - /// their associated chains of reduction operations, in program order from top - /// (PHI) to bottom - ReductionChainMap InLoopReductionChains; + /// PHINodes of the reductions that should be expanded in-loop. + SmallPtrSet InLoopReductions; /// A Map of inloop reduction operations and their immediate chain operand. /// FIXME: This can be removed once reductions can be costed correctly in - /// vplan. This was added to allow quick lookup to the inloop operations, - /// without having to loop through InLoopReductionChains. + /// VPlan. This was added to allow quick lookup of the inloop operations. DenseMap InLoopReductionImmediateChains; /// Returns the expected difference in cost from scalarizing the expression @@ -6623,7 +6609,7 @@ Instruction *I, ElementCount VF, Type *Ty, TTI::TargetCostKind CostKind) { using namespace llvm::PatternMatch; // Early exit for no inloop reductions - if (InLoopReductionChains.empty() || VF.isScalar() || !isa(Ty)) + if (InLoopReductions.empty() || VF.isScalar() || !isa(Ty)) return std::nullopt; auto *VectorTy = cast(Ty); @@ -7473,8 +7459,9 @@ SmallVector ReductionOperations = RdxDesc.getReductionOpChain(Phi, TheLoop); bool InLoop = !ReductionOperations.empty(); + if (InLoop) { - InLoopReductionChains[Phi] = ReductionOperations; + InLoopReductions.insert(Phi); // Add the elements to InLoopReductionImmediateChains for cost modelling. Instruction *LastChain = Phi; for (auto *I : ReductionOperations) { @@ -8866,24 +8853,6 @@ // process after constructing the initial VPlan. // --------------------------------------------------------------------------- - for (const auto &Reduction : CM.getInLoopReductionChains()) { - PHINode *Phi = Reduction.first; - RecurKind Kind = - Legal->getReductionVars().find(Phi)->second.getRecurrenceKind(); - const SmallVector &ReductionOperations = Reduction.second; - - RecipeBuilder.recordRecipeOf(Phi); - for (const auto &R : ReductionOperations) { - RecipeBuilder.recordRecipeOf(R); - // For min/max reductions, where we have a pair of icmp/select, we also - // need to record the ICmp recipe, so it can be removed later. - assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) && - "Only min/max recurrences allowed for inloop reductions"); - if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) - RecipeBuilder.recordRecipeOf(cast(R->getOperand(0))); - } - } - // For each interleave group which is relevant for this (possibly trimmed) // Range, add it to the set of groups to be later applied to the VPlan and add // placeholders for its members' Recipes which we'll be replacing with a @@ -9163,86 +9132,118 @@ void LoopVectorizationPlanner::adjustRecipesForReductions( VPBasicBlock *LatchVPBB, VPlanPtr &Plan, VPRecipeBuilder &RecipeBuilder, ElementCount MinVF) { - for (const auto &Reduction : CM.getInLoopReductionChains()) { - PHINode *Phi = Reduction.first; - const RecurrenceDescriptor &RdxDesc = - Legal->getReductionVars().find(Phi)->second; - const SmallVector &ReductionOperations = Reduction.second; - - if (MinVF.isScalar() && !CM.useOrderedReductions(RdxDesc)) + SmallVector InLoopReductionPhis; + for (VPRecipeBase &R : + Plan->getVectorLoopRegion()->getEntryBasicBlock()->phis()) { + auto *PhiR = dyn_cast(&R); + if (!PhiR || !PhiR->isInLoop() || (MinVF.isScalar() && !PhiR->isOrdered())) continue; + InLoopReductionPhis.push_back(PhiR); + } + + for (VPReductionPHIRecipe *PhiR : InLoopReductionPhis) { + const RecurrenceDescriptor &RdxDesc = PhiR->getRecurrenceDescriptor(); + RecurKind Kind = RdxDesc.getRecurrenceKind(); + assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) && + "select/cmp reductions are not allowed for in-loop reductions"); + + // Collect the chain of "link" recipes for the reduction starting at PhiR. + SetVector Worklist; + Worklist.insert(PhiR); + for (unsigned I = 0; I != Worklist.size(); ++I) { + VPRecipeBase *Cur = Worklist[I]; + for (VPUser *U : Cur->getVPSingleValue()->users()) { + auto *UserRecipe = dyn_cast(U); + if (!UserRecipe) + continue; + assert(UserRecipe->getNumDefinedValues() == 1 && + "recipes must define exactly one result value"); + Worklist.insert(UserRecipe); + } + } + + // Visit operation "Links" along the reduction chain top-down starting from + // the phi until LoopExitValue. We keep track of the previous item + // (PreviousLink) to tell which of the two operands of a Link will remain + // scalar and which will be reduced. For minmax by select(cmp), Link will be + // the select instructions. + VPRecipeBase *PreviousLink = PhiR; // Aka Worklist[0]. + for (VPRecipeBase *CurrentLink : Worklist.getArrayRef().drop_front()) { + VPValue *PreviousLinkV = PreviousLink->getVPSingleValue(); - // ReductionOperations are orders top-down from the phi's use to the - // LoopExitValue. We keep a track of the previous item (the Chain) to tell - // which of the two operands will remain scalar and which will be reduced. - // For minmax the chain will be the select instructions. - Instruction *Chain = Phi; - for (Instruction *R : ReductionOperations) { - VPRecipeBase *WidenRecipe = RecipeBuilder.getRecipe(R); - RecurKind Kind = RdxDesc.getRecurrenceKind(); - - VPValue *ChainOp = Plan->getVPValue(Chain); - unsigned FirstOpId; - assert(!RecurrenceDescriptor::isSelectCmpRecurrenceKind(Kind) && - "Only min/max recurrences allowed for inloop reductions"); + Instruction *CurrentLinkI = CurrentLink->getUnderlyingInstr(); + + // Index of the first operand which holds a non-mask vector operand. + unsigned IndexOfFirstOperand; // Recognize a call to the llvm.fmuladd intrinsic. bool IsFMulAdd = (Kind == RecurKind::FMulAdd); - assert((!IsFMulAdd || RecurrenceDescriptor::isFMulAddIntrinsic(R)) && - "Expected instruction to be a call to the llvm.fmuladd intrinsic"); - if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { - assert(isa(WidenRecipe) && - "Expected to replace a VPWidenSelectSC"); - FirstOpId = 1; + VPValue *VecOp; + VPBasicBlock *LinkVPBB = CurrentLink->getParent(); + if (IsFMulAdd) { + assert( + RecurrenceDescriptor::isFMulAddIntrinsic(CurrentLinkI) && + "Expected instruction to be a call to the llvm.fmuladd intrinsic"); + assert(((MinVF.isScalar() && isa(CurrentLink)) || + isa(CurrentLink)) && + CurrentLink->getOperand(2) == PreviousLinkV && + "expected a call where the previous link is the added operand"); + + // If the instruction is a call to the llvm.fmuladd intrinsic then we + // need to create an fmul recipe (multiplying the first two operands of + // the fmuladd together) to use as the vector operand for the fadd + // reduction. + VPInstruction *FMulRecipe = + new VPInstruction(Instruction::FMul, {CurrentLink->getOperand(0), + CurrentLink->getOperand(1)}); + FMulRecipe->setFastMathFlags(CurrentLinkI->getFastMathFlags()); + LinkVPBB->insert(FMulRecipe, CurrentLink->getIterator()); + VecOp = FMulRecipe; } else { - assert((MinVF.isScalar() || isa(WidenRecipe) || - (IsFMulAdd && isa(WidenRecipe))) && - "Expected to replace a VPWidenSC"); - FirstOpId = 0; + if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { + if (auto *Cmp = dyn_cast(CurrentLink)) { + assert(isa(CurrentLinkI) && + "need to have the compare of the select"); + continue; + } + assert(isa(CurrentLink) && + "must be a select recipe"); + IndexOfFirstOperand = 1; + } else { + assert((MinVF.isScalar() || isa(CurrentLink)) && + "Expected to replace a VPWidenSC"); + IndexOfFirstOperand = 0; + } + // Note that for non-commutable operands (cmp-selects), the semantics of + // the cmp-select are captured in the recurrence kind. + unsigned VecOpId = + CurrentLink->getOperand(IndexOfFirstOperand) == PreviousLinkV + ? IndexOfFirstOperand + 1 + : IndexOfFirstOperand; + VecOp = CurrentLink->getOperand(VecOpId); + assert(VecOp != PreviousLinkV && + CurrentLink->getOperand(CurrentLink->getNumOperands() - 1 - + (VecOpId - IndexOfFirstOperand)) == + PreviousLinkV && + "PreviousLinkV must be the operand other than VecOp"); } - unsigned VecOpId = - R->getOperand(FirstOpId) == Chain ? FirstOpId + 1 : FirstOpId; - VPValue *VecOp = Plan->getVPValue(R->getOperand(VecOpId)); + BasicBlock *BB = CurrentLinkI->getParent(); VPValue *CondOp = nullptr; - if (CM.blockNeedsPredicationForAnyReason(R->getParent())) { + if (CM.blockNeedsPredicationForAnyReason(BB)) { VPBuilder::InsertPointGuard Guard(Builder); - Builder.setInsertPoint(WidenRecipe->getParent(), - WidenRecipe->getIterator()); - CondOp = RecipeBuilder.createBlockInMask(R->getParent(), *Plan); + Builder.setInsertPoint(LinkVPBB, CurrentLink->getIterator()); + CondOp = RecipeBuilder.createBlockInMask(BB, *Plan); } - if (IsFMulAdd) { - // If the instruction is a call to the llvm.fmuladd intrinsic then we - // need to create an fmul recipe to use as the vector operand for the - // fadd reduction. - VPInstruction *FMulRecipe = new VPInstruction( - Instruction::FMul, {VecOp, Plan->getVPValue(R->getOperand(1))}); - FMulRecipe->setFastMathFlags(R->getFastMathFlags()); - WidenRecipe->getParent()->insert(FMulRecipe, - WidenRecipe->getIterator()); - VecOp = FMulRecipe; - } - VPReductionRecipe *RedRecipe = - new VPReductionRecipe(&RdxDesc, R, ChainOp, VecOp, CondOp, &TTI); - WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe); - Plan->removeVPValueFor(R); - Plan->addVPValue(R, RedRecipe); + VPReductionRecipe *RedRecipe = new VPReductionRecipe( + &RdxDesc, CurrentLinkI, PreviousLinkV, VecOp, CondOp, &TTI); // Append the recipe to the end of the VPBasicBlock because we need to // ensure that it comes after all of it's inputs, including CondOp. - WidenRecipe->getParent()->appendRecipe(RedRecipe); - WidenRecipe->getVPSingleValue()->replaceAllUsesWith(RedRecipe); - WidenRecipe->eraseFromParent(); - - if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind)) { - VPRecipeBase *CompareRecipe = - RecipeBuilder.getRecipe(cast(R->getOperand(0))); - assert(isa(CompareRecipe) && - "Expected to replace a VPWidenSC"); - assert(cast(CompareRecipe)->getNumUsers() == 0 && - "Expected no remaining users"); - CompareRecipe->eraseFromParent(); - } - Chain = R; + // Note that this transformation may leave over dead recipes (including + // CurrentLink), which will be cleaned by a later VPlan transform. + LinkVPBB->appendRecipe(RedRecipe); + CurrentLink->getVPSingleValue()->replaceAllUsesWith(RedRecipe); + PreviousLink = RedRecipe; } } diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2586,12 +2586,6 @@ return getVPValue(V); } - void removeVPValueFor(Value *V) { - assert(Value2VPValueEnabled && - "IR value to VPValue mapping may be out of date!"); - Value2VPValue.erase(V); - } - #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print this VPlan to \p O. void print(raw_ostream &O) const;