diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -7113,7 +7113,6 @@ // into a vector and can be represented as a permutation elements in a // single input vector or of 2 input vectors. Cost += computeExtractCost(VL, Mask, ShuffleKind); - InVectors.assign(1, E); return VecBase; } void add(const TreeEntry *E1, const TreeEntry *E2, ArrayRef Mask) { @@ -7124,18 +7123,57 @@ CommonMask.assign(Mask.begin(), Mask.end()); InVectors.assign(1, E1); } - void gather(ArrayRef VL, Value *Root = nullptr) { + /// Adds another one input vector and the mask for the shuffling. + void add(Value *V1, ArrayRef Mask) { + assert(CommonMask.empty() && InVectors.empty() && + "Expected empty input mask/vectors."); + CommonMask.assign(Mask.begin(), Mask.end()); + InVectors.assign(1, V1); + } + Value *gather(ArrayRef VL, Value *Root = nullptr) { Cost += getBuildVectorCost(VL, Root); if (!Root) { assert(InVectors.empty() && "Unexpected input vectors for buildvector."); // FIXME: Need to find a way to avoid use of getNullValue here. - InVectors.assign(1, Constant::getNullValue(FixedVectorType::get( - VL.front()->getType(), VL.size()))); + SmallVector Vals; + for (Value *V : VL) { + if (isa(V)) { + Vals.push_back(cast(V)); + continue; + } + Vals.push_back(Constant::getNullValue(V->getType())); + } + return ConstantVector::get(Vals); } + return ConstantVector::getSplat( + ElementCount::getFixed(VL.size()), + Constant::getNullValue(VL.front()->getType())); } /// Finalize emission of the shuffles. - InstructionCost finalize(ArrayRef ExtMask) { + InstructionCost + finalize(ArrayRef ExtMask, unsigned VF = 0, + function_ref &)> Action = {}) { IsFinalized = true; + if (Action) { + const PointerUnion &Vec = InVectors.front(); + if (InVectors.size() == 2) { + Cost += createShuffle(Vec, InVectors.back(), CommonMask); + InVectors.pop_back(); + } else { + Cost += createShuffle(Vec, nullptr, CommonMask); + } + for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx) + if (CommonMask[Idx] != PoisonMaskElem) + CommonMask[Idx] = Idx; + assert(VF > 0 && + "Expected vector length for the final value before action."); + Value *V = Vec.dyn_cast(); + if (!Vec.isNull() && !V) + V = Constant::getNullValue(FixedVectorType::get( + Vec.get()->Scalars.front()->getType(), + CommonMask.size())); + Action(V, CommonMask); + } ::addMask(CommonMask, ExtMask, /*ExtendingManyInputs=*/true); if (CommonMask.empty()) return Cost; @@ -7259,18 +7297,31 @@ Estimator.add(Entries.front(), Mask); else Estimator.add(Entries.front(), Entries.back(), Mask); - Estimator.gather( - GatheredScalars, - Constant::getNullValue(FixedVectorType::get( - GatheredScalars.front()->getType(), GatheredScalars.size()))); - return Estimator.finalize(E->ReuseShuffleIndices); - } - Estimator.gather( - GatheredScalars, - VL.equals(GatheredScalars) - ? nullptr - : Constant::getNullValue(FixedVectorType::get( - GatheredScalars.front()->getType(), GatheredScalars.size()))); + if (all_of(GatheredScalars, PoisonValue ::classof)) + return Estimator.finalize(E->ReuseShuffleIndices); + return Estimator.finalize( + E->ReuseShuffleIndices, E->Scalars.size(), + [&](Value *&Vec, SmallVectorImpl &Mask) { + Vec = Estimator.gather(GatheredScalars, + Constant::getNullValue(FixedVectorType::get( + GatheredScalars.front()->getType(), + GatheredScalars.size()))); + }); + } + if (!all_of(GatheredScalars, PoisonValue::classof)) { + auto Gathers = ArrayRef(GatheredScalars).take_front(VL.size()); + bool SameGathers = VL.equals(Gathers); + Value *BV = Estimator.gather( + Gathers, SameGathers ? nullptr + : Constant::getNullValue(FixedVectorType::get( + GatheredScalars.front()->getType(), + GatheredScalars.size()))); + SmallVector ReuseMask(Gathers.size(), PoisonMaskElem); + std::iota(ReuseMask.begin(), ReuseMask.end(), 0); + Estimator.add(BV, ReuseMask); + } + if (ExtractShuffle) + Estimator.add(E, std::nullopt); return Estimator.finalize(E->ReuseShuffleIndices); } InstructionCost CommonCost = 0;