diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -366,6 +366,14 @@ /// not limited by finding a scalar source value to a splatted vector. bool isSplatValue(const Value *V, int Index = -1, unsigned Depth = 0); +/// Transform a shuffle mask's output demanded element mask into demanded +/// element masks for the 2 operands, returns false if the mask isn't valid. +/// Both \p DemandedLHS and \p DemandedRHS are initialised to [SrcWidth]. +/// \p AllowUndefElts permits "-1" indices to be treated as undef. +bool getShuffleDemandedElts(int SrcWidth, ArrayRef Mask, + const APInt &DemandedElts, APInt &DemandedLHS, + APInt &DemandedRHS, bool AllowUndefElts = false); + /// Replace each shuffle mask index with the scaled sequential indices for an /// equivalent mask of narrowed elements. Mask elements that are less than 0 /// (sentinel values) are repeated in the output mask. diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -34,6 +34,7 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -173,32 +174,8 @@ int NumElts = cast(Shuf->getOperand(0)->getType())->getNumElements(); - int NumMaskElts = cast(Shuf->getType())->getNumElements(); - DemandedLHS = DemandedRHS = APInt::getZero(NumElts); - if (DemandedElts.isZero()) - return true; - // Simple case of a shuffle with zeroinitializer. - if (all_of(Shuf->getShuffleMask(), [](int Elt) { return Elt == 0; })) { - DemandedLHS.setBit(0); - return true; - } - for (int i = 0; i != NumMaskElts; ++i) { - if (!DemandedElts[i]) - continue; - int M = Shuf->getMaskValue(i); - assert(M < (NumElts * 2) && "Invalid shuffle mask constant"); - - // For undef elements, we don't know anything about the common state of - // the shuffle result. - if (M == -1) - return false; - if (M < NumElts) - DemandedLHS.setBit(M % NumElts); - else - DemandedRHS.setBit(M % NumElts); - } - - return true; + return llvm::getShuffleDemandedElts(NumElts, Shuf->getShuffleMask(), + DemandedElts, DemandedLHS, DemandedRHS); } static void computeKnownBits(const Value *V, const APInt &DemandedElts, diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -429,6 +429,43 @@ return false; } +bool llvm::getShuffleDemandedElts(int SrcWidth, ArrayRef Mask, + const APInt &DemandedElts, APInt &DemandedLHS, + APInt &DemandedRHS, bool AllowUndefElts) { + DemandedLHS = DemandedRHS = APInt::getZero(SrcWidth); + + // Early out if we don't demand any elements. + if (DemandedElts.isZero()) + return true; + + // Simple case of a shuffle with zeroinitializer. + if (all_of(Mask, [](int Elt) { return Elt == 0; })) { + DemandedLHS.setBit(0); + return true; + } + + for (unsigned I = 0, E = Mask.size(); I != E; ++I) { + int M = Mask[I]; + assert((-1 <= M) && (M < (SrcWidth * 2)) && + "Invalid shuffle mask constant"); + + if (!DemandedElts[I] || (AllowUndefElts && (M < 0))) + continue; + + // For undef elements, we don't know anything about the common state of + // the shuffle result. + if (M < 0) + return false; + + if (M < SrcWidth) + DemandedLHS.setBit(M); + else + DemandedRHS.setBit(M - SrcWidth); + } + + return true; +} + void llvm::narrowShuffleMaskElts(int Scale, ArrayRef Mask, SmallVectorImpl &ScaledMask) { assert(Scale > 0 && "Unexpected scaling factor"); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -27,6 +27,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/Analysis.h" #include "llvm/CodeGen/FunctionLoweringInfo.h" #include "llvm/CodeGen/ISDOpcodes.h" @@ -2978,30 +2979,15 @@ case ISD::VECTOR_SHUFFLE: { // Collect the known bits that are shared by every vector element referenced // by the shuffle. - APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0); - Known.Zero.setAllBits(); Known.One.setAllBits(); + APInt DemandedLHS, DemandedRHS; const ShuffleVectorSDNode *SVN = cast(Op); assert(NumElts == SVN->getMask().size() && "Unexpected vector size"); - for (unsigned i = 0; i != NumElts; ++i) { - if (!DemandedElts[i]) - continue; - - int M = SVN->getMaskElt(i); - if (M < 0) { - // For UNDEF elements, we don't know anything about the common state of - // the shuffle result. - Known.resetAll(); - DemandedLHS.clearAllBits(); - DemandedRHS.clearAllBits(); - break; - } + if (!getShuffleDemandedElts(NumElts, SVN->getMask(), DemandedElts, + DemandedLHS, DemandedRHS)) + break; - if ((unsigned)M < NumElts) - DemandedLHS.setBit((unsigned)M % NumElts); - else - DemandedRHS.setBit((unsigned)M % NumElts); - } // Known bits are the values that are shared by every demanded element. + Known.Zero.setAllBits(); Known.One.setAllBits(); if (!!DemandedLHS) { SDValue LHS = Op.getOperand(0); Known2 = computeKnownBits(LHS, DemandedLHS, Depth + 1); @@ -3984,22 +3970,13 @@ case ISD::VECTOR_SHUFFLE: { // Collect the minimum number of sign bits that are shared by every vector // element referenced by the shuffle. - APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0); + APInt DemandedLHS, DemandedRHS; const ShuffleVectorSDNode *SVN = cast(Op); assert(NumElts == SVN->getMask().size() && "Unexpected vector size"); - for (unsigned i = 0; i != NumElts; ++i) { - int M = SVN->getMaskElt(i); - if (!DemandedElts[i]) - continue; - // For UNDEF elements, we don't know anything about the common state of - // the shuffle result. - if (M < 0) - return 1; - if ((unsigned)M < NumElts) - DemandedLHS.setBit((unsigned)M % NumElts); - else - DemandedRHS.setBit((unsigned)M % NumElts); - } + if (!getShuffleDemandedElts(NumElts, SVN->getMask(), DemandedElts, + DemandedLHS, DemandedRHS)) + return 1; + Tmp = std::numeric_limits::max(); if (!!DemandedLHS) Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedLHS, Depth + 1); diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -12,6 +12,7 @@ #include "llvm/CodeGen/TargetLowering.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/CodeGenCommonISel.h" #include "llvm/CodeGen/MachineFrameInfo.h" @@ -1291,25 +1292,10 @@ ArrayRef ShuffleMask = cast(Op)->getMask(); // Collect demanded elements from shuffle operands.. - APInt DemandedLHS(NumElts, 0); - APInt DemandedRHS(NumElts, 0); - for (unsigned i = 0; i != NumElts; ++i) { - if (!DemandedElts[i]) - continue; - int M = ShuffleMask[i]; - if (M < 0) { - // For UNDEF elements, we don't know anything about the common state of - // the shuffle result. - DemandedLHS.clearAllBits(); - DemandedRHS.clearAllBits(); - break; - } - assert(0 <= M && M < (int)(2 * NumElts) && "Shuffle index out of range"); - if (M < (int)NumElts) - DemandedLHS.setBit(M); - else - DemandedRHS.setBit(M - NumElts); - } + APInt DemandedLHS, DemandedRHS; + if (!getShuffleDemandedElts(NumElts, ShuffleMask, DemandedElts, DemandedLHS, + DemandedRHS)) + break; if (!!DemandedLHS || !!DemandedRHS) { SDValue Op0 = Op.getOperand(0); diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp --- a/llvm/unittests/Analysis/VectorUtilsTest.cpp +++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp @@ -166,6 +166,39 @@ EXPECT_EQ(makeArrayRef(WideMask), makeArrayRef({-2,-3})); } +TEST_F(BasicTest, getShuffleDemandedElts) { + APInt LHS, RHS; + + // broadcast zero + EXPECT_TRUE(getShuffleDemandedElts(4, {0, 0, 0, 0}, APInt(4,0xf), LHS, RHS)); + EXPECT_EQ(LHS.getZExtValue(), 0x1); + EXPECT_EQ(RHS.getZExtValue(), 0x0); + + // broadcast zero (with non-permitted undefs) + EXPECT_FALSE(getShuffleDemandedElts(2, {0, -1}, APInt(2, 0x3), LHS, RHS)); + + // broadcast zero (with permitted undefs) + EXPECT_TRUE(getShuffleDemandedElts(3, {0, 0, -1}, APInt(3, 0x7), LHS, RHS, true)); + EXPECT_EQ(LHS.getZExtValue(), 0x1); + EXPECT_EQ(RHS.getZExtValue(), 0x0); + + // broadcast one in demanded + EXPECT_TRUE(getShuffleDemandedElts(4, {1, 1, 1, -1}, APInt(4, 0x7), LHS, RHS)); + EXPECT_EQ(LHS.getZExtValue(), 0x2); + EXPECT_EQ(RHS.getZExtValue(), 0x0); + + // broadcast 7 in demanded + EXPECT_TRUE(getShuffleDemandedElts(4, {7, 0, 7, 7}, APInt(4, 0xd), LHS, RHS)); + EXPECT_EQ(LHS.getZExtValue(), 0x0); + EXPECT_EQ(RHS.getZExtValue(), 0x8); + + // general test + EXPECT_TRUE(getShuffleDemandedElts(4, {4, 2, 7, 3}, APInt(4, 0xf), LHS, RHS)); + EXPECT_EQ(LHS.getZExtValue(), 0xc); + EXPECT_EQ(RHS.getZExtValue(), 0x9); +} + + TEST_F(BasicTest, getSplatIndex) { EXPECT_EQ(getSplatIndex({0,0,0}), 0); EXPECT_EQ(getSplatIndex({1,0,0}), -1); // no splat