diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -306,11 +306,13 @@ /// a sequence of instructions that broadcast a single value into a vector. const Value *getSplatValue(const Value *V); -/// Return true if the input value is known to be a vector with all identical -/// elements (potentially including undefined elements). +/// Return true if each element of the vector value \p V is poisoned or equal to +/// every other non-poisoned element. If an index element is specified, either +/// every element of the vector is poisoned or the element at that index is not +/// poisoned and equal to every other non-poisoned element. /// This may be more powerful than the related getSplatValue() because it is /// not limited by finding a scalar source value to a splatted vector. -bool isSplatValue(const Value *V, unsigned Depth = 0); +bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); /// Compute a map of integer instructions to their minimum legal type /// size. diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -330,21 +330,32 @@ // adjusted if needed. const unsigned MaxDepth = 6; -bool llvm::isSplatValue(const Value *V, unsigned Depth) { +bool llvm::isSplatValue(const Value *V, int Index, unsigned Depth) { assert(Depth <= MaxDepth && "Limit Search Depth"); if (isa(V->getType())) { if (isa(V)) return true; - // FIXME: Constant splat analysis does not allow undef elements. + // FIXME: We can allow undefs, but if Index was specified, we may want to + // check that the constant is defined at that index. if (auto *C = dyn_cast(V)) return C->getSplatValue() != nullptr; } - // FIXME: Constant splat analysis does not allow undef elements. - Constant *Mask; - if (match(V, m_ShuffleVector(m_Value(), m_Value(), m_Constant(Mask)))) - return Mask->getSplatValue() != nullptr; + if (auto *Shuf = dyn_cast(V)) { + // FIXME: We can safely allow undefs here. If Index was specified, we will + // check that the mask elt is defined at the required index. + if (!Shuf->getMask()->getSplatValue()) + return false; + + // Match any index. + if (Index == -1) + return true; + + // Match a specific element. The mask should be defined at and match the + // specified index. + return Shuf->getMaskValue(Index) == Index; + } // The remaining tests are all recursive, so bail out if we hit the limit. if (Depth++ == MaxDepth) @@ -353,12 +364,12 @@ // If both operands of a binop are splats, the result is a splat. Value *X, *Y, *Z; if (match(V, m_BinOp(m_Value(X), m_Value(Y)))) - return isSplatValue(X, Depth) && isSplatValue(Y, Depth); + return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth); // If all operands of a select are splats, the result is a splat. if (match(V, m_Select(m_Value(X), m_Value(Y), m_Value(Z)))) - return isSplatValue(X, Depth) && isSplatValue(Y, Depth) && - isSplatValue(Z, Depth); + return isSplatValue(X, Index, Depth) && isSplatValue(Y, Index, Depth) && + isSplatValue(Z, Index, Depth); // TODO: Add support for unary ops (fneg), casts, intrinsics (overflow ops). diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp --- a/llvm/unittests/Analysis/VectorUtilsTest.cpp +++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp @@ -107,6 +107,24 @@ EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_00_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_TRUE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_00_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> zeroinitializer\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_11) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -116,6 +134,24 @@ EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_11_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_11_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_TRUE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_01) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -125,7 +161,25 @@ EXPECT_FALSE(isSplatValue(A)); } -// FIXME: Constant (mask) splat analysis does not allow undef elements. +TEST_F(VectorUtilsTest, isSplatValue_01_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_01_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + +// FIXME: Allow undef matching with Constant (mask) splat analysis. TEST_F(VectorUtilsTest, isSplatValue_0u) { parseAssembly( @@ -136,6 +190,26 @@ EXPECT_FALSE(isSplatValue(A)); } +// FIXME: Allow undef matching with Constant (mask) splat analysis. + +TEST_F(VectorUtilsTest, isSplatValue_0u_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_0u_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %A = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_Binop) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -147,6 +221,28 @@ EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_Binop_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = udiv <2 x i8> %v0, %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_Binop_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v0 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = udiv <2 x i8> %v0, %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n" @@ -157,6 +253,26 @@ EXPECT_TRUE(isSplatValue(A)); } +TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index0) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = ashr <2 x i8> , %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_FALSE(isSplatValue(A, 0)); +} + +TEST_F(VectorUtilsTest, isSplatValue_Binop_ConstantOp0_index1) { + parseAssembly( + "define <2 x i8> @test(<2 x i8> %x) {\n" + " %v1 = shufflevector <2 x i8> %x, <2 x i8> undef, <2 x i32> \n" + " %A = ashr <2 x i8> , %v1\n" + " ret <2 x i8> %A\n" + "}\n"); + EXPECT_TRUE(isSplatValue(A, 1)); +} + TEST_F(VectorUtilsTest, isSplatValue_Binop_Not_Op0) { parseAssembly( "define <2 x i8> @test(<2 x i8> %x) {\n"