Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -2117,13 +2117,14 @@ : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} // In the both nodes address is Op1, mask is Op2: - // MaskedGatherSDNode (Chain, src0, mask, base, index), src0 is a passthru value - // MaskedScatterSDNode (Chain, value, mask, base, index) + // MaskedGatherSDNode (Chain, passthru, mask, base, index, scale) + // MaskedScatterSDNode (Chain, value, mask, base, index, scale) // Mask is a vector of i1 elements const SDValue &getBasePtr() const { return getOperand(3); } const SDValue &getIndex() const { return getOperand(4); } const SDValue &getMask() const { return getOperand(2); } const SDValue &getValue() const { return getOperand(1); } + const SDValue &getScale() const { return getOperand(5); } static bool classof(const SDNode *N) { return N->getOpcode() == ISD::MGATHER || Index: lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -501,7 +501,7 @@ SDLoc dl(N); SDValue Ops[] = {N->getChain(), ExtSrc0, N->getMask(), N->getBasePtr(), - N->getIndex()}; + N->getIndex(), N->getScale() }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other), N->getMemoryVT(), dl, Ops, N->getMemOperand()); Index: lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1224,6 +1224,7 @@ SDValue Mask = MGT->getMask(); SDValue Src0 = MGT->getValue(); SDValue Index = MGT->getIndex(); + SDValue Scale = MGT->getScale(); unsigned Alignment = MGT->getOriginalAlignment(); // Split Mask operand @@ -1255,11 +1256,11 @@ MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); - SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo}; + SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale}; Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo, MMO); - SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi}; + SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale}; Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi, MMO); @@ -1798,6 +1799,7 @@ SDValue Ch = MGT->getChain(); SDValue Ptr = MGT->getBasePtr(); SDValue Index = MGT->getIndex(); + SDValue Scale = MGT->getScale(); SDValue Mask = MGT->getMask(); SDValue Src0 = MGT->getValue(); unsigned Alignment = MGT->getOriginalAlignment(); @@ -1830,7 +1832,7 @@ MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); - SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo}; + SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale}; SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo, MMO); @@ -1840,7 +1842,7 @@ Alignment, MGT->getAAInfo(), MGT->getRanges()); - SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi}; + SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale}; SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi, MMO); @@ -1926,6 +1928,7 @@ SDValue Ptr = N->getBasePtr(); SDValue Mask = N->getMask(); SDValue Index = N->getIndex(); + SDValue Scale = N->getScale(); SDValue Data = N->getValue(); EVT MemoryVT = N->getMemoryVT(); unsigned Alignment = N->getOriginalAlignment(); @@ -1961,7 +1964,7 @@ MachineMemOperand::MOStore, LoMemVT.getStoreSize(), Alignment, N->getAAInfo(), N->getRanges()); - SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo}; + SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale}; Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), DL, OpsLo, MMO); @@ -1973,7 +1976,7 @@ // The order of the Scatter operation after split is well defined. The "Hi" // part comes after the "Lo". So these two operations should be chained one // after another. - SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi}; + SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi, Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), DL, OpsHi, MMO); } @@ -2953,6 +2956,7 @@ EVT WideVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)); SDValue Mask = N->getMask(); SDValue Src0 = GetWidenedVector(N->getValue()); + SDValue Scale = N->getScale(); unsigned NumElts = WideVT.getVectorNumElements(); SDLoc dl(N); @@ -2965,7 +2969,7 @@ Index.getValueType().getScalarType(), NumElts); Index = ModifyToType(Index, WideIndexVT); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), N->getMemoryVT(), dl, Ops, N->getMemOperand()); @@ -3601,6 +3605,7 @@ MaskedScatterSDNode *MSC = cast(N); SDValue DataOp = MSC->getValue(); SDValue Mask = MSC->getMask(); + SDValue Scale = MSC->getScale(); // Widen the value. SDValue WideVal = GetWidenedVector(DataOp); @@ -3618,7 +3623,8 @@ NumElts); Index = ModifyToType(Index, WideIndexVT); - SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index}; + SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index, + Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), dl, Ops, MSC->getMemOperand()); Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6173,7 +6173,7 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO) { - assert(Ops.size() == 5 && "Incompatible number of operands"); + assert(Ops.size() == 6 && "Incompatible number of operands"); FoldingSetNodeID ID; AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops); @@ -6208,7 +6208,7 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO) { - assert(Ops.size() == 5 && "Incompatible number of operands"); + assert(Ops.size() == 6 && "Incompatible number of operands"); FoldingSetNodeID ID; AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops); Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3865,7 +3865,7 @@ // extract the splat value and use it as a uniform base. // In all other cases the function returns 'false'. static bool getUniformBase(const Value* &Ptr, SDValue& Base, SDValue& Index, - SelectionDAGBuilder* SDB) { + SDValue &Scale, SelectionDAGBuilder* SDB) { SelectionDAG& DAG = SDB->DAG; LLVMContext &Context = *DAG.getContext(); @@ -3895,6 +3895,10 @@ if (!SDB->findValue(Ptr) || !SDB->findValue(IndexVal)) return false; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + const DataLayout &DL = DAG.getDataLayout(); + Scale = DAG.getTargetConstant(DL.getTypeAllocSize(GEP->getResultElementType()), + SDB->getCurSDLoc(), TLI.getPointerTy(DL)); Base = SDB->getValue(Ptr); Index = SDB->getValue(IndexVal); @@ -3931,8 +3935,9 @@ SDValue Base; SDValue Index; + SDValue Scale; const Value *BasePtr = Ptr; - bool UniformBase = getUniformBase(BasePtr, Base, Index, this); + bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; MachineMemOperand *MMO = DAG.getMachineFunction(). @@ -3940,10 +3945,11 @@ MachineMemOperand::MOStore, VT.getStoreSize(), Alignment, AAInfo); if (!UniformBase) { - Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); } - SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index }; + SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index, Scale }; SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO); DAG.setRoot(Scatter); @@ -4030,8 +4036,9 @@ SDValue Root = DAG.getRoot(); SDValue Base; SDValue Index; + SDValue Scale; const Value *BasePtr = Ptr; - bool UniformBase = getUniformBase(BasePtr, Base, Index, this); + bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); bool ConstantMemory = false; if (UniformBase && AA && AA->pointsToConstantMemory(MemoryLocation( @@ -4049,10 +4056,11 @@ Alignment, AAInfo, Ranges); if (!UniformBase) { - Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); } - SDValue Ops[] = { Root, Src0, Mask, Base, Index }; + SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale }; SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl, Ops, MMO); Index: lib/Target/X86/X86ISelDAGToDAG.cpp =================================================================== --- lib/Target/X86/X86ISelDAGToDAG.cpp +++ lib/Target/X86/X86ISelDAGToDAG.cpp @@ -1509,6 +1509,12 @@ bool X86DAGToDAGISel::matchVectorAddress(SDValue N, X86ISelAddressMode &AM) { // TODO: Support other operations. switch (N.getOpcode()) { + case ISD::Constant: { + uint64_t Val = cast(N)->getSExtValue(); + if (!foldOffsetIntoAddress(Val, AM)) + return false; + break; + } case X86ISD::Wrapper: if (!matchWrapper(N, AM)) return false; @@ -1524,11 +1530,11 @@ X86ISelAddressMode AM; if (auto Mgs = dyn_cast(Parent)) { AM.IndexReg = Mgs->getIndex(); - AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8; + AM.Scale = cast(Mgs->getScale())->getZExtValue(); } else { auto X86Gather = cast(Parent); AM.IndexReg = X86Gather->getIndex(); - AM.Scale = X86Gather->getValue().getScalarValueSizeInBits() / 8; + AM.Scale = cast(X86Gather->getScale())->getZExtValue(); } unsigned AddrSpace = cast(Parent)->getPointerInfo().getAddrSpace(); @@ -1540,14 +1546,8 @@ if (AddrSpace == 258) AM.Segment = CurDAG->getRegister(X86::SS, MVT::i16); - // If Base is 0, the whole address is in index and the Scale is 1 - if (isa(N)) { - assert(cast(N)->isNullValue() && - "Unexpected base in gather/scatter"); - AM.Scale = 1; - } - // Otherwise, try to match into the base and displacement fields. - else if (matchVectorAddress(N, AM)) + // Try to match into the base and displacement fields. + if (matchVectorAddress(N, AM)) return false; MVT VT = N.getSimpleValueType(); Index: lib/Target/X86/X86ISelLowering.h =================================================================== --- lib/Target/X86/X86ISelLowering.h +++ lib/Target/X86/X86ISelLowering.h @@ -1420,6 +1420,7 @@ const SDValue &getIndex() const { return getOperand(4); } const SDValue &getMask() const { return getOperand(2); } const SDValue &getValue() const { return getOperand(1); } + const SDValue &getScale() const { return getOperand(5); } static bool classof(const SDNode *N) { return N->getOpcode() == X86ISD::MGATHER; Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -24110,6 +24110,7 @@ SDLoc dl(Op); SDValue NewScatter; + SDValue Scale = N->getScale(); SDValue Index = N->getIndex(); SDValue Mask = N->getMask(); SDValue Chain = N->getChain(); @@ -24179,7 +24180,7 @@ // The mask is killed by scatter, add it to the values SDVTList VTs = DAG.getVTList(BitMaskVT, MVT::Other); - SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index}; + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale }; NewScatter = DAG.getMaskedScatter(VTs, N->getMemoryVT(), dl, Ops, N->getMemOperand()); DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); @@ -24302,6 +24303,7 @@ MaskedGatherSDNode *N = cast(Op.getNode()); SDLoc dl(Op); MVT VT = Op.getSimpleValueType(); + SDValue Scale = N->getScale(); SDValue Index = N->getIndex(); SDValue Mask = N->getMask(); SDValue Src0 = N->getValue(); @@ -24319,7 +24321,7 @@ if (NumElts == 8) { Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); SDValue Ops[] = { N->getOperand(0), N->getOperand(1), N->getOperand(2), - N->getOperand(3), Index }; + N->getOperand(3), Index, Scale }; DAG.UpdateNodeOperands(N, Ops); return Op; } @@ -24344,7 +24346,7 @@ MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts); Src0 = ExtendToType(Src0, NewVT, DAG); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; SDValue NewGather = DAG.getMaskedGather(DAG.getVTList(NewVT, MVT::Other), N->getMemoryVT(), dl, Ops, N->getMemOperand()); @@ -24369,7 +24371,7 @@ // is not necessary since instruction itself reads only two values from // memory. Mask = ExtendToType(Mask, MVT::v4i1, DAG, false); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; SDValue NewGather = DAG.getTargetMemSDNode( DAG.getVTList(MVT::v4i32, MVT::Other), Ops, dl, N->getMemoryVT(), N->getMemOperand()); @@ -24396,7 +24398,7 @@ Index = Index.getOperand(0); } else return Op; - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; SDValue NewGather = DAG.getTargetMemSDNode( DAG.getVTList(MVT::v4f32, MVT::Other), Ops, dl, N->getMemoryVT(), N->getMemOperand()); Index: test/CodeGen/X86/masked_gather_scatter.ll =================================================================== --- test/CodeGen/X86/masked_gather_scatter.ll +++ test/CodeGen/X86/masked_gather_scatter.ll @@ -1272,7 +1272,7 @@ ; KNL_64-NEXT: vmovdqa %xmm1, %xmm1 ; KNL_64-NEXT: vpsllq $63, %zmm1, %zmm1 ; KNL_64-NEXT: vptestmq %zmm1, %zmm1, %k1 -; KNL_64-NEXT: vpgatherqq (%rdi,%zmm0,8), %zmm2 {%k1} +; KNL_64-NEXT: vpgatherqq (%rdi,%zmm0,4), %zmm2 {%k1} ; KNL_64-NEXT: vmovdqa %xmm2, %xmm0 ; KNL_64-NEXT: vzeroupper ; KNL_64-NEXT: retq @@ -1285,7 +1285,7 @@ ; KNL_32-NEXT: movl {{[0-9]+}}(%esp), %eax ; KNL_32-NEXT: vpsllq $63, %zmm1, %zmm1 ; KNL_32-NEXT: vptestmq %zmm1, %zmm1, %k1 -; KNL_32-NEXT: vpgatherqq (%eax,%zmm0,8), %zmm2 {%k1} +; KNL_32-NEXT: vpgatherqq (%eax,%zmm0,4), %zmm2 {%k1} ; KNL_32-NEXT: vmovdqa %xmm2, %xmm0 ; KNL_32-NEXT: vzeroupper ; KNL_32-NEXT: retl @@ -1320,7 +1320,7 @@ ; KNL_64-NEXT: # kill: %XMM0 %XMM0 %ZMM0 ; KNL_64-NEXT: movb $3, %al ; KNL_64-NEXT: kmovw %eax, %k1 -; KNL_64-NEXT: vpgatherqq (%rdi,%zmm0,8), %zmm1 {%k1} +; KNL_64-NEXT: vpgatherqq (%rdi,%zmm0,4), %zmm1 {%k1} ; KNL_64-NEXT: vmovdqa %xmm1, %xmm0 ; KNL_64-NEXT: vzeroupper ; KNL_64-NEXT: retq @@ -1332,7 +1332,7 @@ ; KNL_32-NEXT: vmovdqa {{.*#+}} xmm1 = [1,0,1,0] ; KNL_32-NEXT: vpsllq $63, %zmm1, %zmm1 ; KNL_32-NEXT: vptestmq %zmm1, %zmm1, %k1 -; KNL_32-NEXT: vpgatherqq (%eax,%zmm0,8), %zmm1 {%k1} +; KNL_32-NEXT: vpgatherqq (%eax,%zmm0,4), %zmm1 {%k1} ; KNL_32-NEXT: vmovdqa %xmm1, %xmm0 ; KNL_32-NEXT: vzeroupper ; KNL_32-NEXT: retl