Skip to content

Commit

Permalink
[SelectionDAG] Add demanded elts variants to isConstOrConstSplat help…
Browse files Browse the repository at this point in the history
…ers. NFCI.

These helpers extend the existing isConstOrConstSplat helper checks to support DemandedElts masks as well.

We already had a local version of this in SelectionDAG that computeKnownBits/ComputeNumSignBits made use of, but this adds the functionality directly to the BuildVectorSDNode node and extends isConstOrConstSplat etc. to use that.

This will allow us to reuse the functionality in SimplifyDemandedVectorElts/SimplifyDemandedBits.

Differential Revision: https://reviews.llvm.org/D58503

llvm-svn: 354797
  • Loading branch information
RKSimon committed Feb 25, 2019
1 parent 8a7f4c9 commit 80d0e9c
Showing 2 changed files with 113 additions and 37 deletions.
39 changes: 39 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
@@ -1629,9 +1629,19 @@ bool isBitwiseNot(SDValue V);
/// Returns the SDNode if it is a constant splat BuildVector or constant int.
ConstantSDNode *isConstOrConstSplat(SDValue N, bool AllowUndefs = false);

/// Returns the SDNode if it is a demanded constant splat BuildVector or
/// constant int.
ConstantSDNode *isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
bool AllowUndefs = false);

/// Returns the SDNode if it is a constant splat BuildVector or constant float.
ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, bool AllowUndefs = false);

/// Returns the SDNode if it is a demanded constant splat BuildVector or
/// constant float.
ConstantFPSDNode *isConstOrConstSplatFP(SDValue N, const APInt &DemandedElts,
bool AllowUndefs = false);

/// Return true if the value is a constant 0 integer or a splatted vector of
/// a constant 0 integer (with no undefs by default).
/// Build vector implicit truncation is not an issue for null values.
@@ -1868,12 +1878,31 @@ class BuildVectorSDNode : public SDNode {
unsigned MinSplatBits = 0,
bool isBigEndian = false) const;

/// Returns the demanded splatted value or a null value if this is not a
/// splat.
///
/// The DemandedElts mask indicates the elements that must be in the splat.
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the vector width and set the bits where elements are undef.
SDValue getSplatValue(const APInt &DemandedElts,
BitVector *UndefElements = nullptr) const;

/// Returns the splatted value or a null value if this is not a splat.
///
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the vector width and set the bits where elements are undef.
SDValue getSplatValue(BitVector *UndefElements = nullptr) const;

/// Returns the demanded splatted constant or null if this is not a constant
/// splat.
///
/// The DemandedElts mask indicates the elements that must be in the splat.
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the vector width and set the bits where elements are undef.
ConstantSDNode *
getConstantSplatNode(const APInt &DemandedElts,
BitVector *UndefElements = nullptr) const;

