Index: llvm/trunk/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/trunk/lib/Target/X86/X86ISelLowering.cpp +++ llvm/trunk/lib/Target/X86/X86ISelLowering.cpp @@ -5480,55 +5480,84 @@ /// elements can be replaced by a single large load which has the same value as /// a build_vector or insert_subvector whose loaded operands are 'Elts'. /// -/// Example: -> zextload a -/// -/// FIXME: we'd also like to handle the case where the last elements are zero -/// rather than undef via VZEXT_LOAD, but we do not detect that case today. -/// There's even a handy isZeroNode for that purpose. +/// Example: -> zextload a static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef Elts, SDLoc &DL, SelectionDAG &DAG, bool isAfterLegalize) { unsigned NumElems = Elts.size(); - LoadSDNode *LDBase = nullptr; - unsigned LastLoadedElt = -1U; + int LastLoadedElt = -1; + SmallBitVector LoadMask(NumElems, false); + SmallBitVector ZeroMask(NumElems, false); + SmallBitVector UndefMask(NumElems, false); + + auto PeekThroughBitcast = [](SDValue V) { + while (V.getNode() && V.getOpcode() == ISD::BITCAST) + V = V.getOperand(0); + return V; + }; - // For each element in the initializer, see if we've found a load or an undef. - // If we don't find an initial load element, or later load elements are - // non-consecutive, bail out. + // For each element in the initializer, see if we've found a load, zero or an + // undef. for (unsigned i = 0; i < NumElems; ++i) { - SDValue Elt = Elts[i]; - // Look through a bitcast. - if (Elt.getNode() && Elt.getOpcode() == ISD::BITCAST) - Elt = Elt.getOperand(0); - if (!Elt.getNode() || - (Elt.getOpcode() != ISD::UNDEF && !ISD::isNON_EXTLoad(Elt.getNode()))) + SDValue Elt = PeekThroughBitcast(Elts[i]); + if (!Elt.getNode()) return SDValue(); - if (!LDBase) { - if (Elt.getNode()->getOpcode() == ISD::UNDEF) - return SDValue(); - LDBase = cast(Elt.getNode()); - LastLoadedElt = i; - continue; - } - if (Elt.getOpcode() == ISD::UNDEF) - continue; - LoadSDNode *LD = cast(Elt); - EVT LdVT = Elt.getValueType(); - // Each loaded element must be the correct fractional portion of the - // requested vector load. - if (LdVT.getSizeInBits() != VT.getSizeInBits() / NumElems) - return SDValue(); - if (!DAG.isConsecutiveLoad(LD, LDBase, LdVT.getSizeInBits() / 8, i)) + if (Elt.isUndef()) + UndefMask[i] = true; + else if (X86::isZeroNode(Elt) || ISD::isBuildVectorAllZeros(Elt.getNode())) + ZeroMask[i] = true; + else if (ISD::isNON_EXTLoad(Elt.getNode())) { + LoadMask[i] = true; + LastLoadedElt = i; + // Each loaded element must be the correct fractional portion of the + // requested vector load. + if ((NumElems * Elt.getValueSizeInBits()) != VT.getSizeInBits()) + return SDValue(); + } else return SDValue(); - LastLoadedElt = i; } + assert((ZeroMask | UndefMask | LoadMask).count() == NumElems && + "Incomplete element masks"); + // Handle Special Cases - all undef or undef/zero. + if (UndefMask.count() == NumElems) + return DAG.getUNDEF(VT); + + // FIXME: Should we return this as a BUILD_VECTOR instead? + if ((ZeroMask | UndefMask).count() == NumElems) + return VT.isInteger() ? DAG.getConstant(0, DL, VT) + : DAG.getConstantFP(0.0, DL, VT); + + int FirstLoadedElt = LoadMask.find_first(); + SDValue EltBase = PeekThroughBitcast(Elts[FirstLoadedElt]); + LoadSDNode *LDBase = cast(EltBase); + EVT LDBaseVT = EltBase.getValueType(); + + // Consecutive loads can contain UNDEFS but not ZERO elements. + bool IsConsecutiveLoad = true; + for (int i = FirstLoadedElt + 1; i <= LastLoadedElt; ++i) { + if (LoadMask[i]) { + SDValue Elt = PeekThroughBitcast(Elts[i]); + LoadSDNode *LD = cast(Elt); + if (!DAG.isConsecutiveLoad(LD, LDBase, + Elt.getValueType().getStoreSizeInBits() / 8, + i - FirstLoadedElt)) { + IsConsecutiveLoad = false; + break; + } + } else if (ZeroMask[i]) { + IsConsecutiveLoad = false; + break; + } + } + + // LOAD - all consecutive load/undefs (must start/end with a load). // If we have found an entire vector of loads and undefs, then return a large - // load of the entire vector width starting at the base pointer. If we found - // consecutive loads for the low half, generate a vzext_load node. - if (LastLoadedElt == NumElems - 1) { + // load of the entire vector width starting at the base pointer. + if (IsConsecutiveLoad && FirstLoadedElt == 0 && + LastLoadedElt == (int)(NumElems - 1) && ZeroMask.none()) { assert(LDBase && "Did not find base load for merging consecutive loads"); EVT EltVT = LDBase->getValueType(0); // Ensure that the input vector size for the merged loads matches the @@ -5548,9 +5577,9 @@ LDBase->getAlignment()); if (LDBase->hasAnyUseOfValue(1)) { - SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, - SDValue(LDBase, 1), - SDValue(NewLd.getNode(), 1)); + SDValue NewChain = + DAG.getNode(ISD::TokenFactor, DL, MVT::Other, SDValue(LDBase, 1), + SDValue(NewLd.getNode(), 1)); DAG.ReplaceAllUsesOfValueWith(SDValue(LDBase, 1), NewChain); DAG.UpdateNodeOperands(NewChain.getNode(), SDValue(LDBase, 1), SDValue(NewLd.getNode(), 1)); @@ -5559,11 +5588,14 @@ return NewLd; } - //TODO: The code below fires only for for loading the low v2i32 / v2f32 - //of a v4i32 / v4f32. It's probably worth generalizing. - EVT EltVT = VT.getVectorElementType(); - if (NumElems == 4 && LastLoadedElt == 1 && (EltVT.getSizeInBits() == 32) && - DAG.getTargetLoweringInfo().isTypeLegal(MVT::v2i64)) { + int LoadSize = + (1 + LastLoadedElt - FirstLoadedElt) * LDBaseVT.getStoreSizeInBits(); + + // VZEXT_LOAD - consecutive load/undefs followed by zeros/undefs. + // TODO: The code below fires only for for loading the low 64-bits of a + // of a 128-bit vector. It's probably worth generalizing more. + if (IsConsecutiveLoad && FirstLoadedElt == 0 && VT.is128BitVector() && + (LoadSize == 64 && DAG.getTargetLoweringInfo().isTypeLegal(MVT::v2i64))) { SDVTList Tys = DAG.getVTList(MVT::v2i64, MVT::Other); SDValue Ops[] = { LDBase->getChain(), LDBase->getBasePtr() }; SDValue ResNode = @@ -5577,8 +5609,9 @@ // terms of dependency. We create a TokenFactor for LDBase and ResNode, and // update uses of LDBase's output chain to use the TokenFactor. if (LDBase->hasAnyUseOfValue(1)) { - SDValue NewChain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, - SDValue(LDBase, 1), SDValue(ResNode.getNode(), 1)); + SDValue NewChain = + DAG.getNode(ISD::TokenFactor, DL, MVT::Other, SDValue(LDBase, 1), + SDValue(ResNode.getNode(), 1)); DAG.ReplaceAllUsesOfValueWith(SDValue(LDBase, 1), NewChain); DAG.UpdateNodeOperands(NewChain.getNode(), SDValue(LDBase, 1), SDValue(ResNode.getNode(), 1)); @@ -6551,15 +6584,17 @@ if (IsAllConstants) return SDValue(); - // For AVX-length vectors, see if we can use a vector load to get all of the - // elements, otherwise build the individual 128-bit pieces and use - // shuffles to put them in place. - if (VT.is256BitVector() || VT.is512BitVector()) { + // See if we can use a vector load to get all of the elements. + if (VT.is128BitVector() || VT.is256BitVector() || VT.is512BitVector()) { SmallVector V(Op->op_begin(), Op->op_begin() + NumElems); - - // Check for a build vector of consecutive loads. if (SDValue LD = EltsFromConsecutiveLoads(VT, V, dl, DAG, false)) return LD; + } + + // For AVX-length vectors, build the individual 128-bit pieces and use + // shuffles to put them in place. + if (VT.is256BitVector() || VT.is512BitVector()) { + SmallVector V(Op->op_begin(), Op->op_begin() + NumElems); EVT HVT = EVT::getVectorVT(*DAG.getContext(), ExtVT, NumElems/2); @@ -6648,10 +6683,6 @@ for (unsigned i = 0; i < NumElems; ++i) V[i] = Op.getOperand(i); - // Check for elements which are consecutive loads. - if (SDValue LD = EltsFromConsecutiveLoads(VT, V, dl, DAG, false)) - return LD; - // Check for a build vector from mostly shuffle plus few inserting. if (SDValue Sh = buildFromShuffleMostly(Op, DAG)) return Sh; Index: llvm/trunk/test/CodeGen/X86/merge-consecutive-loads-128.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/merge-consecutive-loads-128.ll +++ llvm/trunk/test/CodeGen/X86/merge-consecutive-loads-128.ll @@ -347,18 +347,12 @@ define <8 x i16> @merge_8i16_i16_45u7zzzz(i16* %ptr) nounwind uwtable noinline ssp { ; SSE-LABEL: merge_8i16_i16_45u7zzzz: ; SSE: # BB#0: -; SSE-NEXT: pxor %xmm0, %xmm0 -; SSE-NEXT: pinsrw $0, 8(%rdi), %xmm0 -; SSE-NEXT: pinsrw $1, 10(%rdi), %xmm0 -; SSE-NEXT: pinsrw $3, 14(%rdi), %xmm0 +; SSE-NEXT: movq {{.*#+}} xmm0 = mem[0],zero ; SSE-NEXT: retq ; ; AVX-LABEL: merge_8i16_i16_45u7zzzz: ; AVX: # BB#0: -; AVX-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; AVX-NEXT: vpinsrw $0, 8(%rdi), %xmm0, %xmm0 -; AVX-NEXT: vpinsrw $1, 10(%rdi), %xmm0, %xmm0 -; AVX-NEXT: vpinsrw $3, 14(%rdi), %xmm0, %xmm0 +; AVX-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero ; AVX-NEXT: retq %ptr0 = getelementptr inbounds i16, i16* %ptr, i64 4 %ptr1 = getelementptr inbounds i16, i16* %ptr, i64 5 @@ -478,46 +472,14 @@ } define <16 x i8> @merge_16i8_i8_0123uu67uuuuuzzz(i8* %ptr) nounwind uwtable noinline ssp { -; SSE2-LABEL: merge_16i8_i8_0123uu67uuuuuzzz: -; SSE2: # BB#0: -; SSE2-NEXT: movzbl 2(%rdi), %eax -; SSE2-NEXT: movzbl 3(%rdi), %ecx -; SSE2-NEXT: shll $8, %ecx -; SSE2-NEXT: orl %eax, %ecx -; SSE2-NEXT: movzbl (%rdi), %eax -; SSE2-NEXT: movzbl 1(%rdi), %edx -; SSE2-NEXT: shll $8, %edx -; SSE2-NEXT: orl %eax, %edx -; SSE2-NEXT: pxor %xmm0, %xmm0 -; SSE2-NEXT: pinsrw $0, %edx, %xmm0 -; SSE2-NEXT: pinsrw $1, %ecx, %xmm0 -; SSE2-NEXT: movzbl 6(%rdi), %eax -; SSE2-NEXT: movzbl 7(%rdi), %ecx -; SSE2-NEXT: shll $8, %ecx -; SSE2-NEXT: orl %eax, %ecx -; SSE2-NEXT: pinsrw $3, %ecx, %xmm0 -; SSE2-NEXT: retq -; -; SSE41-LABEL: merge_16i8_i8_0123uu67uuuuuzzz: -; SSE41: # BB#0: -; SSE41-NEXT: pxor %xmm0, %xmm0 -; SSE41-NEXT: pinsrb $0, (%rdi), %xmm0 -; SSE41-NEXT: pinsrb $1, 1(%rdi), %xmm0 -; SSE41-NEXT: pinsrb $2, 2(%rdi), %xmm0 -; SSE41-NEXT: pinsrb $3, 3(%rdi), %xmm0 -; SSE41-NEXT: pinsrb $6, 6(%rdi), %xmm0 -; SSE41-NEXT: pinsrb $7, 7(%rdi), %xmm0 -; SSE41-NEXT: retq +; SSE-LABEL: merge_16i8_i8_0123uu67uuuuuzzz: +; SSE: # BB#0: +; SSE-NEXT: movq {{.*#+}} xmm0 = mem[0],zero +; SSE-NEXT: retq ; ; AVX-LABEL: merge_16i8_i8_0123uu67uuuuuzzz: ; AVX: # BB#0: -; AVX-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; AVX-NEXT: vpinsrb $0, (%rdi), %xmm0, %xmm0 -; AVX-NEXT: vpinsrb $1, 1(%rdi), %xmm0, %xmm0 -; AVX-NEXT: vpinsrb $2, 2(%rdi), %xmm0, %xmm0 -; AVX-NEXT: vpinsrb $3, 3(%rdi), %xmm0, %xmm0 -; AVX-NEXT: vpinsrb $6, 6(%rdi), %xmm0, %xmm0 -; AVX-NEXT: vpinsrb $7, 7(%rdi), %xmm0, %xmm0 +; AVX-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero ; AVX-NEXT: retq %ptr0 = getelementptr inbounds i8, i8* %ptr, i64 0 %ptr1 = getelementptr inbounds i8, i8* %ptr, i64 1