diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -8477,7 +8477,8 @@ static SDValue EltsFromConsecutiveLoads(EVT VT, ArrayRef Elts, const SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget, - bool isAfterLegalize) { + bool IsAfterLegalize, + bool VectorLoadsOnly = false) { if ((VT.getScalarSizeInBits() % 8) != 0) return SDValue(); @@ -8607,7 +8608,7 @@ if (FirstLoadedElt == 0 && (NumLoadedElts == (int)NumElems || IsDereferenceable) && (IsConsecutiveLoad || IsConsecutiveLoadWithZeros)) { - if (isAfterLegalize && !TLI.isOperationLegal(ISD::LOAD, VT)) + if (IsAfterLegalize && !TLI.isOperationLegal(ISD::LOAD, VT)) return SDValue(); // Don't create 256-bit non-temporal aligned loads without AVX2 as these @@ -8624,7 +8625,7 @@ // IsConsecutiveLoadWithZeros - we need to create a shuffle of the loaded // vector and a zero vector to clear out the zero elements. - if (!isAfterLegalize && VT.isVector()) { + if (!VectorLoadsOnly && !IsAfterLegalize && VT.isVector()) { unsigned NumMaskElts = VT.getVectorNumElements(); if ((NumMaskElts % NumElems) == 0) { unsigned Scale = NumMaskElts / NumElems; @@ -8644,6 +8645,9 @@ } } + if (VectorLoadsOnly) + return SDValue(); + // If the upper half of a ymm/zmm load is undef then just load the lower half. if (VT.is256BitVector() || VT.is512BitVector()) { unsigned HalfNumElems = NumElems / 2; @@ -8652,7 +8656,7 @@ EVT::getVectorVT(*DAG.getContext(), VT.getScalarType(), HalfNumElems); SDValue HalfLD = EltsFromConsecutiveLoads(HalfVT, Elts.drop_back(HalfNumElems), DL, - DAG, Subtarget, isAfterLegalize); + DAG, Subtarget, IsAfterLegalize); if (HalfLD) return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), HalfLD, DAG.getIntPtrConstant(0, DL)); @@ -8728,7 +8732,7 @@ VT.getSizeInBits() / ScalarSize); if (TLI.isTypeLegal(BroadcastVT)) { if (SDValue RepeatLoad = EltsFromConsecutiveLoads( - RepeatVT, RepeatedLoads, DL, DAG, Subtarget, isAfterLegalize)) { + RepeatVT, RepeatedLoads, DL, DAG, Subtarget, IsAfterLegalize)) { SDValue Broadcast = RepeatLoad; if (RepeatSize > ScalarSize) { while (Broadcast.getValueSizeInBits() < VT.getSizeInBits()) @@ -8752,7 +8756,8 @@ static SDValue combineToConsecutiveLoads(EVT VT, SDValue Op, const SDLoc &DL, SelectionDAG &DAG, const X86Subtarget &Subtarget, - bool isAfterLegalize) { + bool IsAfterLegalize, + bool VectorLoadsOnly = false) { SmallVector Elts; for (unsigned i = 0, e = VT.getVectorNumElements(); i != e; ++i) { if (SDValue Elt = getShuffleScalarElt(Op, i, DAG, 0)) { @@ -8763,7 +8768,7 @@ } assert(Elts.size() == VT.getVectorNumElements()); return EltsFromConsecutiveLoads(VT, Elts, DL, DAG, Subtarget, - isAfterLegalize); + IsAfterLegalize, VectorLoadsOnly); } static Constant *getConstantVector(MVT VT, const APInt &SplatValue, @@ -38466,8 +38471,8 @@ return AddSub; // Attempt to combine into a vector load/broadcast. - if (SDValue LD = combineToConsecutiveLoads(VT, SDValue(N, 0), dl, DAG, - Subtarget, true)) + if (SDValue LD = combineToConsecutiveLoads( + VT, SDValue(N, 0), dl, DAG, Subtarget, /*IsAfterLegalize*/ true)) return LD; // For AVX2, we sometimes want to combine @@ -50617,7 +50622,9 @@ return SDValue(); } -static SDValue combineScalarToVector(SDNode *N, SelectionDAG &DAG) { +static SDValue combineScalarToVector(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { EVT VT = N->getValueType(0); SDValue Src = N->getOperand(0); SDLoc DL(N); @@ -50667,10 +50674,11 @@ Src.getOperand(0).getValueType() == MVT::x86mmx) return DAG.getNode(X86ISD::MOVQ2DQ, DL, VT, Src.getOperand(0)); - // See if we're broadcasting the scalar value, in which case just reuse that. - // Ensure the same SDValue from the SDNode use is being used. - if (VT.getScalarType() == Src.getValueType()) - for (SDNode *User : Src->uses()) + // Ensure we don't have any implicit truncation. + if (VT.getScalarType() == Src.getValueType()) { + // See if we're broadcasting the scalar value, in which case just reuse + // that. Ensure the same SDValue from the SDNode use is being used. + for (SDNode *User : Src->uses()) { if (User->getOpcode() == X86ISD::VBROADCAST && Src == User->getOperand(0)) { unsigned SizeInBits = VT.getFixedSizeInBits(); @@ -50683,6 +50691,21 @@ // TODO: Handle BroadcastSizeInBits < SizeInBits when we have test // coverage. } + } + + // Attempt to combine into a vector load. + if (auto *Ld = dyn_cast(peekThroughBitcasts(Src))) { + bool Fast; + const X86TargetLowering *TLI = Subtarget.getTargetLowering(); + if (N->isOnlyUserOf(Src.getNode()) && + TLI->allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VT, + *Ld->getMemOperand(), &Fast) && Fast) + if (SDValue LD = combineToConsecutiveLoads( + VT, SDValue(N, 0), DL, DAG, Subtarget, DCI.isAfterLegalizeDAG(), + /*VectorLoadsOnly*/ true)) + return LD; + } + } return SDValue(); } @@ -51024,7 +51047,7 @@ switch (N->getOpcode()) { default: break; case ISD::SCALAR_TO_VECTOR: - return combineScalarToVector(N, DAG); + return combineScalarToVector(N, DAG, DCI, Subtarget); case ISD::EXTRACT_VECTOR_ELT: case X86ISD::PEXTRW: case X86ISD::PEXTRB: diff --git a/llvm/test/CodeGen/X86/load-partial-dot-product.ll b/llvm/test/CodeGen/X86/load-partial-dot-product.ll --- a/llvm/test/CodeGen/X86/load-partial-dot-product.ll +++ b/llvm/test/CodeGen/X86/load-partial-dot-product.ll @@ -178,9 +178,9 @@ ; ; AVX-LABEL: dot3_float3: ; AVX: # %bb.0: -; AVX-NEXT: vmovsd {{.*#+}} xmm0 = mem[0],zero +; AVX-NEXT: vmovups (%rdi), %xmm0 ; AVX-NEXT: vinsertps {{.*#+}} xmm0 = xmm0[0,1],mem[0],xmm0[3] -; AVX-NEXT: vmovsd {{.*#+}} xmm1 = mem[0],zero +; AVX-NEXT: vmovups (%rsi), %xmm1 ; AVX-NEXT: vinsertps {{.*#+}} xmm1 = xmm1[0,1],mem[0],xmm1[3] ; AVX-NEXT: vmulps %xmm1, %xmm0, %xmm0 ; AVX-NEXT: vmovshdup {{.*#+}} xmm1 = xmm0[1,1,3,3] @@ -241,10 +241,9 @@ ; ; AVX-LABEL: dot3_float2_float: ; AVX: # %bb.0: -; AVX-NEXT: vmovsd {{.*#+}} xmm0 = mem[0],zero -; AVX-NEXT: vmovsd {{.*#+}} xmm1 = mem[0],zero -; AVX-NEXT: vmulps %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vmovups (%rdi), %xmm0 ; AVX-NEXT: vmovss {{.*#+}} xmm1 = mem[0],zero,zero,zero +; AVX-NEXT: vmulps (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vmulss 8(%rsi), %xmm1, %xmm1 ; AVX-NEXT: vmovshdup {{.*#+}} xmm2 = xmm0[1,1,3,3] ; AVX-NEXT: vaddss %xmm2, %xmm0, %xmm0 @@ -414,9 +413,8 @@ ; ; AVX-LABEL: dot2_float2: ; AVX: # %bb.0: -; AVX-NEXT: vmovsd {{.*#+}} xmm0 = mem[0],zero -; AVX-NEXT: vmovsd {{.*#+}} xmm1 = mem[0],zero -; AVX-NEXT: vmulps %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vmovups (%rdi), %xmm0 +; AVX-NEXT: vmulps (%rsi), %xmm0, %xmm0 ; AVX-NEXT: vmovshdup {{.*#+}} xmm1 = xmm0[1,1,3,3] ; AVX-NEXT: vaddss %xmm1, %xmm0, %xmm0 ; AVX-NEXT: retq diff --git a/llvm/test/CodeGen/X86/load-partial.ll b/llvm/test/CodeGen/X86/load-partial.ll --- a/llvm/test/CodeGen/X86/load-partial.ll +++ b/llvm/test/CodeGen/X86/load-partial.ll @@ -139,7 +139,7 @@ ; ; AVX-LABEL: load_float4_float3_as_float2_float_0122: ; AVX: # %bb.0: -; AVX-NEXT: vmovsd {{.*#+}} xmm0 = mem[0],zero +; AVX-NEXT: vmovups (%rdi), %xmm0 ; AVX-NEXT: vmovss {{.*#+}} xmm1 = mem[0],zero,zero,zero ; AVX-NEXT: vshufps {{.*#+}} xmm0 = xmm0[0,1],xmm1[0,0] ; AVX-NEXT: retq diff --git a/llvm/test/CodeGen/X86/masked_gather.ll b/llvm/test/CodeGen/X86/masked_gather.ll --- a/llvm/test/CodeGen/X86/masked_gather.ll +++ b/llvm/test/CodeGen/X86/masked_gather.ll @@ -1148,7 +1148,7 @@ ; SSE-NEXT: testb $1, %al ; SSE-NEXT: je .LBB4_1 ; SSE-NEXT: # %bb.2: # %cond.load -; SSE-NEXT: movd {{.*#+}} xmm0 = mem[0],zero,zero,zero +; SSE-NEXT: movdqu c+12(%rip), %xmm0 ; SSE-NEXT: testb $2, %al ; SSE-NEXT: jne .LBB4_4 ; SSE-NEXT: jmp .LBB4_5 @@ -1455,7 +1455,7 @@ ; AVX1-NEXT: vinsertf128 $1, %xmm2, %ymm0, %ymm0 ; AVX1-NEXT: retq ; AVX1-NEXT: .LBB4_1: # %cond.load -; AVX1-NEXT: vmovd {{.*#+}} xmm1 = mem[0],zero,zero,zero +; AVX1-NEXT: vmovdqu c+12(%rip), %xmm1 ; AVX1-NEXT: testb $2, %al ; AVX1-NEXT: je .LBB4_4 ; AVX1-NEXT: .LBB4_3: # %cond.load1 @@ -1657,7 +1657,7 @@ ; AVX2-NEXT: vpaddd %ymm0, %ymm1, %ymm0 ; AVX2-NEXT: retq ; AVX2-NEXT: .LBB4_1: # %cond.load -; AVX2-NEXT: vmovd {{.*#+}} xmm1 = mem[0],zero,zero,zero +; AVX2-NEXT: vmovdqu c+12(%rip), %xmm1 ; AVX2-NEXT: testb $2, %al ; AVX2-NEXT: je .LBB4_4 ; AVX2-NEXT: .LBB4_3: # %cond.load1