/// Returns the splatted constant or null if this is not a constant
/// splat.
///
@@ -1882,6 +1911,16 @@ class BuildVectorSDNode : public SDNode {
ConstantSDNode *
getConstantSplatNode(BitVector *UndefElements = nullptr) const;

/// Returns the demanded splatted constant FP or null if this is not a
/// constant FP splat.
///
/// The DemandedElts mask indicates the elements that must be in the splat.
/// If passed a non-null UndefElements bitvector, it will resize it to match
/// the vector width and set the bits where elements are undef.
ConstantFPSDNode *
getConstantFPSplatNode(const APInt &DemandedElts,
BitVector *UndefElements = nullptr) const;

/// Returns the splatted constant FP or null if this is not a constant
/// FP splat.
///
111 changes: 74 additions & 37 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
@@ -2253,30 +2253,6 @@ bool SelectionDAG::isSplatValue(SDValue V, bool AllowUndefs) {
(AllowUndefs || !UndefElts);
}

/// Helper function that checks to see if a node is a constant or a
/// build vector of splat constants at least within the demanded elts.
static ConstantSDNode *isConstOrDemandedConstSplat(SDValue N,
const APInt &DemandedElts) {
if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
return CN;
if (N.getOpcode() != ISD::BUILD_VECTOR)
return nullptr;
EVT VT = N.getValueType();
ConstantSDNode *Cst = nullptr;
unsigned NumElts = VT.getVectorNumElements();
assert(DemandedElts.getBitWidth() == NumElts && "Unexpected vector size");
for (unsigned i = 0; i != NumElts; ++i) {
if (!DemandedElts[i])
continue;
ConstantSDNode *C = dyn_cast<ConstantSDNode>(N.getOperand(i));
if (!C || (Cst && Cst->getAPIntValue() != C->getAPIntValue()) ||
C->getValueType(0) != VT.getScalarType())
return nullptr;
Cst = C;
}
return Cst;
}

/// If a SHL/SRA/SRL node has a constant or splat constant shift amount that
/// is less than the element bit-width of the shift node, return it.
static const APInt *getValidShiftAmountConstant(SDValue V) {
@@ -2717,8 +2693,7 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
break;
case ISD::FSHL:
case ISD::FSHR:
if (ConstantSDNode *C =
isConstOrDemandedConstSplat(Op.getOperand(2), DemandedElts)) {
if (ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(2), DemandedElts)) {
unsigned Amt = C->getAPIntValue().urem(BitWidth);

// For fshl, 0-shift returns the 1st arg.
@@ -3155,10 +3130,10 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
// the minimum of the clamp min/max range.
bool IsMax = (Opcode == ISD::SMAX);
ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)))
if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1),
DemandedElts);
CstHigh =
isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
if (CstLow && CstHigh) {
if (!IsMax)
std::swap(CstLow, CstHigh);
@@ -3439,15 +3414,15 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
// SRA X, C -> adds C sign bits.
if (ConstantSDNode *C =
isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) {
isConstOrConstSplat(Op.getOperand(1), DemandedElts)) {
APInt ShiftVal = C->getAPIntValue();
ShiftVal += Tmp;
Tmp = ShiftVal.uge(VTBits) ? VTBits : ShiftVal.getZExtValue();
}
return Tmp;
case ISD::SHL:
if (ConstantSDNode *C =
isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)) {
isConstOrConstSplat(Op.getOperand(1), DemandedElts)) {
// shl destroys sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth+1);
if (C->getAPIntValue().uge(VTBits) || // Bad shift.
@@ -3487,10 +3462,10 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
// the minimum of the clamp min/max range.
bool IsMax = (Opcode == ISD::SMAX);
ConstantSDNode *CstLow = nullptr, *CstHigh = nullptr;
if ((CstLow = isConstOrDemandedConstSplat(Op.getOperand(1), DemandedElts)))
if ((CstLow = isConstOrConstSplat(Op.getOperand(1), DemandedElts)))
if (Op.getOperand(0).getOpcode() == (IsMax ? ISD::SMIN : ISD::SMAX))
CstHigh = isConstOrDemandedConstSplat(Op.getOperand(0).getOperand(1),
DemandedElts);
CstHigh =
isConstOrConstSplat(Op.getOperand(0).getOperand(1), DemandedElts);
if (CstLow && CstHigh) {
if (!IsMax)
std::swap(CstLow, CstHigh);
@@ -8593,6 +8568,24 @@ ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, bool AllowUndefs) {
return nullptr;
}

ConstantSDNode *llvm::isConstOrConstSplat(SDValue N, const APInt &DemandedElts,
bool AllowUndefs) {
if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(N))
return CN;

if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
BitVector UndefElements;
ConstantSDNode *CN = BV->getConstantSplatNode(DemandedElts, &UndefElements);

// BuildVectors can truncate their operands. Ignore that case here.
if (CN && (UndefElements.none() || AllowUndefs) &&
CN->getValueType(0) == N.getValueType().getScalarType())
return CN;
}

return nullptr;
}

ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) {
if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
return CN;
@@ -8607,6 +8600,23 @@ ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N, bool AllowUndefs) {
return nullptr;
}

ConstantFPSDNode *llvm::isConstOrConstSplatFP(SDValue N,
const APInt &DemandedElts,
bool AllowUndefs) {
if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
return CN;

if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
BitVector UndefElements;
ConstantFPSDNode *CN =
BV->getConstantFPSplatNode(DemandedElts, &UndefElements);
if (CN && (UndefElements.none() || AllowUndefs))
return CN;
}

return nullptr;
}

bool llvm::isNullOrNullSplat(SDValue N, bool AllowUndefs) {
// TODO: may want to use peekThroughBitcast() here.
ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
@@ -9193,13 +9203,20 @@ bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue, APInt &SplatUndef,
return true;
}

SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
SDValue BuildVectorSDNode::getSplatValue(const APInt &DemandedElts,
BitVector *UndefElements) const {
if (UndefElements) {
UndefElements->clear();
UndefElements->resize(getNumOperands());
}
assert(getNumOperands() == DemandedElts.getBitWidth() &&
"Unexpected vector size");
if (!DemandedElts)
return SDValue();
SDValue Splatted;
for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
if (!DemandedElts[i])
continue;
SDValue Op = getOperand(i);
if (Op.isUndef()) {
if (UndefElements)
@@ -9212,19 +9229,39 @@ SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
}

if (!Splatted) {
assert(getOperand(0).isUndef() &&
unsigned FirstDemandedIdx = DemandedElts.countTrailingZeros();
assert(getOperand(FirstDemandedIdx).isUndef() &&
"Can only have a splat without a constant for all undefs.");
return getOperand(0);
return getOperand(FirstDemandedIdx);
}

return Splatted;
}

SDValue BuildVectorSDNode::getSplatValue(BitVector *UndefElements) const {
APInt DemandedElts = APInt::getAllOnesValue(getNumOperands());
return getSplatValue(DemandedElts, UndefElements);
}

ConstantSDNode *
BuildVectorSDNode::getConstantSplatNode(const APInt &DemandedElts,
BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantSDNode>(
getSplatValue(DemandedElts, UndefElements));
}

ConstantSDNode *
BuildVectorSDNode::getConstantSplatNode(BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantSDNode>(getSplatValue(UndefElements));
}

ConstantFPSDNode *
BuildVectorSDNode::getConstantFPSplatNode(const APInt &DemandedElts,
BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantFPSDNode>(
getSplatValue(DemandedElts, UndefElements));
}

ConstantFPSDNode *
BuildVectorSDNode::getConstantFPSplatNode(BitVector *UndefElements) const {
return dyn_cast_or_null<ConstantFPSDNode>(getSplatValue(UndefElements));

0 comments on commit 80d0e9c

Please sign in to comment.