@@ -5480,55 +5480,84 @@ LowerAsSplatVectorLoad(SDValue SrcOp, MVT VT, SDLoc dl, SelectionDAG &DAG) {
5480
5480
/// elements can be replaced by a single large load which has the same value as
5481
5481
/// a build_vector or insert_subvector whose loaded operands are 'Elts'.
5482
5482
///
5483
- /// Example: <load i32 *a, load i32 *a+4, undef, undef> -> zextload a
5484
- ///
5485
- /// FIXME: we'd also like to handle the case where the last elements are zero
5486
- /// rather than undef via VZEXT_LOAD, but we do not detect that case today.
5487
- /// There's even a handy isZeroNode for that purpose.
5483
+ /// Example: <load i32 *a, load i32 *a+4, zero, undef> -> zextload a
5488
5484
static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
5489
5485
SDLoc &DL, SelectionDAG &DAG,
5490
5486
bool isAfterLegalize) {
5491
5487
unsigned NumElems = Elts.size();
5492
5488
5493
- LoadSDNode *LDBase = nullptr;
5494
- unsigned LastLoadedElt = -1U;
5489
+ int LastLoadedElt = -1;
5490
+ SmallBitVector LoadMask(NumElems, false);
5491
+ SmallBitVector ZeroMask(NumElems, false);
5492
+ SmallBitVector UndefMask(NumElems, false);
5493
+
5494
+ auto PeekThroughBitcast = [](SDValue V) {
5495
+ while (V.getNode() && V.getOpcode() == ISD::BITCAST)
5496
+ V = V.getOperand(0);
5497
+ return V;
5498
+ };
5495
5499
5496
- // For each element in the initializer, see if we've found a load or an undef.
5497
- // If we don't find an initial load element, or later load elements are
5498
- // non-consecutive, bail out.
5500
+ // For each element in the initializer, see if we've found a load, zero or an
5501
+ // undef.
5499
5502
for (unsigned i = 0; i < NumElems; ++i) {
5500
- SDValue Elt = Elts[i];
5501
- // Look through a bitcast.
5502
- if (Elt.getNode() && Elt.getOpcode() == ISD::BITCAST)
5503
- Elt = Elt.getOperand(0);
5504
- if (!Elt.getNode() ||
5505
- (Elt.getOpcode() != ISD::UNDEF && !ISD::isNON_EXTLoad(Elt.getNode())))
5503
+ SDValue Elt = PeekThroughBitcast(Elts[i]);
5504
+ if (!Elt.getNode())
5506
5505
return SDValue();
5507
- if (!LDBase) {
5508
- if (Elt.getNode()->getOpcode() == ISD::UNDEF)
5509
- return SDValue();
5510
- LDBase = cast<LoadSDNode>(Elt.getNode());
5511
- LastLoadedElt = i;
5512
- continue;
5513
- }
5514
- if (Elt.getOpcode() == ISD::UNDEF)
5515
- continue;
5516
5506
5517
- LoadSDNode *LD = cast<LoadSDNode>(Elt);
5518
- EVT LdVT = Elt.getValueType();
5519
- // Each loaded element must be the correct fractional portion of the
5520
- // requested vector load.
5521
- if (LdVT.getSizeInBits() != VT.getSizeInBits() / NumElems)
5522
- return SDValue();
5523
- if (!DAG.isConsecutiveLoad(LD, LDBase, LdVT.getSizeInBits() / 8, i))
5507
+ if (Elt.isUndef())
5508
+ UndefMask[i] = true;
5509
+ else if (X86::isZeroNode(Elt) || ISD::isBuildVectorAllZeros(Elt.getNode()))
5510
+ ZeroMask[i] = true;
5511
+ else if (ISD::isNON_EXTLoad(Elt.getNode())) {
5512
+ LoadMask[i] = true;
5513
+ LastLoadedElt = i;
5514
+ // Each loaded element must be the correct fractional portion of the
5515
+ // requested vector load.
5516
+ if ((NumElems * Elt.getValueSizeInBits()) != VT.getSizeInBits())
5517
+ return SDValue();
5518
+ } else
5524
5519
return SDValue();
5525
- LastLoadedElt = i;
5526
5520
}
5521
+ assert((ZeroMask | UndefMask | LoadMask).count() == NumElems &&
5522
+ "Incomplete element masks");
5527
5523
5524
+ // Handle Special Cases - all undef or undef/zero.
5525
+ if (UndefMask.count() == NumElems)
5526
+ return DAG.getUNDEF(VT);
5527
+
5528
+ // FIXME: Should we return this as a BUILD_VECTOR instead?
5529
+ if ((ZeroMask | UndefMask).count() == NumElems)
5530
+ return VT.isInteger() ? DAG.getConstant(0, DL, VT)
5531
+ : DAG.getConstantFP(0.0, DL, VT);
5532
+
5533
+ int FirstLoadedElt = LoadMask.find_first();
5534
+ SDValue EltBase = PeekThroughBitcast(Elts[FirstLoadedElt]);
5535
+ LoadSDNode *LDBase = cast<LoadSDNode>(EltBase);
5536
+ EVT LDBaseVT = EltBase.getValueType();
5537
+
5538
+ // Consecutive loads can contain UNDEFS but not ZERO elements.
5539
+ bool IsConsecutiveLoad = true;
5540
+ for (int i = FirstLoadedElt + 1; i <= LastLoadedElt; ++i) {
5541
+ if (LoadMask[i]) {
5542
+ SDValue Elt = PeekThroughBitcast(Elts[i]);
5543
+ LoadSDNode *LD = cast<LoadSDNode>(Elt);
5544
+ if (!DAG.isConsecutiveLoad(LD, LDBase,
5545
+ Elt.getValueType().getStoreSizeInBits() / 8,
5546
+ i - FirstLoadedElt)) {
5547
+ IsConsecutiveLoad = false;
5548
+ break;
5549
+ }
5550
+ } else if (ZeroMask[i]) {
5551
+ IsConsecutiveLoad = false;
5552
+ break;
5553
+ }
5554
+ }
5555
+
5556
+ // LOAD - all consecutive load/undefs (must start/end with a load).
5528
5557
// If we have found an entire vector of loads and undefs, then return a large
5529
- // load of the entire vector width starting at the base pointer. If we found
5530
- // consecutive loads for the low half, generate a vzext_load node.
5531
- if ( LastLoadedElt == NumElems - 1) {
5558
+ // load of the entire vector width starting at the base pointer.
5559
+ if (IsConsecutiveLoad && FirstLoadedElt == 0 &&
5560
+ LastLoadedElt == (int)( NumElems - 1) && ZeroMask.none() ) {
5532
5561
assert(LDBase && "Did not find base load for merging consecutive loads");
5533
5562
EVT EltVT = LDBase->getValueType(0);
5534
5563
// Ensure that the input vector size for the merged loads matches the
@@ -5548,9 +5577,9 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
5548
5577
LDBase->getAlignment());
5549
5578
5550
5579
if (LDBase->hasAnyUseOfValue(1)) {
5551
- SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
5552
- SDValue(LDBase, 1),
5553
- SDValue(NewLd.getNode(), 1));
5580
+ SDValue NewChain =
5581
+ DAG.getNode(ISD::TokenFactor, DL, MVT::Other, SDValue(LDBase, 1),
5582
+ SDValue(NewLd.getNode(), 1));
5554
5583
DAG.ReplaceAllUsesOfValueWith(SDValue(LDBase, 1), NewChain);
5555
5584
DAG.UpdateNodeOperands(NewChain.getNode(), SDValue(LDBase, 1),
5556
5585
SDValue(NewLd.getNode(), 1));
@@ -5559,11 +5588,14 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
5559
5588
return NewLd;
5560
5589
}
5561
5590
5562
- //TODO: The code below fires only for for loading the low v2i32 / v2f32
5563
- //of a v4i32 / v4f32. It's probably worth generalizing.
5564
- EVT EltVT = VT.getVectorElementType();
5565
- if (NumElems == 4 && LastLoadedElt == 1 && (EltVT.getSizeInBits() == 32) &&
5566
- DAG.getTargetLoweringInfo().isTypeLegal(MVT::v2i64)) {
5591
+ int LoadSize =
5592
+ (1 + LastLoadedElt - FirstLoadedElt) * LDBaseVT.getStoreSizeInBits();
5593
+
5594
+ // VZEXT_LOAD - consecutive load/undefs followed by zeros/undefs.
5595
+ // TODO: The code below fires only for for loading the low 64-bits of a
5596
+ // of a 128-bit vector. It's probably worth generalizing more.
5597
+ if (IsConsecutiveLoad && FirstLoadedElt == 0 && VT.is128BitVector() &&
5598
+ (LoadSize == 64 && DAG.getTargetLoweringInfo().isTypeLegal(MVT::v2i64))) {
5567
5599
SDVTList Tys = DAG.getVTList(MVT::v2i64, MVT::Other);
5568
5600
SDValue Ops[] = { LDBase->getChain(), LDBase->getBasePtr() };
5569
5601
SDValue ResNode =
@@ -5577,8 +5609,9 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef<SDValue> Elts,
5577
5609
// terms of dependency. We create a TokenFactor for LDBase and ResNode, and
5578
5610
// update uses of LDBase's output chain to use the TokenFactor.
5579
5611
if (LDBase->hasAnyUseOfValue(1)) {
5580
- SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other,
5581
- SDValue(LDBase, 1), SDValue(ResNode.getNode(), 1));
5612
+ SDValue NewChain =
5613
+ DAG.getNode(ISD::TokenFactor, DL, MVT::Other, SDValue(LDBase, 1),
5614
+ SDValue(ResNode.getNode(), 1));
5582
5615
DAG.ReplaceAllUsesOfValueWith(SDValue(LDBase, 1), NewChain);
5583
5616
DAG.UpdateNodeOperands(NewChain.getNode(), SDValue(LDBase, 1),
5584
5617
SDValue(ResNode.getNode(), 1));
@@ -6551,15 +6584,17 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
6551
6584
if (IsAllConstants)
6552
6585
return SDValue();
6553
6586
6554
- // For AVX-length vectors, see if we can use a vector load to get all of the
6555
- // elements, otherwise build the individual 128-bit pieces and use
6556
- // shuffles to put them in place.
6557
- if (VT.is256BitVector() || VT.is512BitVector()) {
6587
+ // See if we can use a vector load to get all of the elements.
6588
+ if (VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) {
6558
6589
SmallVector<SDValue, 64> V(Op->op_begin(), Op->op_begin() + NumElems);
6559
-
6560
- // Check for a build vector of consecutive loads.
6561
6590
if (SDValue LD = EltsFromConsecutiveLoads(VT, V, dl, DAG, false))
6562
6591
return LD;
6592
+ }
6593
+
6594
+ // For AVX-length vectors, build the individual 128-bit pieces and use
6595
+ // shuffles to put them in place.
6596
+ if (VT.is256BitVector() || VT.is512BitVector()) {
6597
+ SmallVector<SDValue, 64> V(Op->op_begin(), Op->op_begin() + NumElems);
6563
6598
6564
6599
EVT HVT = EVT::getVectorVT(*DAG.getContext(), ExtVT, NumElems/2);
6565
6600
@@ -6648,10 +6683,6 @@ X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
6648
6683
for (unsigned i = 0; i < NumElems; ++i)
6649
6684
V[i] = Op.getOperand(i);
6650
6685
6651
- // Check for elements which are consecutive loads.
6652
- if (SDValue LD = EltsFromConsecutiveLoads(VT, V, dl, DAG, false))
6653
- return LD;
6654
-
6655
6686
// Check for a build vector from mostly shuffle plus few inserting.
6656
6687
if (SDValue Sh = buildFromShuffleMostly(Op, DAG))
6657
6688
return Sh;
0 commit comments