diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -50227,9 +50227,40 @@ } static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG, - TargetLowering::DAGCombinerInfo &DCI) { + TargetLowering::DAGCombinerInfo &DCI, + const X86Subtarget &Subtarget) { + auto *MemOp = cast(N); + SDValue Index = MemOp->getIndex(); + SDValue Scale = MemOp->getScale(); + SDValue Mask = MemOp->getMask(); + + // Attempt to fold an index scale into the scale value directly. + // TODO: Move this into X86DAGToDAGISel::matchVectorAddressRecursively? + if ((Index.getOpcode() == X86ISD::VSHLI || + (Index.getOpcode() == ISD::ADD && + Index.getOperand(0) == Index.getOperand(1))) && + isa(Scale)) { + unsigned ShiftAmt = + Index.getOpcode() == ISD::ADD ? 1 : Index.getConstantOperandVal(1); + uint64_t ScaleAmt = cast(Scale)->getZExtValue(); + uint64_t NewScaleAmt = ScaleAmt * (1ULL << ShiftAmt); + if (isPowerOf2_64(NewScaleAmt) && NewScaleAmt <= 8) { + SDValue NewIndex = Index.getOperand(0); + SDValue NewScale = + DAG.getTargetConstant(NewScaleAmt, SDLoc(N), Scale.getValueType()); + if (N->getOpcode() == X86ISD::MGATHER) + return getAVX2GatherNode(N->getOpcode(), SDValue(N, 0), DAG, + MemOp->getOperand(1), Mask, + MemOp->getBasePtr(), NewIndex, NewScale, + MemOp->getChain(), Subtarget); + if (N->getOpcode() == X86ISD::MSCATTER) + return getScatterNode(N->getOpcode(), SDValue(N, 0), DAG, + MemOp->getOperand(1), Mask, MemOp->getBasePtr(), + NewIndex, NewScale, MemOp->getChain(), Subtarget); + } + } + // With vector masks we only demand the upper bit of the mask. - SDValue Mask = cast(N)->getMask(); if (Mask.getScalarValueSizeInBits() != 1) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits())); @@ -52886,7 +52917,8 @@ case X86ISD::FMSUBADD: return combineFMADDSUB(N, DAG, DCI); case X86ISD::MOVMSK: return combineMOVMSK(N, DAG, DCI, Subtarget); case X86ISD::MGATHER: - case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI); + case X86ISD::MSCATTER: + return combineX86GatherScatter(N, DAG, DCI, Subtarget); case ISD::MGATHER: case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI); case X86ISD::PCMPEQ: diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll --- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll +++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll @@ -808,20 +808,19 @@ ; KNL_64-NEXT: vmovd %esi, %xmm0 ; KNL_64-NEXT: vpbroadcastd %xmm0, %ymm0 ; KNL_64-NEXT: vpmovsxdq %ymm0, %zmm0 -; KNL_64-NEXT: vpsllq $2, %zmm0, %zmm0 ; KNL_64-NEXT: kxnorw %k0, %k0, %k1 ; KNL_64-NEXT: vxorps %xmm1, %xmm1, %xmm1 -; KNL_64-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1} +; KNL_64-NEXT: vgatherqps (%rax,%zmm0,4), %ymm1 {%k1} ; KNL_64-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0 ; KNL_64-NEXT: retq ; ; KNL_32-LABEL: test14: ; KNL_32: # %bb.0: ; KNL_32-NEXT: vmovd %xmm0, %eax -; KNL_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1 +; KNL_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1 ; KNL_32-NEXT: kxnorw %k0, %k0, %k1 ; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; KNL_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1} +; KNL_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1} ; KNL_32-NEXT: retl ; ; SKX-LABEL: test14: @@ -829,20 +828,19 @@ ; SKX-NEXT: vmovq %xmm0, %rax ; SKX-NEXT: vpbroadcastd %esi, %ymm0 ; SKX-NEXT: vpmovsxdq %ymm0, %zmm0 -; SKX-NEXT: vpsllq $2, %zmm0, %zmm0 ; SKX-NEXT: kxnorw %k0, %k0, %k1 ; SKX-NEXT: vxorps %xmm1, %xmm1, %xmm1 -; SKX-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1} +; SKX-NEXT: vgatherqps (%rax,%zmm0,4), %ymm1 {%k1} ; SKX-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0 ; SKX-NEXT: retq ; ; SKX_32-LABEL: test14: ; SKX_32: # %bb.0: ; SKX_32-NEXT: vmovd %xmm0, %eax -; SKX_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1 +; SKX_32-NEXT: vbroadcastss {{[0-9]+}}(%esp), %zmm1 ; SKX_32-NEXT: kxnorw %k0, %k0, %k1 ; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; SKX_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1} +; SKX_32-NEXT: vgatherdps (%eax,%zmm1,4), %zmm0 {%k1} ; SKX_32-NEXT: retl %broadcast.splatinsert = insertelement <16 x float*> %vec, float* %base, i32 1 @@ -4988,38 +4986,38 @@ ; ; PR13310 -; FIXME: Failure to fold scaled-index into gather/scatter scale operand. +; Failure to fold scaled-index into gather/scatter scale operand. ; define <8 x float> @scaleidx_x86gather(float* %base, <8 x i32> %index, <8 x i32> %imask) nounwind { ; KNL_64-LABEL: scaleidx_x86gather: ; KNL_64: # %bb.0: -; KNL_64-NEXT: vpslld $2, %ymm0, %ymm2 -; KNL_64-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; KNL_64-NEXT: vgatherdps %ymm1, (%rdi,%ymm2), %ymm0 +; KNL_64-NEXT: vxorps %xmm2, %xmm2, %xmm2 +; KNL_64-NEXT: vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2 +; KNL_64-NEXT: vmovaps %ymm2, %ymm0 ; KNL_64-NEXT: retq ; ; KNL_32-LABEL: scaleidx_x86gather: ; KNL_32: # %bb.0: ; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax -; KNL_32-NEXT: vpslld $2, %ymm0, %ymm2 -; KNL_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; KNL_32-NEXT: vgatherdps %ymm1, (%eax,%ymm2), %ymm0 +; KNL_32-NEXT: vxorps %xmm2, %xmm2, %xmm2 +; KNL_32-NEXT: vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2 +; KNL_32-NEXT: vmovaps %ymm2, %ymm0 ; KNL_32-NEXT: retl ; ; SKX-LABEL: scaleidx_x86gather: ; SKX: # %bb.0: -; SKX-NEXT: vpslld $2, %ymm0, %ymm2 -; SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; SKX-NEXT: vgatherdps %ymm1, (%rdi,%ymm2), %ymm0 +; SKX-NEXT: vxorps %xmm2, %xmm2, %xmm2 +; SKX-NEXT: vgatherdps %ymm1, (%rdi,%ymm0,4), %ymm2 +; SKX-NEXT: vmovaps %ymm2, %ymm0 ; SKX-NEXT: retq ; ; SKX_32-LABEL: scaleidx_x86gather: ; SKX_32: # %bb.0: ; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax -; SKX_32-NEXT: vpslld $2, %ymm0, %ymm2 -; SKX_32-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; SKX_32-NEXT: vgatherdps %ymm1, (%eax,%ymm2), %ymm0 +; SKX_32-NEXT: vxorps %xmm2, %xmm2, %xmm2 +; SKX_32-NEXT: vgatherdps %ymm1, (%eax,%ymm0,4), %ymm2 +; SKX_32-NEXT: vmovaps %ymm2, %ymm0 ; SKX_32-NEXT: retl %ptr = bitcast float* %base to i8* %mask = bitcast <8 x i32> %imask to <8 x float> @@ -5070,8 +5068,7 @@ ; KNL_64-LABEL: scaleidx_x86scatter: ; KNL_64: # %bb.0: ; KNL_64-NEXT: kmovw %esi, %k1 -; KNL_64-NEXT: vpaddd %zmm1, %zmm1, %zmm1 -; KNL_64-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1} +; KNL_64-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1} ; KNL_64-NEXT: vzeroupper ; KNL_64-NEXT: retq ; @@ -5079,16 +5076,14 @@ ; KNL_32: # %bb.0: ; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax ; KNL_32-NEXT: kmovw {{[0-9]+}}(%esp), %k1 -; KNL_32-NEXT: vpaddd %zmm1, %zmm1, %zmm1 -; KNL_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,2) {%k1} +; KNL_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,4) {%k1} ; KNL_32-NEXT: vzeroupper ; KNL_32-NEXT: retl ; ; SKX-LABEL: scaleidx_x86scatter: ; SKX: # %bb.0: ; SKX-NEXT: kmovw %esi, %k1 -; SKX-NEXT: vpaddd %zmm1, %zmm1, %zmm1 -; SKX-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,2) {%k1} +; SKX-NEXT: vscatterdps %zmm0, (%rdi,%zmm1,4) {%k1} ; SKX-NEXT: vzeroupper ; SKX-NEXT: retq ; @@ -5096,8 +5091,7 @@ ; SKX_32: # %bb.0: ; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax ; SKX_32-NEXT: kmovw {{[0-9]+}}(%esp), %k1 -; SKX_32-NEXT: vpaddd %zmm1, %zmm1, %zmm1 -; SKX_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,2) {%k1} +; SKX_32-NEXT: vscatterdps %zmm0, (%eax,%zmm1,4) {%k1} ; SKX_32-NEXT: vzeroupper ; SKX_32-NEXT: retl %ptr = bitcast float* %base to i8* @@ -5135,18 +5129,16 @@ ; ; SKX-LABEL: scaleidx_scatter: ; SKX: # %bb.0: -; SKX-NEXT: vpaddd %ymm1, %ymm1, %ymm1 ; SKX-NEXT: kmovw %esi, %k1 -; SKX-NEXT: vscatterdps %ymm0, (%rdi,%ymm1,4) {%k1} +; SKX-NEXT: vscatterdps %ymm0, (%rdi,%ymm1,8) {%k1} ; SKX-NEXT: vzeroupper ; SKX-NEXT: retq ; ; SKX_32-LABEL: scaleidx_scatter: ; SKX_32: # %bb.0: ; SKX_32-NEXT: movl {{[0-9]+}}(%esp), %eax -; SKX_32-NEXT: vpaddd %ymm1, %ymm1, %ymm1 ; SKX_32-NEXT: kmovb {{[0-9]+}}(%esp), %k1 -; SKX_32-NEXT: vscatterdps %ymm0, (%eax,%ymm1,4) {%k1} +; SKX_32-NEXT: vscatterdps %ymm0, (%eax,%ymm1,8) {%k1} ; SKX_32-NEXT: vzeroupper ; SKX_32-NEXT: retl %scaledindex = mul <8 x i32> %index,