Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -2120,13 +2120,14 @@ : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} // In the both nodes address is Op1, mask is Op2: - // MaskedGatherSDNode (Chain, src0, mask, base, index), src0 is a passthru value - // MaskedScatterSDNode (Chain, value, mask, base, index) + // MaskedGatherSDNode (Chain, passthru, mask, base, index, scale) + // MaskedScatterSDNode (Chain, value, mask, base, index, scale) // Mask is a vector of i1 elements const SDValue &getBasePtr() const { return getOperand(3); } const SDValue &getIndex() const { return getOperand(4); } const SDValue &getMask() const { return getOperand(2); } const SDValue &getValue() const { return getOperand(1); } + const SDValue &getScale() const { return getOperand(5); } static bool classof(const SDNode *N) { return N->getOpcode() == ISD::MGATHER || Index: lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -501,7 +501,7 @@ SDLoc dl(N); SDValue Ops[] = {N->getChain(), ExtSrc0, N->getMask(), N->getBasePtr(), - N->getIndex()}; + N->getIndex(), N->getScale() }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other), N->getMemoryVT(), dl, Ops, N->getMemOperand()); Index: lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1232,6 +1232,7 @@ SDValue Mask = MGT->getMask(); SDValue Src0 = MGT->getValue(); SDValue Index = MGT->getIndex(); + SDValue Scale = MGT->getScale(); unsigned Alignment = MGT->getOriginalAlignment(); // Split Mask operand @@ -1263,11 +1264,11 @@ MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); - SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo}; + SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale}; Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo, MMO); - SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi}; + SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale}; Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi, MMO); @@ -1810,6 +1811,7 @@ SDValue Ch = MGT->getChain(); SDValue Ptr = MGT->getBasePtr(); SDValue Index = MGT->getIndex(); + SDValue Scale = MGT->getScale(); SDValue Mask = MGT->getMask(); SDValue Src0 = MGT->getValue(); unsigned Alignment = MGT->getOriginalAlignment(); @@ -1842,7 +1844,7 @@ MachineMemOperand::MOLoad, LoMemVT.getStoreSize(), Alignment, MGT->getAAInfo(), MGT->getRanges()); - SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo}; + SDValue OpsLo[] = {Ch, Src0Lo, MaskLo, Ptr, IndexLo, Scale}; SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo, MMO); @@ -1852,7 +1854,7 @@ Alignment, MGT->getAAInfo(), MGT->getRanges()); - SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi}; + SDValue OpsHi[] = {Ch, Src0Hi, MaskHi, Ptr, IndexHi, Scale}; SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi, MMO); @@ -1935,6 +1937,7 @@ SDValue Ptr = N->getBasePtr(); SDValue Mask = N->getMask(); SDValue Index = N->getIndex(); + SDValue Scale = N->getScale(); SDValue Data = N->getValue(); EVT MemoryVT = N->getMemoryVT(); unsigned Alignment = N->getOriginalAlignment(); @@ -1970,7 +1973,7 @@ MachineMemOperand::MOStore, LoMemVT.getStoreSize(), Alignment, N->getAAInfo(), N->getRanges()); - SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo}; + SDValue OpsLo[] = {Ch, DataLo, MaskLo, Ptr, IndexLo, Scale}; Lo = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataLo.getValueType(), DL, OpsLo, MMO); @@ -1982,7 +1985,7 @@ // The order of the Scatter operation after split is well defined. The "Hi" // part comes after the "Lo". So these two operations should be chained one // after another. - SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi}; + SDValue OpsHi[] = {Lo, DataHi, MaskHi, Ptr, IndexHi, Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), DataHi.getValueType(), DL, OpsHi, MMO); } @@ -2948,6 +2951,7 @@ SDValue Mask = N->getMask(); EVT MaskVT = Mask.getValueType(); SDValue Src0 = GetWidenedVector(N->getValue()); + SDValue Scale = N->getScale(); unsigned NumElts = WideVT.getVectorNumElements(); SDLoc dl(N); @@ -2963,7 +2967,7 @@ Index.getValueType().getScalarType(), NumElts); Index = ModifyToType(Index, WideIndexVT); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), N->getMemoryVT(), dl, Ops, N->getMemOperand()); @@ -3587,6 +3591,7 @@ SDValue DataOp = MSC->getValue(); SDValue Mask = MSC->getMask(); EVT MaskVT = Mask.getValueType(); + SDValue Scale = MSC->getScale(); // Widen the value. SDValue WideVal = GetWidenedVector(DataOp); @@ -3606,7 +3611,8 @@ NumElts); Index = ModifyToType(Index, WideIndexVT); - SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index}; + SDValue Ops[] = {MSC->getChain(), WideVal, Mask, MSC->getBasePtr(), Index, + Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), dl, Ops, MSC->getMemOperand()); Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -6208,7 +6208,7 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO) { - assert(Ops.size() == 5 && "Incompatible number of operands"); + assert(Ops.size() == 6 && "Incompatible number of operands"); FoldingSetNodeID ID; AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops); @@ -6234,6 +6234,9 @@ assert(N->getIndex().getValueType().getVectorNumElements() == N->getValueType(0).getVectorNumElements() && "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); CSEMap.InsertNode(N, IP); InsertNode(N); @@ -6245,7 +6248,7 @@ SDValue SelectionDAG::getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO) { - assert(Ops.size() == 5 && "Incompatible number of operands"); + assert(Ops.size() == 6 && "Incompatible number of operands"); FoldingSetNodeID ID; AddNodeIDNode(ID, ISD::MSCATTER, VTs, Ops); @@ -6268,6 +6271,9 @@ assert(N->getIndex().getValueType().getVectorNumElements() == N->getValue().getValueType().getVectorNumElements() && "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); CSEMap.InsertNode(N, IP); InsertNode(N); Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3867,7 +3867,7 @@ // extract the splat value and use it as a uniform base. // In all other cases the function returns 'false'. static bool getUniformBase(const Value* &Ptr, SDValue& Base, SDValue& Index, - SelectionDAGBuilder* SDB) { + SDValue &Scale, SelectionDAGBuilder* SDB) { SelectionDAG& DAG = SDB->DAG; LLVMContext &Context = *DAG.getContext(); @@ -3897,6 +3897,10 @@ if (!SDB->findValue(Ptr) || !SDB->findValue(IndexVal)) return false; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + const DataLayout &DL = DAG.getDataLayout(); + Scale = DAG.getTargetConstant(DL.getTypeAllocSize(GEP->getResultElementType()), + SDB->getCurSDLoc(), TLI.getPointerTy(DL)); Base = SDB->getValue(Ptr); Index = SDB->getValue(IndexVal); @@ -3926,8 +3930,9 @@ SDValue Base; SDValue Index; + SDValue Scale; const Value *BasePtr = Ptr; - bool UniformBase = getUniformBase(BasePtr, Base, Index, this); + bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; MachineMemOperand *MMO = DAG.getMachineFunction(). @@ -3935,10 +3940,11 @@ MachineMemOperand::MOStore, VT.getStoreSize(), Alignment, AAInfo); if (!UniformBase) { - Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); } - SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index }; + SDValue Ops[] = { getRoot(), Src0, Mask, Base, Index, Scale }; SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), VT, sdl, Ops, MMO); DAG.setRoot(Scatter); @@ -4025,8 +4031,9 @@ SDValue Root = DAG.getRoot(); SDValue Base; SDValue Index; + SDValue Scale; const Value *BasePtr = Ptr; - bool UniformBase = getUniformBase(BasePtr, Base, Index, this); + bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); bool ConstantMemory = false; if (UniformBase && AA && AA->pointsToConstantMemory(MemoryLocation( @@ -4044,10 +4051,11 @@ Alignment, AAInfo, Ranges); if (!UniformBase) { - Base = DAG.getTargetConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); } - SDValue Ops[] = { Root, Src0, Mask, Base, Index }; + SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale }; SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl, Ops, MMO); Index: lib/Target/X86/X86ISelDAGToDAG.cpp =================================================================== --- lib/Target/X86/X86ISelDAGToDAG.cpp +++ lib/Target/X86/X86ISelDAGToDAG.cpp @@ -1508,6 +1508,12 @@ bool X86DAGToDAGISel::matchVectorAddress(SDValue N, X86ISelAddressMode &AM) { // TODO: Support other operations. switch (N.getOpcode()) { + case ISD::Constant: { + uint64_t Val = cast(N)->getSExtValue(); + if (!foldOffsetIntoAddress(Val, AM)) + return false; + break; + } case X86ISD::Wrapper: if (!matchWrapper(N, AM)) return false; @@ -1524,6 +1530,7 @@ auto *Mgs = cast(Parent); AM.IndexReg = Mgs->getIndex(); AM.Scale = Mgs->getValue().getScalarValueSizeInBits() / 8; + AM.Scale = cast(Mgs->getScale())->getZExtValue(); unsigned AddrSpace = cast(Parent)->getPointerInfo().getAddrSpace(); // AddrSpace 256 -> GS, 257 -> FS, 258 -> SS. @@ -1534,14 +1541,8 @@ if (AddrSpace == 258) AM.Segment = CurDAG->getRegister(X86::SS, MVT::i16); - // If Base is 0, the whole address is in index and the Scale is 1 - if (isa(N)) { - assert(cast(N)->isNullValue() && - "Unexpected base in gather/scatter"); - AM.Scale = 1; - } - // Otherwise, try to match into the base and displacement fields. - else if (matchVectorAddress(N, AM)) + // Try to match into the base and displacement fields. + if (matchVectorAddress(N, AM)) return false; MVT VT = N.getSimpleValueType(); Index: lib/Target/X86/X86ISelLowering.h =================================================================== --- lib/Target/X86/X86ISelLowering.h +++ lib/Target/X86/X86ISelLowering.h @@ -1442,6 +1442,7 @@ const SDValue &getIndex() const { return getOperand(4); } const SDValue &getMask() const { return getOperand(2); } const SDValue &getValue() const { return getOperand(1); } + const SDValue &getScale() const { return getOperand(5); } static bool classof(const SDNode *N) { return N->getOpcode() == X86ISD::MGATHER || Index: lib/Target/X86/X86ISelLowering.cpp =================================================================== --- lib/Target/X86/X86ISelLowering.cpp +++ lib/Target/X86/X86ISelLowering.cpp @@ -24317,6 +24317,7 @@ assert(VT.getScalarSizeInBits() >= 32 && "Unsupported scatter op"); SDLoc dl(Op); + SDValue Scale = N->getScale(); SDValue Index = N->getIndex(); SDValue Mask = N->getMask(); SDValue Chain = N->getChain(); @@ -24383,7 +24384,7 @@ // The mask is killed by scatter, add it to the values SDVTList VTs = DAG.getVTList(Mask.getValueType(), MVT::Other); - SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index}; + SDValue Ops[] = {Chain, Src, Mask, BasePtr, Index, Scale}; SDValue NewScatter = DAG.getTargetMemSDNode( VTs, Ops, dl, N->getMemoryVT(), N->getMemOperand()); DAG.ReplaceAllUsesWith(Op, SDValue(NewScatter.getNode(), 1)); @@ -24489,6 +24490,7 @@ MaskedGatherSDNode *N = cast(Op.getNode()); SDLoc dl(Op); MVT VT = Op.getSimpleValueType(); + SDValue Scale = N->getScale(); SDValue Index = N->getIndex(); SDValue Mask = N->getMask(); SDValue Src0 = N->getValue(); @@ -24509,7 +24511,8 @@ // the vector contains 8 elements, we just sign-extend the index if (NumElts == 8) { Index = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v8i64, Index); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, + Scale }; SDValue NewGather = DAG.getTargetMemSDNode( DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), N->getMemOperand()); @@ -24533,7 +24536,7 @@ MVT NewVT = MVT::getVectorVT(VT.getScalarType(), NumElts); Src0 = ExtendToType(Src0, NewVT, DAG); - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; SDValue NewGather = DAG.getTargetMemSDNode( DAG.getVTList(NewVT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), N->getMemOperand()); @@ -24544,7 +24547,7 @@ return DAG.getMergeValues(RetOps, dl); } - SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index }; + SDValue Ops[] = { N->getChain(), Src0, Mask, N->getBasePtr(), Index, Scale }; SDValue NewGather = DAG.getTargetMemSDNode( DAG.getVTList(VT, MaskVT, MVT::Other), Ops, dl, N->getMemoryVT(), N->getMemOperand()); @@ -25080,7 +25083,7 @@ Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask); } SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), - Index }; + Index, Gather->getScale() }; SDValue Res = DAG.getTargetMemSDNode( DAG.getVTList(MVT::v4f32, Mask.getValueType(), MVT::Other), Ops, dl, Gather->getMemoryVT(), Gather->getMemOperand()); @@ -25107,7 +25110,7 @@ Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, MVT::v4i32, Mask); } SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), - Index }; + Index, Gather->getScale() }; SDValue Res = DAG.getTargetMemSDNode( DAG.getVTList(MVT::v4i32, Mask.getValueType(), MVT::Other), Ops, dl, Gather->getMemoryVT(), Gather->getMemOperand()); @@ -25128,7 +25131,7 @@ Mask = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4i1, Mask, DAG.getConstant(0, dl, MVT::v2i1)); SDValue Ops[] = { Gather->getChain(), Src0, Mask, Gather->getBasePtr(), - Index }; + Index, Gather->getScale() }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(MVT::v4i32, MVT::Other), Gather->getMemoryVT(), dl, Ops, Gather->getMemOperand());