Index: llvm/trunk/include/llvm/CodeGen/SelectionDAG.h =================================================================== --- llvm/trunk/include/llvm/CodeGen/SelectionDAG.h +++ llvm/trunk/include/llvm/CodeGen/SelectionDAG.h @@ -1515,6 +1515,18 @@ /// allow an 'add' to be transformed into an 'or'. bool haveNoCommonBitsSet(SDValue A, SDValue B) const; + /// Test whether \p V has a splatted value for all the demanded elements. + /// + /// On success \p UndefElts will indicate the elements that have UNDEF + /// values instead of the splat value, this is only guaranteed to be correct + /// for \p DemandedElts. + /// + /// NOTE: The function will return true for a demanded splat of UNDEF values. + bool isSplatValue(SDValue V, const APInt &DemandedElts, APInt &UndefElts); + + /// Test whether \p V has a splatted value. + bool isSplatValue(SDValue V, bool AllowUndefs = false); + /// Match a binop + shuffle pyramid that represents a horizontal reduction /// over the elements of a vector starting from the EXTRACT_VECTOR_ELT node /p /// Extract. The reduction must use one of the opcodes listed in /p Index: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2121,6 +2121,102 @@ return Mask.isSubsetOf(computeKnownBits(Op, Depth).Zero); } +/// isSplatValue - Return true if the vector V has the same value +/// across all DemandedElts. +bool SelectionDAG::isSplatValue(SDValue V, const APInt &DemandedElts, + APInt &UndefElts) { + if (!DemandedElts) + return false; // No demanded elts, better to assume we don't know anything. + + EVT VT = V.getValueType(); + assert(VT.isVector() && "Vector type expected"); + + unsigned NumElts = VT.getVectorNumElements(); + assert(NumElts == DemandedElts.getBitWidth() && "Vector size mismatch"); + UndefElts = APInt::getNullValue(NumElts); + + switch (V.getOpcode()) { + case ISD::BUILD_VECTOR: { + SDValue Scl; + for (unsigned i = 0; i != NumElts; ++i) { + SDValue Op = V.getOperand(i); + if (Op.isUndef()) { + UndefElts.setBit(i); + continue; + } + if (!DemandedElts[i]) + continue; + if (Scl && Scl != Op) + return false; + Scl = Op; + } + return true; + } + case ISD::VECTOR_SHUFFLE: { + // Check if this is a shuffle node doing a splat. + // TODO: Do we need to handle shuffle(splat, undef, mask)? + int SplatIndex = -1; + ArrayRef Mask = cast(V)->getMask(); + for (int i = 0; i != (int)NumElts; ++i) { + int M = Mask[i]; + if (M < 0) { + UndefElts.setBit(i); + continue; + } + if (!DemandedElts[i]) + continue; + if (0 <= SplatIndex && SplatIndex != M) + return false; + SplatIndex = M; + } + return true; + } + case ISD::EXTRACT_SUBVECTOR: { + SDValue Src = V.getOperand(0); + ConstantSDNode *SubIdx = dyn_cast(V.getOperand(1)); + unsigned NumSrcElts = Src.getValueType().getVectorNumElements(); + if (SubIdx && SubIdx->getAPIntValue().ule(NumSrcElts - NumElts)) { + // Offset the demanded elts by the subvector index. + uint64_t Idx = SubIdx->getZExtValue(); + APInt UndefSrcElts; + APInt DemandedSrc = DemandedElts.zextOrSelf(NumSrcElts).shl(Idx); + if (isSplatValue(Src, DemandedSrc, UndefSrcElts)) { + UndefElts = UndefSrcElts.extractBits(NumElts, Idx); + return true; + } + } + break; + } + case ISD::ADD: + case ISD::SUB: + case ISD::AND: { + APInt UndefLHS, UndefRHS; + SDValue LHS = V.getOperand(0); + SDValue RHS = V.getOperand(1); + if (isSplatValue(LHS, DemandedElts, UndefLHS) && + isSplatValue(RHS, DemandedElts, UndefRHS)) { + UndefElts = UndefLHS | UndefRHS; + return true; + } + break; + } + } + + return false; +} + +/// Helper wrapper to main isSplatValue function. +bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) { + EVT VT = V.getValueType(); + assert(VT.isVector() && "Vector type expected"); + unsigned NumElts = VT.getVectorNumElements(); + + APInt UndefElts; + APInt DemandedElts = APInt::getAllOnesValue(NumElts); + return isSplatValue(V, DemandedElts, UndefElts) && + (AllowUndefs || !UndefElts); +} + /// Helper function that checks to see if a node is a constant or a /// build vector of splat constants at least within the demanded elts. static ConstantSDNode *isConstOrDemandedConstSplat(SDValue N, Index: llvm/trunk/lib/Target/Mips/MipsSEISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/Mips/MipsSEISelLowering.cpp +++ llvm/trunk/lib/Target/Mips/MipsSEISelLowering.cpp @@ -2360,24 +2360,6 @@ } } -/// Check if the given BuildVectorSDNode is a splat. -/// This method currently relies on DAG nodes being reused when equivalent, -/// so it's possible for this to return false even when isConstantSplat returns -/// true. -static bool isSplatVector(const BuildVectorSDNode *N) { - unsigned int nOps = N->getNumOperands(); - assert(nOps > 1 && "isSplatVector has 0 or 1 sized build vector"); - - SDValue Operand0 = N->getOperand(0); - - for (unsigned int i = 1; i < nOps; ++i) { - if (N->getOperand(i) != Operand0) - return false; - } - - return true; -} - // Lower ISD::EXTRACT_VECTOR_ELT into MipsISD::VEXTRACT_SEXT_ELT. // // The non-value bits resulting from ISD::EXTRACT_VECTOR_ELT are undefined. We @@ -2488,7 +2470,7 @@ Result = DAG.getNode(ISD::BITCAST, SDLoc(Node), ResTy, Result); return Result; - } else if (isSplatVector(Node)) + } else if (DAG.isSplatValue(Op, /* AllowUndefs */ false)) return Op; else if (!isConstantOrUndefBUILD_VECTOR(Node)) { // Use INSERT_VECTOR_ELT operations rather than expand to stores. Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -24072,26 +24072,30 @@ } // If V is a splat value, return the source vector and splat index; -// TODO - can we make this generic and move to SelectionDAG? -static SDValue IsSplatVector(SDValue V, int &SplatIdx) { +static SDValue IsSplatVector(SDValue V, int &SplatIdx, SelectionDAG &DAG) { V = peekThroughEXTRACT_SUBVECTORs(V); + EVT VT = V.getValueType(); unsigned Opcode = V.getOpcode(); switch (Opcode) { - case ISD::BUILD_VECTOR: { - BitVector UndefElts; - SDValue SplatAmt = cast(V)->getSplatValue(&UndefElts); - if (SplatAmt && !SplatAmt.isUndef()) { - for (int i = 0, e = UndefElts.size(); i != e; ++i) - if (!UndefElts[i]) { - SplatIdx = i; - return V; - } + default: { + APInt UndefElts; + APInt DemandedElts = APInt::getAllOnesValue(VT.getVectorNumElements()); + if (DAG.isSplatValue(V, DemandedElts, UndefElts)) { + // Handle case where all demanded elements are UNDEF. + if (DemandedElts.isSubsetOf(UndefElts)) { + SplatIdx = 0; + return DAG.getUNDEF(VT); + } + SplatIdx = (UndefElts & DemandedElts).countTrailingOnes(); + return V; } break; } case ISD::VECTOR_SHUFFLE: { // Check if this is a shuffle node doing a splat. + // TODO - remove this and rely purely on SelectionDAG::isSplatValue, + // getTargetVShiftNode currently struggles without the splat source. auto *SVN = cast(V); if (!SVN->isSplat()) break; @@ -24100,23 +24104,6 @@ SplatIdx = Idx % NumElts; return V.getOperand(Idx / NumElts); } - case ISD::SUB: { - SDValue LHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(0)); - SDValue RHS = peekThroughEXTRACT_SUBVECTORs(V.getOperand(1)); - - // Ensure that the corresponding splat BV element is not UNDEF. - BitVector UndefElts; - auto *BV0 = dyn_cast(LHS); - auto *SVN1 = dyn_cast(RHS); - if (BV0 && SVN1 && BV0->getSplatValue(&UndefElts) && SVN1->isSplat()) { - int Idx = SVN1->getSplatIndex(); - if (!UndefElts[Idx]) { - SplatIdx = Idx; - return V; - } - } - break; - } } return SDValue(); @@ -24125,7 +24112,7 @@ static SDValue GetSplatValue(SDValue V, const SDLoc &dl, SelectionDAG &DAG) { int SplatIdx; - if (SDValue SrcVector = IsSplatVector(V, SplatIdx)) + if (SDValue SrcVector = IsSplatVector(V, SplatIdx, DAG)) return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, SrcVector.getValueType().getScalarType(), SrcVector, DAG.getIntPtrConstant(SplatIdx, dl)); @@ -24850,8 +24837,7 @@ // Rotate by splat - expand back to shifts. // TODO - legalizers should be able to handle this. if (EltSizeInBits >= 16 || Subtarget.hasBWI()) { - int SplatIdx; - if (IsSplatVector(Amt, SplatIdx)) { + if (DAG.isSplatValue(Amt)) { SDValue AmtR = DAG.getConstant(EltSizeInBits, DL, VT); AmtR = DAG.getNode(ISD::SUB, DL, VT, AmtR, Amt); SDValue SHL = DAG.getNode(ISD::SHL, DL, VT, R, Amt);