Index: llvm/include/llvm/CodeGen/SelectionDAG.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAG.h +++ llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1358,7 +1358,7 @@ ISD::MemIndexedMode AM); SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, - ISD::MemIndexType IndexType); + ISD::MemIndexType IndexType, ISD::LoadExtType ExtTy); SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, ISD::MemIndexType IndexType, Index: llvm/include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -513,6 +513,7 @@ class LoadSDNodeBitfields { friend class LoadSDNode; friend class MaskedLoadSDNode; + friend class MaskedGatherSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -2425,13 +2426,19 @@ friend class SelectionDAG; MaskedGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, - EVT MemVT, MachineMemOperand *MMO, + ISD::LoadExtType ETy, EVT MemVT, MachineMemOperand *MMO, ISD::MemIndexType IndexType) : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, VTs, MemVT, MMO, - IndexType) {} + IndexType) { + LoadSDNodeBits.ExtTy = ETy; + } const SDValue &getPassThru() const { return getOperand(1); } + ISD::LoadExtType getExtensionType() const { + return ISD::LoadExtType(LoadSDNodeBits.ExtTy); + } + static bool classof(const SDNode *N) { return N->getOpcode() == ISD::MGATHER; } Index: llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -679,12 +679,17 @@ assert(NVT == ExtPassThru.getValueType() && "Gather result type and the passThru argument type should be the same"); + ISD::LoadExtType ExtType = N->getExtensionType(); + if (ExtType == ISD::NON_EXTLOAD) + ExtType = ISD::EXTLOAD; + SDLoc dl(N); SDValue Ops[] = {N->getChain(), ExtPassThru, N->getMask(), N->getBasePtr(), N->getIndex(), N->getScale() }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other), N->getMemoryVT(), dl, Ops, - N->getMemOperand(), N->getIndexType()); + N->getMemOperand(), N->getIndexType(), + ExtType); // Legalize the chain result - switch anything that used the old chain to // use the new one. ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1713,6 +1713,7 @@ SDValue Index = MGT->getIndex(); SDValue Scale = MGT->getScale(); Align Alignment = MGT->getOriginalAlign(); + ISD::LoadExtType ExtType = MGT->getExtensionType(); // Split Mask operand SDValue MaskLo, MaskHi; @@ -1744,11 +1745,11 @@ SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale}; Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo, - MMO, MGT->getIndexType()); + MMO, MGT->getIndexType(), ExtType); SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale}; Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi, - MMO, MGT->getIndexType()); + MMO, MGT->getIndexType(), ExtType); // Build a factor node to remember that this load is independent of the // other one. @@ -2359,6 +2360,7 @@ SDValue Mask = MGT->getMask(); SDValue PassThru = MGT->getPassThru(); Align Alignment = MGT->getOriginalAlign(); + ISD::LoadExtType ExtType = MGT->getExtensionType(); SDValue MaskLo, MaskHi; if (getTypeAction(Mask.getValueType()) == TargetLowering::TypeSplitVector) @@ -2390,11 +2392,11 @@ SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale}; SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, - OpsLo, MMO, MGT->getIndexType()); + OpsLo, MMO, MGT->getIndexType(), ExtType); SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale}; SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, - OpsHi, MMO, MGT->getIndexType()); + OpsHi, MMO, MGT->getIndexType(), ExtType); // Build a factor node to remember that this load is independent of the // other one. @@ -3895,7 +3897,8 @@ Scale }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), N->getMemoryVT(), dl, Ops, - N->getMemOperand(), N->getIndexType()); + N->getMemOperand(), N->getIndexType(), + N->getExtensionType()); // Legalize the chain result - switch anything that used the old chain to // use the new one. @@ -4689,7 +4692,8 @@ SDValue Ops[] = {MG->getChain(), DataOp, Mask, MG->getBasePtr(), Index, Scale}; SDValue Res = DAG.getMaskedGather(MG->getVTList(), MG->getMemoryVT(), dl, Ops, - MG->getMemOperand(), MG->getIndexType()); + MG->getMemOperand(), MG->getIndexType(), + MG->getExtensionType()); ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); ReplaceValueWith(SDValue(N, 0), Res.getValue(0)); return SDValue(); Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -7345,14 +7345,15 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, - ISD::MemIndexType IndexType) { + ISD::MemIndexType IndexType, + ISD::LoadExtType ExtTy) { assert(Ops.size() == 6 && "Incompatible number of operands"); FoldingSetNodeID ID; AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops); ID.AddInteger(VT.getRawBits()); ID.AddInteger(getSyntheticNodeSubclassData( - dl.getIROrder(), VTs, VT, MMO, IndexType)); + dl.getIROrder(), VTs, ExtTy, VT, MMO, IndexType)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { @@ -7361,7 +7362,7 @@ } auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), - VTs, VT, MMO, IndexType); + VTs, ExtTy, VT, MMO, IndexType); createOperands(N, Ops); assert(N->getPassThru().getValueType() == N->getValueType(0) && Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4413,7 +4413,7 @@ } SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale }; SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl, - Ops, MMO, IndexType); + Ops, MMO, IndexType, ISD::NON_EXTLOAD); PendingLoads.push_back(Gather.getValue(1)); setValue(&I, Gather); Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -735,6 +735,25 @@ OS << ", compressing"; OS << ">"; + } else if (const auto *MGather = dyn_cast(this)) { + OS << "<"; + printMemOperand(OS, *MGather->getMemOperand(), G); + + bool doExt = true; + switch (MGather->getExtensionType()) { + default: doExt = false; break; + case ISD::EXTLOAD: OS << ", anyext"; break; + case ISD::SEXTLOAD: OS << ", sext"; break; + case ISD::ZEXTLOAD: OS << ", zext"; break; + } + if (doExt) + OS << " from " << MGather->getMemoryVT().getEVTString(); + + auto Signed = MGather->isIndexSigned() ? "signed" : "unsigned"; + auto Scaled = MGather->isIndexScaled() ? "scaled" : "unscaled"; + OS << ", " << Signed << " " << Scaled << " offset"; + + OS << ">"; } else if (const auto *MScatter = dyn_cast(this)) { OS << "<"; printMemOperand(OS, *MScatter->getMemOperand(), G); Index: llvm/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.cpp +++ llvm/lib/Target/X86/X86ISelLowering.cpp @@ -47542,7 +47542,8 @@ return DAG.getMaskedGather(Gather->getVTList(), Gather->getMemoryVT(), DL, Ops, Gather->getMemOperand(), - Gather->getIndexType()); + Gather->getIndexType(), + Gather->getExtensionType()); } auto *Scatter = cast(GorS); SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),