Index: llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h +++ llvm/trunk/include/llvm/CodeGen/SelectionDAGNodes.h @@ -1629,9 +1629,19 @@ /// Returns the SDNode if it is a constant splat BuildVector or constant int. ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false); +/// Returns the SDNode if it is a demanded constant splat BuildVector or +/// constant int. +ConstantSDNode *isConstOrConstSplat(SDValue N, const APInt &DemandedElts, + bool AllowUndefs = false); + /// Returns the SDNode if it is a constant splat BuildVector or constant float. ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, bool AllowUndefs = false); +/// Returns the SDNode if it is a demanded constant splat BuildVector or +/// constant float. +ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, const APInt &DemandedElts, + bool AllowUndefs = false); + /// Return true if the value is a constant 0 integer or a splatted vector of /// a constant 0 integer (with no undefs by default). /// Build vector implicit truncation is not an issue for null values. @@ -1868,12 +1878,31 @@ unsigned MinSplatBits = 0, bool isBigEndian = false) const; + /// Returns the demanded splatted value or a null value if this is not a + /// splat. + /// + /// The DemandedElts mask indicates the elements that must be in the splat. + /// If passed a non-null UndefElements bitvector, it will resize it to match + /// the vector width and set the bits where elements are undef. + SDValue getSplatValue(const APInt &DemandedElts, + BitVector *UndefElements = nullptr) const; + /// Returns the splatted value or a null value if this is not a splat. /// /// If passed a non-null UndefElements bitvector, it will resize it to match /// the vector width and set the bits where elements are undef. SDValue getSplatValue(BitVector *UndefElements = nullptr) const; + /// Returns the demanded splatted constant or null if this is not a constant + /// splat. + /// + /// The DemandedElts mask indicates the elements that must be in the splat. + /// If passed a non-null UndefElements bitvector, it will resize it to match + /// the vector width and set the bits where elements are undef. + ConstantSDNode * + getConstantSplatNode(const APInt &DemandedElts, + BitVector *UndefElements = nullptr) const; + /// Returns the splatted constant or null if this is not a constant /// splat. /// @@ -1882,6 +1911,16 @@ ConstantSDNode * getConstantSplatNode(BitVector *UndefElements = nullptr) const; + /// Returns the demanded splatted constant FP or null if this is not a + /// constant FP splat. + /// + /// The DemandedElts mask indicates the elements that must be in the splat. + /// If passed a non-null UndefElements bitvector, it will resize it to match + /// the vector width and set the bits where elements are undef. + ConstantFPSDNode * + getConstantFPSplatNode(const APInt &DemandedElts, + BitVector *UndefElements = nullptr) const; + /// Returns the splatted constant FP or null if this is not a constant /// FP splat. /// Index: llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -2253,30 +2253,6 @@ (AllowUndefs || !UndefElts); } -/// Helper function that checks to see if a node is a constant or a -/// build vector of splat constants at least within the demanded elts. -static ConstantSDNode *isConstOrDemandedConstSplat(SDValue N, - const APInt &DemandedElts) { - if (ConstantSDNode *CN = dyn_cast(N)) - return CN; - if (N.getOpcode() != ISD::BUILD_VECTOR) - return nullptr; - EVT VT = N.getValueType(); - ConstantSDNode *Cst = nullptr; - unsigned NumElts = VT.getVectorNumElements(); - assert(DemandedElts.getBitWidth() == NumElts && "Unexpected vector size"); - for (unsigned i = 0; i != NumElts; ++i) { - if (!DemandedElts[i]) - continue; - ConstantSDNode *C = dyn_cast(N.getOperand(i)); - if (!C || (Cst && Cst->getAPIntValue() != C->getAPIntValue()) || - C->getValueType(0) != VT.getScalarType()) - return nullptr; - Cst = C; - } - return Cst; -} - /// If a SHL/SRA/SRL node has a constant or splat constant shift amount that /// is less than the element bit-width of the shift node, return it. static const APInt *getValidShiftAmountConstant(SDValue V) { @@ -2717,8 +2693,7 @@ break; case ISD::FSHL: case ISD::FSHR: - if (ConstantSDNode *C = - isConstOrDemandedConstSplat(Op.getOperand(2), DemandedElts)) { + if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(2), DemandedElts)) { unsigned Amt = C->getAPIntValue().urem(BitWidth); // For fshl, 0-shift returns the 1st arg. @@ -3155,10 +3130,10 @@ // the minimum of the clamp min/max range. bool IsMax = (Opcode == ISD::SMAX); ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr; - if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts))) + if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts))) if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX)) - CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1), - DemandedElts); + CstHigh = + isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts); if (CstLow && CstHigh) { if (!IsMax) std::swap(CstLow, CstHigh); @@ -3439,7 +3414,7 @@ Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1); // SRA X, C -> adds C sign bits. if (ConstantSDNode *C = - isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) { + isConstOrConstSplat(Op.getOperand(1), DemandedElts)) { APInt ShiftVal = C->getAPIntValue(); ShiftVal += Tmp; Tmp = ShiftVal.uge(VTBits) ? VTBits : ShiftVal.getZExtValue(); @@ -3447,7 +3422,7 @@ return Tmp; case ISD::SHL: if (ConstantSDNode *C = - isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) { + isConstOrConstSplat(Op.getOperand(1), DemandedElts)) { // shl destroys sign bits. Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1); if (C->getAPIntValue().uge(VTBits) || // Bad shift. @@ -3487,10 +3462,10 @@ // the minimum of the clamp min/max range. bool IsMax = (Opcode == ISD::SMAX); ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr; - if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts))) + if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts))) if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX)) - CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1), - DemandedElts); + CstHigh = + isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts); if (CstLow && CstHigh) { if (!IsMax) std::swap(CstLow, CstHigh); @@ -8593,6 +8568,24 @@ return nullptr; } +ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts, + bool AllowUndefs) { + if (ConstantSDNode *CN = dyn_cast(N)) + return CN; + + if (BuildVectorSDNode *BV = dyn_cast(N)) { + BitVector UndefElements; + ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements); + + // BuildVectors can truncate their operands. Ignore that case here. + if (CN && (UndefElements.none() || AllowUndefs) && + CN->getValueType(0) == N.getValueType().getScalarType()) + return CN; + } + + return nullptr; +} + ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) { if (ConstantFPSDNode *CN = dyn_cast(N)) return CN; @@ -8607,6 +8600,23 @@ return nullptr; } +ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, + const APInt &DemandedElts, + bool AllowUndefs) { + if (ConstantFPSDNode *CN = dyn_cast(N)) + return CN; + + if (BuildVectorSDNode *BV = dyn_cast(N)) { + BitVector UndefElements; + ConstantFPSDNode *CN = + BV->getConstantFPSplatNode(DemandedElts, &UndefElements); + if (CN && (UndefElements.none() || AllowUndefs)) + return CN; + } + + return nullptr; +} + bool llvm::isNullOrNullSplat(SDValue N, bool AllowUndefs) { // TODO: may want to use peekThroughBitcast() here. ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs); @@ -9193,13 +9203,20 @@ return true; } -SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const { +SDValue BuildVectorSDNode::getSplatValue(const APInt &DemandedElts, + BitVector *UndefElements) const { if (UndefElements) { UndefElements->clear(); UndefElements->resize(getNumOperands()); } + assert(getNumOperands() == DemandedElts.getBitWidth() && + "Unexpected vector size"); + if (!DemandedElts) + return SDValue(); SDValue Splatted; for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { + if (!DemandedElts[i]) + continue; SDValue Op = getOperand(i); if (Op.isUndef()) { if (UndefElements) @@ -9212,20 +9229,40 @@ } if (!Splatted) { - assert(getOperand(0).isUndef() && + unsigned FirstDemandedIdx = DemandedElts.countTrailingZeros(); + assert(getOperand(FirstDemandedIdx).isUndef() && "Can only have a splat without a constant for all undefs."); - return getOperand(0); + return getOperand(FirstDemandedIdx); } return Splatted; } +SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const { + APInt DemandedElts = APInt::getAllOnesValue(getNumOperands()); + return getSplatValue(DemandedElts, UndefElements); +} + +ConstantSDNode * +BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts, + BitVector *UndefElements) const { + return dyn_cast_or_null( + getSplatValue(DemandedElts, UndefElements)); +} + ConstantSDNode * BuildVectorSDNode::getConstantSplatNode(BitVector *UndefElements) const { return dyn_cast_or_null(getSplatValue(UndefElements)); } ConstantFPSDNode * +BuildVectorSDNode::getConstantFPSplatNode(const APInt &DemandedElts, + BitVector *UndefElements) const { + return dyn_cast_or_null( + getSplatValue(DemandedElts, UndefElements)); +} + +ConstantFPSDNode * BuildVectorSDNode::getConstantFPSplatNode(BitVector *UndefElements) const { return dyn_cast_or_null(getSplatValue(UndefElements)); }