diff --git a/llvm/include/llvm/IR/Instructions.h b/llvm/include/llvm/IR/Instructions.h --- a/llvm/include/llvm/IR/Instructions.h +++ b/llvm/include/llvm/IR/Instructions.h @@ -2354,6 +2354,34 @@ return isInsertSubvectorMask(ShuffleMask, NumSrcElts, NumSubElts, Index); } + /// Return true if this shuffle mask replicates each of the \p VF elements + /// in a vector \p ReplicationFactor times. + /// For example, the mask for \p ReplicationFactor=3 and \p VF=4 is: + /// <0,0,0,1,1,1,2,2,2,3,3,3> + static bool isReplicationMask(ArrayRef Mask, int &ReplicationFactor, + int &VF); + static bool isReplicationMask(const Constant *Mask, int &ReplicationFactor, + int &VF) { + assert(Mask->getType()->isVectorTy() && "Shuffle needs vector constant."); + // Not possible to express a shuffle mask for a scalable vector for this + // case. + if (isa(Mask->getType())) + return false; + SmallVector MaskAsInts; + getShuffleMask(Mask, MaskAsInts); + return isReplicationMask(MaskAsInts, ReplicationFactor, VF); + } + + /// Return true if this shuffle mask is an replication mask. + bool isReplicationMask(int &ReplicationFactor, int &VF) const { + // Not possible to express a shuffle mask for a scalable vector for this + // case. + if (isa(getType())) + return false; + + return isReplicationMask(ShuffleMask, ReplicationFactor, VF); + } + /// Change values in a shuffle permute mask assuming the two vector operands /// of length InVecNumElts have swapped position. static void commuteShuffleMask(MutableArrayRef Mask, diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -2436,6 +2436,60 @@ return isIdentityMaskImpl(getShuffleMask(), NumMaskElts); } +static bool isReplicationMaskWithParams(ArrayRef Mask, + int ReplicationFactor, int VF) { + assert(Mask.size() == (unsigned)ReplicationFactor * VF && + "Unexpected mask size."); + + for (int CurrElt : seq(0, VF)) { + ArrayRef CurrSubMask = Mask.take_front(ReplicationFactor); + assert(CurrSubMask.size() == (unsigned)ReplicationFactor && + "Run out of mask?"); + Mask = Mask.drop_front(ReplicationFactor); + if (!all_of(CurrSubMask, [CurrElt](int MaskElt) { + return MaskElt == UndefMaskElem || MaskElt == CurrElt; + })) + return false; + } + assert(Mask.empty() && "Did not consume the whole mask?"); + + return true; +} + +bool ShuffleVectorInst::isReplicationMask(ArrayRef Mask, + int &ReplicationFactor, int &VF) { + // undef-less case is trivial. + if (none_of(Mask, [](int MaskElt) { return MaskElt == UndefMaskElem; })) { + ReplicationFactor = + Mask.take_while([](int MaskElt) { return MaskElt == 0; }).size(); + if (ReplicationFactor == 0 || Mask.size() % ReplicationFactor != 0) + return false; + VF = Mask.size() / ReplicationFactor; + return isReplicationMaskWithParams(Mask, ReplicationFactor, VF); + } + + // However, if the mask contains undef's, we have to enumerate possible tuples + // and pick one. There are bounds on replication factor: [1, mask size] + // (where RF=1 is an identity shuffle, RF=mask size is a broadcast shuffle) + // Additionally, mask size is a replication factor multiplied by vector size, + // which significantly reduces the search space. + // Prefer larger replication factor if all else equal. + for (int PossibleReplicationFactor : + reverse(seq_inclusive(1, Mask.size()))) { + if (Mask.size() % PossibleReplicationFactor != 0) + continue; + int PossibleVF = Mask.size() / PossibleReplicationFactor; + if (!isReplicationMaskWithParams(Mask, PossibleReplicationFactor, + PossibleVF)) + continue; + ReplicationFactor = PossibleReplicationFactor; + VF = PossibleVF; + return true; + } + + return false; +} + //===----------------------------------------------------------------------===// // InsertValueInst Class //===----------------------------------------------------------------------===// diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp --- a/llvm/unittests/IR/InstructionsTest.cpp +++ b/llvm/unittests/IR/InstructionsTest.cpp @@ -6,10 +6,12 @@ // //===----------------------------------------------------------------------===// -#include "llvm/AsmParser/Parser.h" #include "llvm/IR/Instructions.h" +#include "llvm/ADT/CombinationGenerator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/AsmParser/Parser.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -1115,6 +1117,58 @@ delete Id15; } +TEST(InstructionsTest, ShuffleMaskIsReplicationMask) { + for (int ReplicationFactor : seq_inclusive(1, 8)) { + for (int VF : seq_inclusive(1, 8)) { + const auto ReplicatedMask = createReplicatedMask(ReplicationFactor, VF); + int GuessedReplicationFactor = -1, GuessedVF = -1; + EXPECT_TRUE(ShuffleVectorInst::isReplicationMask( + ReplicatedMask, GuessedReplicationFactor, GuessedVF)); + EXPECT_EQ(GuessedReplicationFactor, ReplicationFactor); + EXPECT_EQ(GuessedVF, VF); + } + } +} + +TEST(InstructionsTest, ShuffleMaskIsReplicationMask_Exhaustive_Correctness) { + for (int ShufMaskNumElts : seq_inclusive(1, 8)) { + SmallVector PossibleShufMaskElts; + PossibleShufMaskElts.reserve(ShufMaskNumElts + 2); + for (int PossibleShufMaskElt : seq_inclusive(-1, ShufMaskNumElts)) + PossibleShufMaskElts.emplace_back(PossibleShufMaskElt); + assert(PossibleShufMaskElts.size() == ShufMaskNumElts + 2U && + "Size misprediction"); + + SmallVector> ElementChoices(ShufMaskNumElts, + PossibleShufMaskElts); + + CombinationGenerator + G(ElementChoices); + + G.generate([&](ArrayRef Mask) -> bool { + int GuessedReplicationFactor = -1, GuessedVF = -1; + bool Match = ShuffleVectorInst::isReplicationMask( + Mask, GuessedReplicationFactor, GuessedVF); + if (!Match) + return /*Abort=*/false; + + const auto ActualMask = + createReplicatedMask(GuessedReplicationFactor, GuessedVF); + EXPECT_EQ(Mask.size(), ActualMask.size()); + for (auto I : zip(Mask, ActualMask)) { + int Elt = std::get<0>(I); + int ActualElt = std::get<0>(I); + + if (Elt != -1) + EXPECT_EQ(Elt, ActualElt); + } + + return /*Abort=*/false; + }); + } +} + TEST(InstructionsTest, GetSplat) { // Create the elements for various constant vectors. LLVMContext Ctx;