Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9428,13 +9428,13 @@ } // Fold sext/zext of index into index type. -bool refineIndexType(MaskedScatterSDNode *MSC, SDValue &Index, bool Scaled, - SelectionDAG &DAG) { +bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index, + bool Scaled, SelectionDAG &DAG) { const TargetLowering &TLI = DAG.getTargetLoweringInfo(); SDValue Op = Index.getOperand(0); if (Index.getOpcode() == ISD::ZERO_EXTEND) { - MSC->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED); + MGS->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED); if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { Index = Op; return true; @@ -9442,7 +9442,7 @@ } if (Index.getOpcode() == ISD::SIGN_EXTEND) { - MSC->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED); + MGS->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED); if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) { Index = Op; return true; @@ -9511,11 +9511,30 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) { MaskedGatherSDNode *MGT = cast(N); SDValue Mask = MGT->getMask(); + SDValue Chain = MGT->getChain(); + SDValue Index = MGT->getIndex(); + SDValue Scale = MGT->getScale(); + SDValue PassThru = MGT->getPassThru(); + SDValue BasePtr = MGT->getBasePtr(); SDLoc DL(N); // Zap gathers with a zero mask. if (ISD::isBuildVectorAllZeros(Mask.getNode())) - return CombineTo(N, MGT->getPassThru(), MGT->getChain()); + return CombineTo(N, PassThru, MGT->getChain()); + + if (refineUniformBase(BasePtr, Index, DAG)) { + SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), + PassThru.getValueType(), DL, Ops, + MGT->getMemOperand(), MGT->getIndexType()); + } + + if (refineIndexType(MGT, Index, MGT->isIndexScaled(), DAG)) { + SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; + return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other), + PassThru.getValueType(), DL, Ops, + MGT->getMemOperand(), MGT->getIndexType()); + } return SDValue(); } Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1746,6 +1746,7 @@ SDValue PassThru = MGT->getPassThru(); SDValue Index = MGT->getIndex(); SDValue Scale = MGT->getScale(); + EVT MemoryVT = MGT->getMemoryVT(); Align Alignment = MGT->getOriginalAlign(); // Split Mask operand @@ -1759,6 +1760,10 @@ std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, dl); } + EVT LoMemVT, HiMemVT; + // Split MemoryVT + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + SDValue PassThruLo, PassThruHi; if (getTypeAction(PassThru.getValueType()) == TargetLowering::TypeSplitVector) GetSplitVector(PassThru, PassThruLo, PassThruHi); @@ -1777,11 +1782,11 @@ MGT->getRanges()); SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale}; - Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo, + Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl, OpsLo, MMO, MGT->getIndexType()); SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale}; - Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi, + Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl, OpsHi, MMO, MGT->getIndexType()); // Build a factor node to remember that this load is independent of the @@ -2421,11 +2426,11 @@ MGT->getRanges()); SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale}; - SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, + SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl, OpsLo, MMO, MGT->getIndexType()); SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale}; - SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, + SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl, OpsHi, MMO, MGT->getIndexType()); // Build a factor node to remember that this load is independent of the Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4416,7 +4416,7 @@ if (!UniformBase) { Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); Index = getValue(Ptr); - IndexType = ISD::SIGNED_SCALED; + IndexType = ISD::SIGNED_UNSCALED; Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); } SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale }; Index: llvm/test/CodeGen/X86/masked_gather_scatter.ll =================================================================== --- llvm/test/CodeGen/X86/masked_gather_scatter.ll +++ llvm/test/CodeGen/X86/masked_gather_scatter.ll @@ -765,45 +765,41 @@ define <16 x float> @test14(float* %base, i32 %ind, <16 x float*> %vec) { ; KNL_64-LABEL: test14: ; KNL_64: # %bb.0: -; KNL_64-NEXT: vpbroadcastq %xmm0, %zmm0 -; KNL_64-NEXT: vmovd %esi, %xmm1 -; KNL_64-NEXT: vpbroadcastd %xmm1, %ymm1 -; KNL_64-NEXT: vpmovsxdq %ymm1, %zmm1 -; KNL_64-NEXT: vpsllq $2, %zmm1, %zmm1 -; KNL_64-NEXT: vpaddq %zmm1, %zmm0, %zmm0 +; KNL_64-NEXT: vmovq %xmm0, %rax +; KNL_64-NEXT: vmovd %esi, %xmm0 +; KNL_64-NEXT: vpbroadcastd %xmm0, %ymm0 +; KNL_64-NEXT: vpmovsxdq %ymm0, %zmm0 +; KNL_64-NEXT: vpsllq $2, %zmm0, %zmm0 ; KNL_64-NEXT: kxnorw %k0, %k0, %k1 -; KNL_64-NEXT: vgatherqps (,%zmm0), %ymm1 {%k1} +; KNL_64-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1} ; KNL_64-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0 ; KNL_64-NEXT: retq ; ; KNL_32-LABEL: test14: ; KNL_32: # %bb.0: -; KNL_32-NEXT: vpbroadcastd %xmm0, %zmm0 +; KNL_32-NEXT: vmovd %xmm0, %eax ; KNL_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1 -; KNL_32-NEXT: vpaddd %zmm1, %zmm0, %zmm1 ; KNL_32-NEXT: kxnorw %k0, %k0, %k1 -; KNL_32-NEXT: vgatherdps (,%zmm1), %zmm0 {%k1} +; KNL_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1} ; KNL_32-NEXT: retl ; ; SKX-LABEL: test14: ; SKX: # %bb.0: -; SKX-NEXT: vpbroadcastq %xmm0, %zmm0 -; SKX-NEXT: vpbroadcastd %esi, %ymm1 -; SKX-NEXT: vpmovsxdq %ymm1, %zmm1 -; SKX-NEXT: vpsllq $2, %zmm1, %zmm1 -; SKX-NEXT: vpaddq %zmm1, %zmm0, %zmm0 +; SKX-NEXT: vmovq %xmm0, %rax +; SKX-NEXT: vpbroadcastd %esi, %ymm0 +; SKX-NEXT: vpmovsxdq %ymm0, %zmm0 +; SKX-NEXT: vpsllq $2, %zmm0, %zmm0 ; SKX-NEXT: kxnorw %k0, %k0, %k1 -; SKX-NEXT: vgatherqps (,%zmm0), %ymm1 {%k1} +; SKX-NEXT: vgatherqps (%rax,%zmm0), %ymm1 {%k1} ; SKX-NEXT: vinsertf64x4 $1, %ymm1, %zmm1, %zmm0 ; SKX-NEXT: retq ; ; SKX_32-LABEL: test14: ; SKX_32: # %bb.0: -; SKX_32-NEXT: vpbroadcastd %xmm0, %zmm0 +; SKX_32-NEXT: vmovd %xmm0, %eax ; SKX_32-NEXT: vpslld $2, {{[0-9]+}}(%esp){1to16}, %zmm1 -; SKX_32-NEXT: vpaddd %zmm1, %zmm0, %zmm1 ; SKX_32-NEXT: kxnorw %k0, %k0, %k1 -; SKX_32-NEXT: vgatherdps (,%zmm1), %zmm0 {%k1} +; SKX_32-NEXT: vgatherdps (%eax,%zmm1), %zmm0 {%k1} ; SKX_32-NEXT: retl %broadcast.splatinsert = insertelement <16 x float*> %vec, float* %base, i32 1