Index: llvm/include/llvm/Analysis/VectorUtils.h =================================================================== --- llvm/include/llvm/Analysis/VectorUtils.h +++ llvm/include/llvm/Analysis/VectorUtils.h @@ -301,10 +301,11 @@ const Value *getSplatValue(const Value *V); /// Return true if the input value is known to be a vector with all identical -/// elements (potentially including undefined elements). +/// elements (potentially including undefined elements). If an index value is +/// specified, require that the splat is from that element of the vector. /// This may be more powerful than the related getSplatValue() because it is /// not limited by finding a scalar source value to a splatted vector. -bool isSplatValue(const Value *V, unsigned Depth = 0); +bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); /// Compute a map of integer instructions to their minimum legal type /// size. Index: llvm/lib/Analysis/VectorUtils.cpp =================================================================== --- llvm/lib/Analysis/VectorUtils.cpp +++ llvm/lib/Analysis/VectorUtils.cpp @@ -330,21 +330,41 @@ // adjusted if needed. const unsigned MaxDepth = 6; -bool llvm::isSplatValue(const Value *V, unsigned Depth) { +bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) { assert(Depth <= MaxDepth && "Limit Search Depth"); if (isa(V->getType())) { if (isa(V)) return true; - // FIXME: Constant splat analysis does not allow undef elements. + // FIXME: We can allow undefs, but if Index was specified, we may want to + // check that the constant is defined at that index. if (auto *C = dyn_cast(V)) return C->getSplatValue() != nullptr; } - // FIXME: Constant splat analysis does not allow undef elements. Constant *Mask; - if (match(V, m_ShuffleVector(m_Value(), m_Value(), m_Constant(Mask)))) - return Mask->getSplatValue() != nullptr; + if (match(V, m_ShuffleVector(m_Value(), m_Value(), m_Constant(Mask)))) { + // FIXME: We can safely allow undefs here. If Index was specified, we will + // check that the mask elt is defined at the required index. + Constant *SplatC = Mask->getSplatValue(); + if (!SplatC) + return false; + + // Match any index. + if (Index == -1) + return true; + + // Match a specific element. + unsigned Elt = (unsigned)Index; + assert(Elt < V->getType()->getVectorNumElements() && + "Expected valid shuffle element"); + + // The mask should be defined at and match the specified index. + if (auto *MaskElt = dyn_cast(Mask->getAggregateElement(Elt))) + return MaskElt->getZExtValue() == Elt; + + return false; + } // The remaining tests are all recursive, so bail out if we hit the limit. if (Depth++ == MaxDepth) @@ -353,12 +373,12 @@ // If both operands of a binop are splats, the result is a splat. Value *X, *Y, *Z; if (match(V, m_BinOp(m_Value(X), m_Value(Y)))) - return isSplatValue(X, Depth) && isSplatValue(Y, Depth); + return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth); // If all operands of a select are splats, the result is a splat. if (match(V, m_Select(m_Value(X), m_Value(Y), m_Value(Z)))) - return isSplatValue(X, Depth) && isSplatValue(Y, Depth) && - isSplatValue(Z, Depth); + return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth) && + isSplatValue(Z, Index, Depth); // TODO: Add support for unary ops (fneg), casts, intrinsics (overflow ops). Index: llvm/unittests/Analysis/VectorUtilsTest.cpp =================================================================== --- llvm/unittests/Analysis/VectorUtilsTest.cpp +++ llvm/unittests/Analysis/VectorUtilsTest.cpp @@ -107,6 +107,24 @@ EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_00_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_TRUE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_00_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_11) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -116,6 +134,24 @@ EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_11_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_11_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_TRUE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_01) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -125,7 +161,25 @@ EXPECT_FALSE(isSplatValue(A)); } -// FIXME: Constant (mask) splat analysis does not allow undef elements. +TEST_F(VectorUtilsTest, isSplatValue_01_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_01_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + +// FIXME: Allow undef matching with Constant (mask) splat analysis. TEST_F(VectorUtilsTest, isSplatValue_0u) { parseAssembly( @@ -136,6 +190,26 @@ EXPECT_FALSE(isSplatValue(A)); } +// FIXME: Allow undef matching with Constant (mask) splat analysis. + +TEST_F(VectorUtilsTest, isSplatValue_0u_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_0u_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_Binop) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -147,6 +221,28 @@ EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_Binop_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = udiv <2 x i8> %v0, %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_Binop_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = udiv <2 x i8> %v0, %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -157,6 +253,26 @@ EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = ashr <2 x i8> , %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = ashr <2 x i8> , %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_TRUE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_Binop_Not_Op0) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n"