diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1908,6 +1908,26 @@ /// the vector width and set the bits where elements are undef. SDValue getSplatValue(BitVector *UndefElements = nullptr) const; + /// Find the shortest repeating sequence of values in the build vector. + /// + /// Currently this must be a power-of-2 build vector. + /// The DemandedElts mask indicates the elements that must be present. + /// If passed a non-null UndefElements bitvector, it will resize it to match + /// the vector width and set the bits where elements are undef. + /// If result is false, Sequence will be empty. + bool getRepeatedSequence(const APInt &DemandedElts, + SmallVectorImpl &Sequence, + BitVector *UndefElements = nullptr) const; + + /// Find the shortest repeating sequence of values in the build vector. + /// + /// Currently this must be a power-of-2 build vector. + /// If passed a non-null UndefElements bitvector, it will resize it to match + /// the vector width and set the bits where elements are undef. + /// If result is false, Sequence will be empty. + bool getRepeatedSequence(SmallVectorImpl &Sequence, + BitVector *UndefElements = nullptr) const; + /// Returns the demanded splatted constant or null if this is not a constant /// splat. /// diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -9870,6 +9870,58 @@ return getSplatValue(DemandedElts, UndefElements); } +bool BuildVectorSDNode::getRepeatedSequence(const APInt &DemandedElts, + SmallVectorImpl &Sequence, + BitVector *UndefElements) const { + unsigned NumOps = getNumOperands(); + Sequence.clear(); + if (UndefElements) { + UndefElements->clear(); + UndefElements->resize(NumOps); + } + assert(NumOps == DemandedElts.getBitWidth() && "Unexpected vector size"); + if (!DemandedElts || NumOps < 2 || !isPowerOf2_32(NumOps)) + return false; + + // Set the undefs even if we don't find a sequence (like getSplatValue). + if (UndefElements) + for (unsigned I = 0; I != NumOps; ++I) + if (DemandedElts[I] && getOperand(I).isUndef()) + (*UndefElements)[I] = true; + + // Iteratively widen the sequence length looking for repetitions. + for (unsigned SeqLen = 1; SeqLen < NumOps; SeqLen *= 2) { + Sequence.append(SeqLen, SDValue()); + for (unsigned I = 0; I != NumOps; ++I) { + if (!DemandedElts[I]) + continue; + SDValue &SeqOp = Sequence[I % SeqLen]; + SDValue Op = getOperand(I); + if (Op.isUndef()) { + if (!SeqOp) + SeqOp = Op; + continue; + } + if (SeqOp && !SeqOp.isUndef() && SeqOp != Op) { + Sequence.clear(); + break; + } + SeqOp = Op; + } + if (!Sequence.empty()) + return true; + } + + assert(Sequence.empty() && "Failed to empty non-repeating sequence pattern"); + return false; +} + +bool BuildVectorSDNode::getRepeatedSequence(SmallVectorImpl &Sequence, + BitVector *UndefElements) const { + APInt DemandedElts = APInt::getAllOnesValue(getNumOperands()); + return getRepeatedSequence(DemandedElts, Sequence, UndefElements); +} + ConstantSDNode * BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts, BitVector *UndefElements) const { diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -8644,43 +8644,6 @@ return false; } -// Check if the current node of build vector is a zero extended vector. -// // If so, return the value extended. -// // For example: (0,0,0,a,0,0,0,a,0,0,0,a,0,0,0,a) returns a. -// // NumElt - return the number of zero extended identical values. -// // EltType - return the type of the value include the zero extend. -static SDValue isSplatZeroExtended(const BuildVectorSDNode *Op, - unsigned &NumElt, MVT &EltType) { - SDValue ExtValue = Op->getOperand(0); - unsigned NumElts = Op->getNumOperands(); - unsigned Delta = NumElts; - - for (unsigned i = 1; i < NumElts; i++) { - if (Op->getOperand(i) == ExtValue) { - Delta = i; - break; - } - if (!(Op->getOperand(i).isUndef() || isNullConstant(Op->getOperand(i)))) - return SDValue(); - } - if (!isPowerOf2_32(Delta) || Delta == 1) - return SDValue(); - - for (unsigned i = Delta; i < NumElts; i++) { - if (i % Delta == 0) { - if (Op->getOperand(i) != ExtValue) - return SDValue(); - } else if (!(isNullConstant(Op->getOperand(i)) || - Op->getOperand(i).isUndef())) - return SDValue(); - } - unsigned EltSize = Op->getSimpleValueType(0).getScalarSizeInBits(); - unsigned ExtVTSize = EltSize * Delta; - EltType = MVT::getIntegerVT(ExtVTSize); - NumElt = NumElts / Delta; - return ExtValue; -} - /// Attempt to use the vbroadcast instruction to generate a splat value /// from a splat BUILD_VECTOR which uses: /// a. A single scalar load, or a constant. @@ -8698,13 +8661,21 @@ return SDValue(); MVT VT = BVOp->getSimpleValueType(0); + unsigned NumElts = VT.getVectorNumElements(); SDLoc dl(BVOp); assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) && "Unsupported vector type for broadcast."); + // See if the build vector is a repeating sequence of scalars (inc. splat). + SDValue Ld; BitVector UndefElements; - SDValue Ld = BVOp->getSplatValue(&UndefElements); + SmallVector Sequence; + if (BVOp->getRepeatedSequence(Sequence, &UndefElements)) { + assert((NumElts % Sequence.size()) == 0 && "Sequence doesn't fit."); + if (Sequence.size() == 1) + Ld = Sequence[0]; + } // Attempt to use VBROADCASTM // From this pattern: @@ -8712,29 +8683,29 @@ // b. t1 = (build_vector t0 t0) // // Create (VBROADCASTM v2i1 X) - if (Subtarget.hasCDI()) { - MVT EltType = VT.getScalarType(); - unsigned NumElts = VT.getVectorNumElements(); - SDValue BOperand; - SDValue ZeroExtended = isSplatZeroExtended(BVOp, NumElts, EltType); - if ((ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST) || - (ZeroExtended && ZeroExtended.getOpcode() == ISD::ZERO_EXTEND && - ZeroExtended.getOperand(0).getOpcode() == ISD::BITCAST) || - (Ld && Ld.getOpcode() == ISD::ZERO_EXTEND && - Ld.getOperand(0).getOpcode() == ISD::BITCAST)) { - if (ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST) - BOperand = ZeroExtended.getOperand(0); - else if (ZeroExtended) - BOperand = ZeroExtended.getOperand(0).getOperand(0); - else - BOperand = Ld.getOperand(0).getOperand(0); + if (!Sequence.empty() && Subtarget.hasCDI()) { + // If not a splat, are the upper sequence values zeroable? + unsigned SeqLen = Sequence.size(); + bool UpperZeroOrUndef = + SeqLen == 1 || + llvm::all_of(makeArrayRef(Sequence).drop_front(), [](SDValue V) { + return V.isUndef() || isNullConstant(V); + }); + SDValue Op0 = Sequence[0]; + if (UpperZeroOrUndef && ((Op0.getOpcode() == ISD::BITCAST) || + (Op0.getOpcode() == ISD::ZERO_EXTEND && + Op0.getOperand(0).getOpcode() == ISD::BITCAST))) { + SDValue BOperand = Op0.getOpcode() == ISD::BITCAST + ? Op0.getOperand(0) + : Op0.getOperand(0).getOperand(0); MVT MaskVT = BOperand.getSimpleValueType(); + MVT EltType = MVT::getIntegerVT(VT.getScalarSizeInBits() * SeqLen); if ((EltType == MVT::i64 && MaskVT == MVT::v8i1) || // for broadcastmb2q (EltType == MVT::i32 && MaskVT == MVT::v16i1)) { // for broadcastmw2d - MVT BcstVT = MVT::getVectorVT(EltType, NumElts); + MVT BcstVT = MVT::getVectorVT(EltType, NumElts / SeqLen); if (!VT.is512BitVector() && !Subtarget.hasVLX()) { unsigned Scale = 512 / VT.getSizeInBits(); - BcstVT = MVT::getVectorVT(EltType, NumElts * Scale); + BcstVT = MVT::getVectorVT(EltType, Scale * (NumElts / SeqLen)); } SDValue Bcst = DAG.getNode(X86ISD::VBROADCASTM, dl, BcstVT, BOperand); if (BcstVT.getSizeInBits() != VT.getSizeInBits()) @@ -8744,7 +8715,6 @@ } } - unsigned NumElts = VT.getVectorNumElements(); unsigned NumUndefElts = UndefElements.count(); if (!Ld || (NumElts - NumUndefElts) <= 1) { APInt SplatValue, Undef; @@ -8818,6 +8788,8 @@ (Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP); bool IsLoad = ISD::isNormalLoad(Ld.getNode()); + // TODO: Handle broadcasts of non-constant sequences. + // Make sure that all of the users of a non-constant load are from the // BUILD_VECTOR node. // FIXME: Is the use count needed for non-constant, non-load case? diff --git a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp --- a/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp +++ b/llvm/unittests/CodeGen/AArch64SelectionDAGTest.cpp @@ -472,6 +472,127 @@ EXPECT_EQ(SplatIdx, 0); } +TEST_F(AArch64SelectionDAGTest, getRepeatedSequence_Patterns) { + if (!TM) + return; + + TargetLowering TL(*TM); + + SDLoc Loc; + unsigned NumElts = 16; + MVT IntVT = MVT::i8; + MVT VecVT = MVT::getVectorVT(IntVT, NumElts); + + // + SDValue Val0 = DAG->getConstant(0, Loc, IntVT); + SDValue Val1 = DAG->getConstant(1, Loc, IntVT); + SDValue Val2 = DAG->getConstant(2, Loc, IntVT); + SDValue Val3 = DAG->getConstant(3, Loc, IntVT); + SDValue UndefVal = DAG->getUNDEF(IntVT); + + // Build some repeating sequences. + SmallVector Pattern1111, Pattern1133, Pattern0123; + for(int I = 0; I != 4; ++I) { + Pattern1111.append(4, Val1); + Pattern1133.append(2, Val1); + Pattern1133.append(2, Val3); + Pattern0123.push_back(Val0); + Pattern0123.push_back(Val1); + Pattern0123.push_back(Val2); + Pattern0123.push_back(Val3); + } + + // Build a non-pow2 repeating sequence. + SmallVector Pattern022; + Pattern022.push_back(Val0); + Pattern022.append(2, Val2); + Pattern022.push_back(Val0); + Pattern022.append(2, Val2); + Pattern022.push_back(Val0); + Pattern022.append(2, Val2); + Pattern022.push_back(Val0); + Pattern022.append(2, Val2); + Pattern022.push_back(Val0); + Pattern022.append(2, Val2); + Pattern022.push_back(Val0); + + // Build a non-repeating sequence. + SmallVector Pattern1_3; + Pattern1_3.append(8, Val1); + Pattern1_3.append(8, Val3); + + // Add some undefs to make it trickier. + Pattern1111[1] = Pattern1111[2] = Pattern1111[15] = UndefVal; + Pattern1133[0] = Pattern1133[2] = UndefVal; + + auto *BV1111 = + cast(DAG->getBuildVector(VecVT, Loc, Pattern1111)); + auto *BV1133 = + cast(DAG->getBuildVector(VecVT, Loc, Pattern1133)); + auto *BV0123= + cast(DAG->getBuildVector(VecVT, Loc, Pattern0123)); + auto *BV022 = + cast(DAG->getBuildVector(VecVT, Loc, Pattern022)); + auto *BV1_3 = + cast(DAG->getBuildVector(VecVT, Loc, Pattern1_3)); + + // Check for sequences. + SmallVector Seq1111, Seq1133, Seq0123, Seq022, Seq1_3; + BitVector Undefs1111, Undefs1133, Undefs0123, Undefs022, Undefs1_3; + + EXPECT_TRUE(BV1111->getRepeatedSequence(Seq1111, &Undefs1111)); + EXPECT_EQ(Undefs1111.count(), 3); + EXPECT_EQ(Seq1111.size(), 1); + EXPECT_EQ(Seq1111[0], Val1); + + EXPECT_TRUE(BV1133->getRepeatedSequence(Seq1133, &Undefs1133)); + EXPECT_EQ(Undefs1133.count(), 2); + EXPECT_EQ(Seq1133.size(), 4); + EXPECT_EQ(Seq1133[0], Val1); + EXPECT_EQ(Seq1133[1], Val1); + EXPECT_EQ(Seq1133[2], Val3); + EXPECT_EQ(Seq1133[3], Val3); + + EXPECT_TRUE(BV0123->getRepeatedSequence(Seq0123, &Undefs0123)); + EXPECT_EQ(Undefs0123.count(), 0); + EXPECT_EQ(Seq0123.size(), 4); + EXPECT_EQ(Seq0123[0], Val0); + EXPECT_EQ(Seq0123[1], Val1); + EXPECT_EQ(Seq0123[2], Val2); + EXPECT_EQ(Seq0123[3], Val3); + + EXPECT_FALSE(BV022->getRepeatedSequence(Seq022, &Undefs022)); + EXPECT_FALSE(BV1_3->getRepeatedSequence(Seq1_3, &Undefs1_3)); + + // Try again with DemandedElts masks. + APInt Mask1111_0 = APInt::getOneBitSet(NumElts, 0); + EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_0, Seq1111, &Undefs1111)); + EXPECT_EQ(Undefs1111.count(), 0); + EXPECT_EQ(Seq1111.size(), 1); + EXPECT_EQ(Seq1111[0], Val1); + + APInt Mask1111_1 = APInt::getOneBitSet(NumElts, 2); + EXPECT_TRUE(BV1111->getRepeatedSequence(Mask1111_1, Seq1111, &Undefs1111)); + EXPECT_EQ(Undefs1111.count(), 1); + EXPECT_EQ(Seq1111.size(), 1); + EXPECT_EQ(Seq1111[0], UndefVal); + + APInt Mask0123 = APInt(NumElts, 0x7777); + EXPECT_TRUE(BV0123->getRepeatedSequence(Mask0123, Seq0123, &Undefs0123)); + EXPECT_EQ(Undefs0123.count(), 0); + EXPECT_EQ(Seq0123.size(), 4); + EXPECT_EQ(Seq0123[0], Val0); + EXPECT_EQ(Seq0123[1], Val1); + EXPECT_EQ(Seq0123[2], Val2); + EXPECT_EQ(Seq0123[3], SDValue()); + + APInt Mask1_3 = APInt::getHighBitsSet(16, 8); + EXPECT_TRUE(BV1_3->getRepeatedSequence(Mask1_3, Seq1_3, &Undefs1_3)); + EXPECT_EQ(Undefs1_3.count(), 0); + EXPECT_EQ(Seq1_3.size(), 1); + EXPECT_EQ(Seq1_3[0], Val3); +} + TEST_F(AArch64SelectionDAGTest, getTypeConversion_SplitScalableMVT) { if (!TM) return;