Index: llvm/include/llvm/CodeGen/SelectionDAG.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAG.h +++ llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1364,7 +1364,7 @@ ISD::MemIndexedMode AM); SDValue getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, - ISD::MemIndexType IndexType); + ISD::MemIndexType IndexType, ISD::LoadExtType ExtTy); SDValue getMaskedScatter(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, ISD::MemIndexType IndexType, Index: llvm/include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -513,6 +513,7 @@ class LoadSDNodeBitfields { friend class LoadSDNode; friend class MaskedLoadSDNode; + friend class MaskedGatherSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -2451,13 +2452,19 @@ friend class SelectionDAG; MaskedGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, - EVT MemVT, MachineMemOperand *MMO, + ISD::LoadExtType ETy, EVT MemVT, MachineMemOperand *MMO, ISD::MemIndexType IndexType) : MaskedGatherScatterSDNode(ISD::MGATHER, Order, dl, VTs, MemVT, MMO, - IndexType) {} + IndexType) { + LoadSDNodeBits.ExtTy = ETy; + } const SDValue &getPassThru() const { return getOperand(1); } + ISD::LoadExtType getExtensionType() const { + return ISD::LoadExtType(LoadSDNodeBits.ExtTy); + } + static bool classof(const SDNode *N) { return N->getOpcode() == ISD::MGATHER; } Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -931,6 +931,33 @@ return false; } +static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) { + if (!ScalarTy.isSimple()) + return false; + + uint64_t MaskForTy = 0ULL; + switch (ScalarTy.getSimpleVT().SimpleTy) { + case MVT::i8: + MaskForTy = 0xFFULL; + break; + case MVT::i16: + MaskForTy = 0xFFFFULL; + break; + case MVT::i32: + MaskForTy = 0xFFFFFFFFULL; + break; + default: + return false; + break; + } + + APInt Val; + if (ISD::isConstantSplatVector(N, Val)) + return Val.getLimitedValue() == MaskForTy; + + return false; +} + // Returns the SDNode if it is a constant float BuildVector // or constant float. static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { @@ -5647,6 +5674,28 @@ } } + // fold (and (masked_gather x)) -> (zext_masked_gather x) + if (auto *GN0 = dyn_cast(N0)) { + EVT MemVT = GN0->getMemoryVT(); + EVT ScalarVT = MemVT.getScalarType(); + + if (SDValue(GN0, 0).hasOneUse() && + isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) && + TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) { + SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(), + GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()}; + + SDValue ZExtLoad = DAG.getMaskedGather( + DAG.getVTList(VT, MVT::Other), MemVT, SDLoc(N), Ops, + GN0->getMemOperand(), GN0->getIndexType(), ISD::ZEXTLOAD); + + CombineTo(N, ZExtLoad); + AddToWorklist(ZExtLoad.getNode()); + // Avoid recheck of N. + return SDValue(N, 0); + } + } + // fold (and (load x), 255) -> (zextload x, i8) // fold (and (extload x, i16), 255) -> (zextload x, i8) // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8) @@ -11483,6 +11532,25 @@ } } + // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x) + if (auto *GN0 = dyn_cast(N0)) { + if (SDValue(GN0, 0).hasOneUse() && + ExtVT == GN0->getMemoryVT() && + TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) { + SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(), + GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()}; + + SDValue ExtLoad = DAG.getMaskedGather( + DAG.getVTList(VT, MVT::Other), ExtVT, SDLoc(N), Ops, + GN0->getMemOperand(), GN0->getIndexType(), ISD::SEXTLOAD); + + CombineTo(N, ExtLoad); + CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); + AddToWorklist(ExtLoad.getNode()); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } + } + // Form (sext_inreg (bswap >> 16)) or (sext_inreg (rotl (bswap) 16)) if (ExtVTBits <= 16 && N0.getOpcode() == ISD::OR) { if (SDValue BSwap = MatchBSwapHWordLow(N0.getNode(), N0.getOperand(0), Index: llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -679,12 +679,17 @@ assert(NVT == ExtPassThru.getValueType() && "Gather result type and the passThru argument type should be the same"); + ISD::LoadExtType ExtType = N->getExtensionType(); + if (ExtType == ISD::NON_EXTLOAD) + ExtType = ISD::EXTLOAD; + SDLoc dl(N); SDValue Ops[] = {N->getChain(), ExtPassThru, N->getMask(), N->getBasePtr(), N->getIndex(), N->getScale() }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(NVT, MVT::Other), N->getMemoryVT(), dl, Ops, - N->getMemOperand(), N->getIndexType()); + N->getMemOperand(), N->getIndexType(), + ExtType); // Legalize the chain result - switch anything that used the old chain to // use the new one. ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -1710,7 +1710,9 @@ SDValue PassThru = MGT->getPassThru(); SDValue Index = MGT->getIndex(); SDValue Scale = MGT->getScale(); + EVT MemoryVT = MGT->getMemoryVT(); Align Alignment = MGT->getOriginalAlign(); + ISD::LoadExtType ExtType = MGT->getExtensionType(); // Split Mask operand SDValue MaskLo, MaskHi; @@ -1723,6 +1725,10 @@ std::tie(MaskLo, MaskHi) = DAG.SplitVector(Mask, dl); } + EVT LoMemVT, HiMemVT; + // Split MemoryVT + std::tie(LoMemVT, HiMemVT) = DAG.GetSplitDestVTs(MemoryVT); + SDValue PassThruLo, PassThruHi; if (getTypeAction(PassThru.getValueType()) == TargetLowering::TypeSplitVector) GetSplitVector(PassThru, PassThruLo, PassThruHi); @@ -1741,12 +1747,12 @@ MGT->getRanges()); SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale}; - Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, OpsLo, - MMO, MGT->getIndexType()); + Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl, OpsLo, + MMO, MGT->getIndexType(), ExtType); SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale}; - Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, OpsHi, - MMO, MGT->getIndexType()); + Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl, OpsHi, + MMO, MGT->getIndexType(), ExtType); // Build a factor node to remember that this load is independent of the // other one. @@ -2355,6 +2361,7 @@ SDValue Mask = MGT->getMask(); SDValue PassThru = MGT->getPassThru(); Align Alignment = MGT->getOriginalAlign(); + ISD::LoadExtType ExtType = MGT->getExtensionType(); SDValue MaskLo, MaskHi; if (getTypeAction(Mask.getValueType()) == TargetLowering::TypeSplitVector) @@ -2385,12 +2392,12 @@ MGT->getRanges()); SDValue OpsLo[] = {Ch, PassThruLo, MaskLo, Ptr, IndexLo, Scale}; - SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoVT, dl, - OpsLo, MMO, MGT->getIndexType()); + SDValue Lo = DAG.getMaskedGather(DAG.getVTList(LoVT, MVT::Other), LoMemVT, dl, + OpsLo, MMO, MGT->getIndexType(), ExtType); SDValue OpsHi[] = {Ch, PassThruHi, MaskHi, Ptr, IndexHi, Scale}; - SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiVT, dl, - OpsHi, MMO, MGT->getIndexType()); + SDValue Hi = DAG.getMaskedGather(DAG.getVTList(HiVT, MVT::Other), HiMemVT, dl, + OpsHi, MMO, MGT->getIndexType(), ExtType); // Build a factor node to remember that this load is independent of the // other one. @@ -3891,7 +3898,8 @@ Scale }; SDValue Res = DAG.getMaskedGather(DAG.getVTList(WideVT, MVT::Other), N->getMemoryVT(), dl, Ops, - N->getMemOperand(), N->getIndexType()); + N->getMemOperand(), N->getIndexType(), + N->getExtensionType()); // Legalize the chain result - switch anything that used the old chain to // use the new one. @@ -4685,7 +4693,8 @@ SDValue Ops[] = {MG->getChain(), DataOp, Mask, MG->getBasePtr(), Index, Scale}; SDValue Res = DAG.getMaskedGather(MG->getVTList(), MG->getMemoryVT(), dl, Ops, - MG->getMemOperand(), MG->getIndexType()); + MG->getMemOperand(), MG->getIndexType(), + MG->getExtensionType()); ReplaceValueWith(SDValue(N, 1), Res.getValue(1)); ReplaceValueWith(SDValue(N, 0), Res.getValue(0)); return SDValue(); Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -7324,14 +7324,15 @@ SDValue SelectionDAG::getMaskedGather(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, - ISD::MemIndexType IndexType) { + ISD::MemIndexType IndexType, + ISD::LoadExtType ExtTy) { assert(Ops.size() == 6 && "Incompatible number of operands"); FoldingSetNodeID ID; AddNodeIDNode(ID, ISD::MGATHER, VTs, Ops); ID.AddInteger(VT.getRawBits()); ID.AddInteger(getSyntheticNodeSubclassData( - dl.getIROrder(), VTs, VT, MMO, IndexType)); + dl.getIROrder(), VTs, ExtTy, VT, MMO, IndexType)); ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); void *IP = nullptr; if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { @@ -7340,7 +7341,7 @@ } auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), - VTs, VT, MMO, IndexType); + VTs, ExtTy, VT, MMO, IndexType); createOperands(N, Ops); assert(N->getPassThru().getValueType() == N->getValueType(0) && Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4416,7 +4416,7 @@ } SDValue Ops[] = { Root, Src0, Mask, Base, Index, Scale }; SDValue Gather = DAG.getMaskedGather(DAG.getVTList(VT, MVT::Other), VT, sdl, - Ops, MMO, IndexType); + Ops, MMO, IndexType, ISD::NON_EXTLOAD); PendingLoads.push_back(Gather.getValue(1)); setValue(&I, Gather); Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -737,6 +737,25 @@ OS << ", compressing"; OS << ">"; + } else if (const auto *MGather = dyn_cast(this)) { + OS << "<"; + printMemOperand(OS, *MGather->getMemOperand(), G); + + bool doExt = true; + switch (MGather->getExtensionType()) { + default: doExt = false; break; + case ISD::EXTLOAD: OS << ", anyext"; break; + case ISD::SEXTLOAD: OS << ", sext"; break; + case ISD::ZEXTLOAD: OS << ", zext"; break; + } + if (doExt) + OS << " from " << MGather->getMemoryVT().getEVTString(); + + auto Signed = MGather->isIndexSigned() ? "signed" : "unsigned"; + auto Scaled = MGather->isIndexScaled() ? "scaled" : "unscaled"; + OS << ", " << Signed << " " << Scaled << " offset"; + + OS << ">"; } else if (const auto *MScatter = dyn_cast(this)) { OS << "<"; printMemOperand(OS, *MScatter->getMemOperand(), G); Index: llvm/lib/Target/X86/X86ISelLowering.cpp =================================================================== --- llvm/lib/Target/X86/X86ISelLowering.cpp +++ llvm/lib/Target/X86/X86ISelLowering.cpp @@ -47548,7 +47548,8 @@ return DAG.getMaskedGather(Gather->getVTList(), Gather->getMemoryVT(), DL, Ops, Gather->getMemOperand(), - Gather->getIndexType()); + Gather->getIndexType(), + Gather->getExtensionType()); } auto *Scatter = cast(GorS); SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),