diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -301,6 +301,11 @@ /// from the vector. Value *findScalarElement(Value *V, unsigned EltNo); +/// If all non-negative \p Mask elements are the same value, return that value. +/// If all elements are negative (undefined) or \p Mask contains different +/// non-negative values, return -1. +int getSplatIndex(ArrayRef Mask); + /// Get splat value if the input is a splat vector or return nullptr. /// The value may be extracted from a splat constants vector or from /// a sequence of instructions that broadcast a single value into a vector. diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -307,6 +307,24 @@ return nullptr; } +int llvm::getSplatIndex(ArrayRef Mask) { + int SplatIndex = -1; + for (int M : Mask) { + // Ignore invalid (undefined) mask elements. + if (M < 0) + continue; + + // There can be only 1 non-negative mask element value if this is a splat. + if (SplatIndex != -1 && SplatIndex != M) + return -1; + + // Initialize the splat index to the 1st non-negative mask element. + SplatIndex = M; + } + assert((SplatIndex == -1 || SplatIndex >= 0) && "Negative index?"); + return SplatIndex; +} + /// Get splat value if the input is a splat vector or return nullptr. /// This function is not fully general. It checks only 2 cases: /// the input value is (1) a splat constant vector or (2) a sequence diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp --- a/llvm/unittests/Analysis/VectorUtilsTest.cpp +++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp @@ -98,6 +98,17 @@ EXPECT_FALSE(isSplatValue(SplatWithUndefC)); } +TEST_F(BasicTest, getSplatIndex) { + EXPECT_EQ(getSplatIndex({0,0,0}), 0); + EXPECT_EQ(getSplatIndex({1,0,0}), -1); // no splat + EXPECT_EQ(getSplatIndex({0,1,1}), -1); // no splat + EXPECT_EQ(getSplatIndex({42,42,42}), 42); // array size is independent of splat index + EXPECT_EQ(getSplatIndex({42,42,-1}), 42); // ignore negative + EXPECT_EQ(getSplatIndex({-1,42,-1}), 42); // ignore negatives + EXPECT_EQ(getSplatIndex({-4,42,-42}), 42); // ignore all negatives + EXPECT_EQ(getSplatIndex({-4,-1,-42}), -1); // all negative values map to -1 +} + TEST_F(VectorUtilsTest, isSplatValue_00) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n"