diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1908,6 +1908,26 @@ /// the vector width and set the bits where elements are undef. SDValue getSplatValue(BitVector *UndefElements = nullptr) const; + /// Find the shortest repeating sequence of values in the build vector. + /// + /// Currently this must be a power-of-2 build vector. + /// The DemandedElts mask indicates the elements that must be present. + /// If passed a non-null UndefElements bitvector, it will resize it to match + /// the vector width and set the bits where elements are undef. + /// If result is false, Sequence will be empty; + bool getRepeatedSequence(const APInt &DemandedElts, + SmallVectorImpl &Sequence, + BitVector *UndefElements = nullptr) const; + + /// Find the shortest repeating sequence of values in the build vector. + /// + /// Currently this must be a power-of-2 build vector. + /// If passed a non-null UndefElements bitvector, it will resize it to match + /// the vector width and set the bits where elements are undef. + /// If result is false, Sequence will be empty; + bool getRepeatedSequence(SmallVectorImpl &Sequence, + BitVector *UndefElements = nullptr) const; + /// Returns the demanded splatted constant or null if this is not a constant /// splat. /// diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -9804,6 +9804,58 @@ return getSplatValue(DemandedElts, UndefElements); } +bool BuildVectorSDNode::getRepeatedSequence(const APInt &DemandedElts, + SmallVectorImpl &Sequence, + BitVector *UndefElements) const { + unsigned NumOps = getNumOperands(); + Sequence.clear(); + if (UndefElements) { + UndefElements->clear(); + UndefElements->resize(NumOps); + } + assert(NumOps == DemandedElts.getBitWidth() && "Unexpected vector size"); + if (!DemandedElts || NumOps < 2 || !isPowerOf2_32(NumOps)) + return false; + + // Set the undefs even if we don't find a sequence (like getSplatValue). + if (UndefElements) + for (unsigned I = 0; I != NumOps; ++I) + if (DemandedElts[I] && getOperand(I).isUndef()) + (*UndefElements)[I] = true; + + // Iteratively widen the sequence length looking for repetitions. + for (unsigned SeqLen = 1; SeqLen < NumOps; SeqLen *= 2) { + Sequence.append(SeqLen, SDValue()); + for (unsigned I = 0; I != NumOps; ++I) { + if (!DemandedElts[I]) + continue; + SDValue &SeqOp = Sequence[I % SeqLen]; + SDValue Op = getOperand(I); + if (Op.isUndef()) { + if (!SeqOp) + SeqOp = Op; + continue; + } + if (SeqOp && !SeqOp.isUndef() && SeqOp != Op) { + Sequence.clear(); + break; + } + SeqOp = Op; + } + if (!Sequence.empty()) + return true; + } + + assert(Sequence.empty() && "Failed to empty non-repeating sequence pattern"); + return false; +} + +bool BuildVectorSDNode::getRepeatedSequence(SmallVectorImpl &Sequence, + BitVector *UndefElements) const { + APInt DemandedElts = APInt::getAllOnesValue(getNumOperands()); + return getRepeatedSequence(DemandedElts, Sequence, UndefElements); +} + ConstantSDNode * BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts, BitVector *UndefElements) const { diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -8597,43 +8597,6 @@ return false; } -// Check if the current node of build vector is a zero extended vector. -// // If so, return the value extended. -// // For example: (0,0,0,a,0,0,0,a,0,0,0,a,0,0,0,a) returns a. -// // NumElt - return the number of zero extended identical values. -// // EltType - return the type of the value include the zero extend. -static SDValue isSplatZeroExtended(const BuildVectorSDNode *Op, - unsigned &NumElt, MVT &EltType) { - SDValue ExtValue = Op->getOperand(0); - unsigned NumElts = Op->getNumOperands(); - unsigned Delta = NumElts; - - for (unsigned i = 1; i < NumElts; i++) { - if (Op->getOperand(i) == ExtValue) { - Delta = i; - break; - } - if (!(Op->getOperand(i).isUndef() || isNullConstant(Op->getOperand(i)))) - return SDValue(); - } - if (!isPowerOf2_32(Delta) || Delta == 1) - return SDValue(); - - for (unsigned i = Delta; i < NumElts; i++) { - if (i % Delta == 0) { - if (Op->getOperand(i) != ExtValue) - return SDValue(); - } else if (!(isNullConstant(Op->getOperand(i)) || - Op->getOperand(i).isUndef())) - return SDValue(); - } - unsigned EltSize = Op->getSimpleValueType(0).getScalarSizeInBits(); - unsigned ExtVTSize = EltSize * Delta; - EltType = MVT::getIntegerVT(ExtVTSize); - NumElt = NumElts / Delta; - return ExtValue; -} - /// Attempt to use the vbroadcast instruction to generate a splat value /// from a splat BUILD_VECTOR which uses: /// a. A single scalar load, or a constant. @@ -8651,13 +8614,21 @@ return SDValue(); MVT VT = BVOp->getSimpleValueType(0); + unsigned NumElts = VT.getVectorNumElements(); SDLoc dl(BVOp); assert((VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) && "Unsupported vector type for broadcast."); + // See if the build vector is a repeating sequence of scalars (inc. splat). + SDValue Ld; BitVector UndefElements; - SDValue Ld = BVOp->getSplatValue(&UndefElements); + SmallVector Sequence; + if (BVOp->getRepeatedSequence(Sequence, &UndefElements)) { + assert((NumElts % Sequence.size()) == 0 && "Sequence doesn't fit."); + if (Sequence.size() == 1) + Ld = Sequence[0]; + } // Attempt to use VBROADCASTM // From this pattern: @@ -8665,29 +8636,29 @@ // b. t1 = (build_vector t0 t0) // // Create (VBROADCASTM v2i1 X) - if (Subtarget.hasCDI()) { - MVT EltType = VT.getScalarType(); - unsigned NumElts = VT.getVectorNumElements(); - SDValue BOperand; - SDValue ZeroExtended = isSplatZeroExtended(BVOp, NumElts, EltType); - if ((ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST) || - (ZeroExtended && ZeroExtended.getOpcode() == ISD::ZERO_EXTEND && - ZeroExtended.getOperand(0).getOpcode() == ISD::BITCAST) || - (Ld && Ld.getOpcode() == ISD::ZERO_EXTEND && - Ld.getOperand(0).getOpcode() == ISD::BITCAST)) { - if (ZeroExtended && ZeroExtended.getOpcode() == ISD::BITCAST) - BOperand = ZeroExtended.getOperand(0); - else if (ZeroExtended) - BOperand = ZeroExtended.getOperand(0).getOperand(0); - else - BOperand = Ld.getOperand(0).getOperand(0); + if (!Sequence.empty() && Subtarget.hasCDI()) { + // If not a splat, are the upper sequence values zeroable? + unsigned SeqLen = Sequence.size(); + bool UpperZeroOrUndef = + SeqLen == 1 || + llvm::all_of(makeArrayRef(Sequence).drop_front(), [](SDValue V) { + return V.isUndef() || isNullConstant(V); + }); + SDValue Op0 = Sequence[0]; + if (UpperZeroOrUndef && ((Op0.getOpcode() == ISD::BITCAST) || + (Op0.getOpcode() == ISD::ZERO_EXTEND && + Op0.getOperand(0).getOpcode() == ISD::BITCAST))) { + SDValue BOperand = Op0.getOpcode() == ISD::BITCAST + ? Op0.getOperand(0) + : Op0.getOperand(0).getOperand(0); MVT MaskVT = BOperand.getSimpleValueType(); + MVT EltType = MVT::getIntegerVT(VT.getScalarSizeInBits() * SeqLen); if ((EltType == MVT::i64 && MaskVT == MVT::v8i1) || // for broadcastmb2q (EltType == MVT::i32 && MaskVT == MVT::v16i1)) { // for broadcastmw2d - MVT BcstVT = MVT::getVectorVT(EltType, NumElts); + MVT BcstVT = MVT::getVectorVT(EltType, NumElts / SeqLen); if (!VT.is512BitVector() && !Subtarget.hasVLX()) { unsigned Scale = 512 / VT.getSizeInBits(); - BcstVT = MVT::getVectorVT(EltType, NumElts * Scale); + BcstVT = MVT::getVectorVT(EltType, Scale * (NumElts / SeqLen)); } SDValue Bcst = DAG.getNode(X86ISD::VBROADCASTM, dl, BcstVT, BOperand); if (BcstVT.getSizeInBits() != VT.getSizeInBits()) @@ -8697,7 +8668,6 @@ } } - unsigned NumElts = VT.getVectorNumElements(); unsigned NumUndefElts = UndefElements.count(); if (!Ld || (NumElts - NumUndefElts) <= 1) { APInt SplatValue, Undef; @@ -8771,6 +8741,8 @@ (Ld.getOpcode() == ISD::Constant || Ld.getOpcode() == ISD::ConstantFP); bool IsLoad = ISD::isNormalLoad(Ld.getNode()); + // TODO: Handle broadcasts of non-constant sequences. + // Make sure that all of the users of a non-constant load are from the // BUILD_VECTOR node. // FIXME: Is the use count needed for non-constant, non-load case?