Index: llvm/include/llvm/Analysis/VectorUtils.h =================================================================== --- llvm/include/llvm/Analysis/VectorUtils.h +++ llvm/include/llvm/Analysis/VectorUtils.h @@ -301,6 +301,11 @@ /// from the vector. Value *findScalarElement(Value *V, unsigned EltNo); +/// If all non-negative \p Mask elements are the same value, return that value. +/// If all elements are negative (undefined) or \p Mask contains different +/// non-negative values, return -1. +int getSplatIndex(ArrayRef Mask); + /// Get splat value if the input is a splat vector or return nullptr. /// The value may be extracted from a splat constants vector or from /// a sequence of instructions that broadcast a single value into a vector. Index: llvm/lib/Analysis/VectorUtils.cpp =================================================================== --- llvm/lib/Analysis/VectorUtils.cpp +++ llvm/lib/Analysis/VectorUtils.cpp @@ -307,6 +307,24 @@ return nullptr; } +int llvm::getSplatIndex(ArrayRef Mask) { + int SplatIndex = -1; + for (unsigned i = 0, e = Mask.size(); i != e; ++i) { + // Ignore invalid (undefined) mask elements. + int M = Mask[i]; + if (M < 0) + continue; + + // There can be only 1 non-negative mask element value if this is a splat. + if (SplatIndex >= 0 && SplatIndex != M) + return -1; + + // Initialize the splat index to the 1st non-negative mask element. + SplatIndex = M; + } + return SplatIndex; +} + /// Get splat value if the input is a splat vector or return nullptr. /// This function is not fully general. It checks only 2 cases: /// the input value is (1) a splat constant vector or (2) a sequence Index: llvm/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.cpp +++ llvm/lib/Target/X86/X86ISelLowering.cpp @@ -28,6 +28,7 @@ #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/IntrinsicLowering.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" @@ -12715,23 +12716,13 @@ // With MOVDDUP (v2f64) we can broadcast from a register or a load, otherwise // we can only broadcast from a register with AVX2. - unsigned NumElts = Mask.size(); unsigned NumEltBits = VT.getScalarSizeInBits(); unsigned Opcode = (VT == MVT::v2f64 && !Subtarget.hasAVX2()) ? X86ISD::MOVDDUP : X86ISD::VBROADCAST; bool BroadcastFromReg = (Opcode == X86ISD::MOVDDUP) || Subtarget.hasAVX2(); - // Check that the mask is a broadcast. - int BroadcastIdx = -1; - for (int i = 0; i != (int)NumElts; ++i) { - SmallVector BroadcastMask(NumElts, i); - if (isShuffleEquivalent(V1, V2, Mask, BroadcastMask)) { - BroadcastIdx = i; - break; - } - } - + int BroadcastIdx = getSplatIndex(Mask); if (BroadcastIdx < 0) return SDValue(); assert(BroadcastIdx < (int)Mask.size() && "We only expect to be called with " Index: llvm/unittests/Analysis/VectorUtilsTest.cpp =================================================================== --- llvm/unittests/Analysis/VectorUtilsTest.cpp +++ llvm/unittests/Analysis/VectorUtilsTest.cpp @@ -98,6 +98,17 @@ EXPECT_FALSE(isSplatValue(SplatWithUndefC)); } +TEST_F(BasicTest, getSplatIndex) { + EXPECT_EQ(getSplatIndex({0,0,0}), 0); + EXPECT_EQ(getSplatIndex({1,0,0}), -1); // no splat + EXPECT_EQ(getSplatIndex({0,1,1}), -1); // no splat + EXPECT_EQ(getSplatIndex({42,42,42}), 42); // array size is independent of splat index + EXPECT_EQ(getSplatIndex({42,42,-1}), 42); // ignore negative + EXPECT_EQ(getSplatIndex({-1,42,-1}), 42); // ignore negatives + EXPECT_EQ(getSplatIndex({-4,42,-42}), 42); // ignore all negatives + EXPECT_EQ(getSplatIndex({-4,-1,-42}), -1); // all negative values map to -1 +} + TEST_F(VectorUtilsTest, isSplatValue_00) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n"