diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1303,6 +1303,73 @@ SDValue getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue Base, SDValue Offset, ISD::MemIndexedMode AM); + SDValue getLoadVP(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, + const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Offset, + SDValue Mask, SDValue VLen, MachinePointerInfo PtrInfo, + EVT MemVT, Align Alignment, + MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + const MDNode *Ranges = nullptr, bool IsExpanding = false); + inline SDValue + getLoadVP(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, + const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Offset, + SDValue Mask, SDValue VLen, MachinePointerInfo PtrInfo, EVT MemVT, + MaybeAlign Alignment = MaybeAlign(), + MachineMemOperand::Flags MMOFlags = MachineMemOperand::MONone, + const AAMDNodes &AAInfo = AAMDNodes(), + const MDNode *Ranges = nullptr, bool IsExpanding = false) { + // Ensures that codegen never sees a None Alignment. + return getLoadVP(AM, ExtType, VT, dl, Chain, Ptr, Offset, Mask, VLen, + PtrInfo, MemVT, Alignment.getValueOr(getEVTAlign(MemVT)), + MMOFlags, AAInfo, Ranges, IsExpanding); + } + SDValue getLoadVP(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, + const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Offset, + SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, bool IsExpanding = false); + SDValue getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, + SDValue Mask, SDValue VLen, MachinePointerInfo PtrInfo, + MaybeAlign Alignment, MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, const MDNode *Ranges = nullptr, + bool IsExpanding = false); + SDValue getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, + SDValue Mask, SDValue VLen, MachineMemOperand *MMO, + bool IsExpanding = false); + SDValue getExtLoadVP(ISD::LoadExtType ExtType, const SDLoc &dl, EVT VT, + SDValue Chain, SDValue Ptr, SDValue Mask, SDValue VLen, + MachinePointerInfo PtrInfo, EVT MemVT, + MaybeAlign Alignment, MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, bool IsExpanding = false); + SDValue getExtLoadVP(ISD::LoadExtType ExtType, const SDLoc &dl, EVT VT, + SDValue Chain, SDValue Ptr, SDValue Mask, SDValue VLen, + EVT MemVT, MachineMemOperand *MMO, + bool IsExpanding = false); + SDValue getIndexedLoadVP(SDValue OrigLoad, const SDLoc &dl, SDValue Base, + SDValue Offset, ISD::MemIndexedMode AM); + SDValue getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, + SDValue Mask, SDValue VLen, MachinePointerInfo PtrInfo, + Align Alignment, MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, bool IsCompressing = false); + SDValue getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, + SDValue Mask, SDValue VLen, MachineMemOperand *MMO, + bool IsCompressing = false); + SDValue getTruncStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, + SDValue Ptr, SDValue Mask, SDValue VLen, + MachinePointerInfo PtrInfo, EVT SVT, Align Alignment, + MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, bool IsCompressing = false); + SDValue getTruncStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, + SDValue Ptr, SDValue Mask, SDValue VLen, EVT SVT, + MachineMemOperand *MMO, bool IsCompressing = false); + SDValue getIndexedStoreVP(SDValue OrigStore, const SDLoc &dl, SDValue Base, + SDValue Offset, ISD::MemIndexedMode AM); + + SDValue getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::MemIndexType IndexType); + SDValue getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::MemIndexType IndexType); + SDValue getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Base, SDValue Offset, SDValue Mask, SDValue Src0, EVT MemVT, MachineMemOperand *MMO, ISD::MemIndexedMode AM, diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -509,15 +509,19 @@ class LSBaseSDNodeBitfields { friend class LSBaseSDNode; + friend class VPLoadStoreSDNode; friend class MaskedLoadStoreSDNode; friend class MaskedGatherScatterSDNode; + friend class VPGatherScatterSDNode; uint16_t : NumMemSDNodeBits; // This storage is shared between disparate class hierarchies to hold an // enumeration specific to the class hierarchy in use. // LSBaseSDNode => enum ISD::MemIndexedMode + // VPLoadStoreBaseSDNode => enum ISD::MemIndexedMode // MaskedLoadStoreBaseSDNode => enum ISD::MemIndexedMode + // VPGatherScatterSDNode => enum ISD::MemIndexType // MaskedGatherScatterSDNode => enum ISD::MemIndexType uint16_t AddressingMode : 3; }; @@ -525,8 +529,10 @@ class LoadSDNodeBitfields { friend class LoadSDNode; + friend class VPLoadSDNode; friend class MaskedLoadSDNode; friend class MaskedGatherSDNode; + friend class VPGatherSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -536,8 +542,10 @@ class StoreSDNodeBitfields { friend class StoreSDNode; + friend class VPStoreSDNode; friend class MaskedStoreSDNode; friend class MaskedScatterSDNode; + friend class VPScatterSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -1353,10 +1361,13 @@ const SDValue &getBasePtr() const { switch (getOpcode()) { case ISD::STORE: + case ISD::VP_STORE: case ISD::MSTORE: return getOperand(2); case ISD::MGATHER: case ISD::MSCATTER: + case ISD::VP_GATHER: + case ISD::VP_SCATTER: return getOperand(3); default: return getOperand(1); @@ -1393,6 +1404,10 @@ case ISD::MSTORE: case ISD::MGATHER: case ISD::MSCATTER: + case ISD::VP_LOAD: + case ISD::VP_STORE: + case ISD::VP_GATHER: + case ISD::VP_SCATTER: return true; default: return N->isMemIntrinsic() || N->isTargetMemoryOpcode(); @@ -2316,6 +2331,115 @@ } }; +/// This base class is used to represent VP_LOAD and VP_STORE nodes +class VPLoadStoreSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, const DebugLoc &dl, + SDVTList VTs, ISD::MemIndexedMode AM, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) { + LSBaseSDNodeBits.AddressingMode = AM; + assert(getAddressingMode() == AM && "Value truncated"); + } + + // VPLoadSDNode (Chain, ptr, offset, mask, VLen) + // VPStoreSDNode (Chain, data, ptr, offset, mask, VLen) + // Mask is a vector of i1 elements, Vlen is i32 + const SDValue &getOffset() const { + return getOperand(getOpcode() == ISD::MLOAD ? 2 : 3); + } + const SDValue &getBasePtr() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 1 : 2); + } + const SDValue &getMask() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 3 : 4); + } + const SDValue &getVectorLength() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 4 : 5); + } + + /// Return the addressing mode for this load or store: + /// unindexed, pre-inc, pre-dec, post-inc, or post-dec. + ISD::MemIndexedMode getAddressingMode() const { + return static_cast(LSBaseSDNodeBits.AddressingMode); + } + + /// Return true if this is a pre/post inc/dec load/store. + bool isIndexed() const { return getAddressingMode() != ISD::UNINDEXED; } + + /// Return true if this is NOT a pre/post inc/dec load/store. + bool isUnindexed() const { return getAddressingMode() == ISD::UNINDEXED; } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD || N->getOpcode() == ISD::VP_STORE; + } +}; + +/// This class is used to represent a VP_LOAD node +class VPLoadSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + ISD::MemIndexedMode AM, ISD::LoadExtType ETy, bool isExpanding, + EVT MemVT, MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_LOAD, Order, dl, VTs, AM, MemVT, MMO) { + LoadSDNodeBits.ExtTy = ETy; + LoadSDNodeBits.IsExpanding = isExpanding; + } + + ISD::LoadExtType getExtensionType() const { + return static_cast(LoadSDNodeBits.ExtTy); + } + + const SDValue &getBasePtr() const { return getOperand(1); } + const SDValue &getOffset() const { return getOperand(2); } + const SDValue &getMask() const { return getOperand(3); } + const SDValue &getVectorLength() const { return getOperand(4); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD; + } + bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } +}; + +/// This class is used to represent a VP_STORE node +class VPStoreSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPStoreSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + ISD::MemIndexedMode AM, bool isTrunc, bool isCompressing, + EVT MemVT, MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_STORE, Order, dl, VTs, AM, MemVT, MMO) { + StoreSDNodeBits.IsTruncating = isTrunc; + StoreSDNodeBits.IsCompressing = isCompressing; + } + + /// Return true if this is a truncating store. + /// For integers this is the same as doing a TRUNCATE and storing the result. + /// For floats, it is the same as doing an FP_ROUND and storing the result. + bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; } + + /// Returns true if the op does a compression to the vector before storing. + /// The node contiguously stores the active elements (integers or floats) + /// in src (those with their respective bit set in writemask k) to unaligned + /// memory at base_addr. + bool isCompressingStore() const { return StoreSDNodeBits.IsCompressing; } + + const SDValue &getValue() const { return getOperand(1); } + const SDValue &getBasePtr() const { return getOperand(2); } + const SDValue &getOffset() const { return getOperand(3); } + const SDValue &getMask() const { return getOperand(4); } + const SDValue &getVectorLength() const { return getOperand(5); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_STORE; + } +}; + /// This base class is used to represent MLOAD and MSTORE nodes class MaskedLoadStoreSDNode : public MemSDNode { public: @@ -2421,6 +2545,94 @@ } }; +/// This is a base class used to represent +/// VP_GATHER and VP_SCATTER nodes +/// +class VPGatherScatterSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO, ISD::MemIndexType IndexType) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) { + LSBaseSDNodeBits.AddressingMode = IndexType; + assert(getIndexType() == IndexType && "Value truncated"); + } + + /// How is Index applied to BasePtr when computing addresses. + ISD::MemIndexType getIndexType() const { + return static_cast(LSBaseSDNodeBits.AddressingMode); + } + bool isIndexScaled() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::UNSIGNED_SCALED); + } + bool isIndexSigned() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::SIGNED_UNSCALED); + } + + // In the both nodes address is Op1, mask is Op2: + // VPGatherSDNode (Chain, base, index, scale, mask, vlen) + // VPScatterSDNode (Chain, value, base, index, scale, mask, vlen) + // Mask is a vector of i1 elements + const SDValue &getBasePtr() const { + return getOperand((getOpcode() == ISD::VP_GATHER) ? 1 : 2); + } + const SDValue &getIndex() const { + return getOperand((getOpcode() == ISD::VP_GATHER) ? 2 : 3); + } + const SDValue &getScale() const { + return getOperand((getOpcode() == ISD::VP_GATHER) ? 3 : 4); + } + const SDValue &getMask() const { + return getOperand((getOpcode() == ISD::VP_GATHER) ? 4 : 5); + } + const SDValue &getVectorLength() const { + return getOperand((getOpcode() == ISD::VP_GATHER) ? 5 : 6); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER || + N->getOpcode() == ISD::VP_SCATTER; + } +}; + +/// This class is used to represent an VP_GATHER node +/// +class VPGatherSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO, ISD::MemIndexType IndexType) + : VPGatherScatterSDNode(ISD::VP_GATHER, Order, dl, VTs, MemVT, MMO, + IndexType) {} + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER; + } +}; + +/// This class is used to represent an VP_SCATTER node +/// +class VPScatterSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO, ISD::MemIndexType IndexType) + : VPGatherScatterSDNode(ISD::VP_SCATTER, Order, dl, VTs, MemVT, MMO, + IndexType) {} + + const SDValue &getValue() const { return getOperand(1); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_SCATTER; + } +}; + /// This is a base class used to represent /// MGATHER and MSCATTER nodes /// diff --git a/llvm/include/llvm/IR/VPIntrinsics.def b/llvm/include/llvm/IR/VPIntrinsics.def --- a/llvm/include/llvm/IR/VPIntrinsics.def +++ b/llvm/include/llvm/IR/VPIntrinsics.def @@ -204,7 +204,8 @@ ///// Memory Operations { // llvm.vp.store(ptr,val,mask,vlen) -BEGIN_REGISTER_VP(vp_store, 2, 3, VP_STORE, 0) +BEGIN_REGISTER_VP_INTRINSIC(vp_store, 2, 3) +BEGIN_REGISTER_VP_SDNODE(VP_STORE, 0, vp_store, 3, 4) HANDLE_VP_TO_OPC(Store) HANDLE_VP_TO_INTRIN(masked_store) HANDLE_VP_IS_MEMOP(vp_store, 1, 0) @@ -217,7 +218,8 @@ END_REGISTER_VP(vp_scatter, VP_SCATTER) // llvm.vp.load(ptr,mask,vlen) -BEGIN_REGISTER_VP(vp_load, 1, 2, VP_LOAD, -1) +BEGIN_REGISTER_VP_INTRINSIC(vp_load, 1, 2) +BEGIN_REGISTER_VP_SDNODE(VP_LOAD, -1, vp_load, 2, 3) HANDLE_VP_TO_OPC(Load) HANDLE_VP_TO_INTRIN(masked_load) HANDLE_VP_IS_MEMOP(vp_load, 0, None) diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -682,6 +682,34 @@ ID.AddInteger(ST->getPointerInfo().getAddrSpace()); break; } + case ISD::VP_LOAD: { + const VPLoadSDNode *ELD = cast(N); + ID.AddInteger(ELD->getMemoryVT().getRawBits()); + ID.AddInteger(ELD->getRawSubclassData()); + ID.AddInteger(ELD->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_STORE: { + const VPStoreSDNode *EST = cast(N); + ID.AddInteger(EST->getMemoryVT().getRawBits()); + ID.AddInteger(EST->getRawSubclassData()); + ID.AddInteger(EST->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_GATHER: { + const VPGatherSDNode *EG = cast(N); + ID.AddInteger(EG->getMemoryVT().getRawBits()); + ID.AddInteger(EG->getRawSubclassData()); + ID.AddInteger(EG->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_SCATTER: { + const VPScatterSDNode *ES = cast(N); + ID.AddInteger(ES->getMemoryVT().getRawBits()); + ID.AddInteger(ES->getRawSubclassData()); + ID.AddInteger(ES->getPointerInfo().getAddrSpace()); + break; + } case ISD::MLOAD: { const MaskedLoadSDNode *MLD = cast(N); ID.AddInteger(MLD->getMemoryVT().getRawBits()); @@ -7563,6 +7591,374 @@ return V; } +SDValue SelectionDAG::getLoadVP( + ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, const SDLoc &dl, + SDValue Chain, SDValue Ptr, SDValue Offset, SDValue Mask, SDValue VLen, + MachinePointerInfo PtrInfo, EVT MemVT, Align Alignment, + MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + const MDNode *Ranges, bool IsExpanding) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + + MMOFlags |= MachineMemOperand::MOLoad; + assert((MMOFlags & MachineMemOperand::MOStore) == 0); + // If we don't have a PtrInfo, infer the trivial frame index case to simplify + // clients. + if (PtrInfo.V.isNull()) + PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr, Offset); + + uint64_t Size = MemoryLocation::getSizeOrUnknown(MemVT.getStoreSize()); + MachineFunction &MF = getMachineFunction(); + MachineMemOperand *MMO = MF.getMachineMemOperand(PtrInfo, MMOFlags, Size, + Alignment, AAInfo, Ranges); + return getLoadVP(AM, ExtType, VT, dl, Chain, Ptr, Offset, Mask, VLen, MemVT, + MMO, IsExpanding); +} + +SDValue SelectionDAG::getLoadVP(ISD::MemIndexedMode AM, + ISD::LoadExtType ExtType, EVT VT, + const SDLoc &dl, SDValue Chain, SDValue Ptr, + SDValue Offset, SDValue Mask, SDValue VLen, + EVT MemVT, MachineMemOperand *MMO, + bool IsExpanding) { + if (VT == MemVT) { + ExtType = ISD::NON_EXTLOAD; + } else if (ExtType == ISD::NON_EXTLOAD) { + assert(VT == MemVT && "Non-extending load from different memory type!"); + } else { + // Extending load. + assert(MemVT.getScalarType().bitsLT(VT.getScalarType()) && + "Should only be an extending load, not truncating!"); + assert(VT.isInteger() == MemVT.isInteger() && + "Cannot convert from FP to Int or Int -> FP!"); + assert(VT.isVector() == MemVT.isVector() && + "Cannot use an ext load to convert to or from a vector!"); + assert((!VT.isVector() || + VT.getVectorElementCount() == MemVT.getVectorElementCount()) && + "Cannot use an ext load to change the number of vector elements!"); + } + + bool Indexed = AM != ISD::UNINDEXED; + assert((Indexed || Offset.isUndef()) && "Unindexed load with an offset!"); + + SDVTList VTs = Indexed ? getVTList(VT, Ptr.getValueType(), MVT::Other) + : getVTList(VT, MVT::Other); + SDValue Ops[] = {Chain, Ptr, Offset, Mask, VLen}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_LOAD, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, AM, ExtType, IsExpanding, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, AM, + ExtType, IsExpanding, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, + SDValue Ptr, SDValue Mask, SDValue VLen, + MachinePointerInfo PtrInfo, + MaybeAlign Alignment, + MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, const MDNode *Ranges, + bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, dl, Chain, Ptr, Undef, + Mask, VLen, PtrInfo, VT, Alignment, MMOFlags, AAInfo, Ranges, + IsExpanding); +} + +SDValue SelectionDAG::getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, + SDValue Ptr, SDValue Mask, SDValue VLen, + MachineMemOperand *MMO, bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, dl, Chain, Ptr, Undef, + Mask, VLen, VT, MMO, IsExpanding); +} + +SDValue SelectionDAG::getExtLoadVP(ISD::LoadExtType ExtType, const SDLoc &dl, + EVT VT, SDValue Chain, SDValue Ptr, + SDValue Mask, SDValue VLen, + MachinePointerInfo PtrInfo, EVT MemVT, + MaybeAlign Alignment, + MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getLoadVP(ISD::UNINDEXED, ExtType, VT, dl, Chain, Ptr, Undef, Mask, + VLen, PtrInfo, MemVT, Alignment, MMOFlags, AAInfo, nullptr, + IsExpanding); +} + +SDValue SelectionDAG::getExtLoadVP(ISD::LoadExtType ExtType, const SDLoc &dl, + EVT VT, SDValue Chain, SDValue Ptr, + SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getLoadVP(ISD::UNINDEXED, ExtType, VT, dl, Chain, Ptr, Undef, Mask, + VLen, MemVT, MMO, IsExpanding); +} + +SDValue SelectionDAG::getIndexedLoadVP(SDValue OrigLoad, const SDLoc &dl, + SDValue Base, SDValue Offset, + ISD::MemIndexedMode AM) { + auto *LD = cast(OrigLoad); + assert(LD->getOffset().isUndef() && "Load is already a indexed load!"); + // Don't propagate the invariant or dereferenceable flags. + auto MMOFlags = + LD->getMemOperand()->getFlags() & + ~(MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable); + return getLoadVP(AM, LD->getExtensionType(), OrigLoad.getValueType(), dl, + LD->getChain(), Base, Offset, LD->getMask(), + LD->getVectorLength(), LD->getPointerInfo(), + LD->getMemoryVT(), LD->getAlign(), MMOFlags, LD->getAAInfo(), + nullptr, LD->isExpandingLoad()); +} + +SDValue SelectionDAG::getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, + SDValue Ptr, SDValue Mask, SDValue VLen, + MachinePointerInfo PtrInfo, Align Alignment, + MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, bool IsCompressing) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + + MMOFlags |= MachineMemOperand::MOStore; + assert((MMOFlags & MachineMemOperand::MOLoad) == 0); + + if (PtrInfo.V.isNull()) + PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr); + + MachineFunction &MF = getMachineFunction(); + uint64_t Size = + MemoryLocation::getSizeOrUnknown(Val.getValueType().getStoreSize()); + MachineMemOperand *MMO = + MF.getMachineMemOperand(PtrInfo, MMOFlags, Size, Alignment, AAInfo); + return getStoreVP(Chain, dl, Val, Ptr, Mask, VLen, MMO, IsCompressing); +} + +SDValue SelectionDAG::getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, + SDValue Ptr, SDValue Mask, SDValue VLen, + MachineMemOperand *MMO, bool IsCompressing) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + EVT VT = Val.getValueType(); + SDVTList VTs = getVTList(MVT::Other); + SDValue Undef = getUNDEF(Ptr.getValueType()); + SDValue Ops[] = {Chain, Val, Ptr, Undef, Mask, VLen}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_STORE, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, ISD::UNINDEXED, false, IsCompressing, VT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = + newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + ISD::UNINDEXED, false, IsCompressing, VT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getTruncStoreVP(SDValue Chain, const SDLoc &dl, + SDValue Val, SDValue Ptr, SDValue Mask, + SDValue VLen, MachinePointerInfo PtrInfo, + EVT SVT, Align Alignment, + MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, + bool IsCompressing) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + + MMOFlags |= MachineMemOperand::MOStore; + assert((MMOFlags & MachineMemOperand::MOLoad) == 0); + + if (PtrInfo.V.isNull()) + PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr); + + MachineFunction &MF = getMachineFunction(); + MachineMemOperand *MMO = MF.getMachineMemOperand( + PtrInfo, MMOFlags, MemoryLocation::getSizeOrUnknown(SVT.getStoreSize()), + Alignment, AAInfo); + return getTruncStoreVP(Chain, dl, Val, Ptr, Mask, VLen, SVT, MMO, + IsCompressing); +} + +SDValue SelectionDAG::getTruncStoreVP(SDValue Chain, const SDLoc &dl, + SDValue Val, SDValue Ptr, SDValue Mask, + SDValue VLen, EVT SVT, + MachineMemOperand *MMO, + bool IsCompressing) { + EVT VT = Val.getValueType(); + + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + if (VT == SVT) + return getStoreVP(Chain, dl, Val, Ptr, Mask, VLen, MMO, IsCompressing); + + assert(SVT.getScalarType().bitsLT(VT.getScalarType()) && + "Should only be a truncating store, not extending!"); + assert(VT.isInteger() == SVT.isInteger() && "Can't do FP-INT conversion!"); + assert(VT.isVector() == SVT.isVector() && + "Cannot use trunc store to convert to or from a vector!"); + assert((!VT.isVector() || + VT.getVectorElementCount() == SVT.getVectorElementCount()) && + "Cannot use trunc store to change the number of vector elements!"); + + SDVTList VTs = getVTList(MVT::Other); + SDValue Undef = getUNDEF(Ptr.getValueType()); + SDValue Ops[] = {Chain, Val, Ptr, Undef, Mask, VLen}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_STORE, VTs, Ops); + ID.AddInteger(SVT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, ISD::UNINDEXED, true, IsCompressing, SVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = + newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + ISD::UNINDEXED, true, IsCompressing, SVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getIndexedStoreVP(SDValue OrigStore, const SDLoc &dl, + SDValue Base, SDValue Offset, + ISD::MemIndexedMode AM) { + auto *ST = cast(OrigStore); + assert(ST->getOffset().isUndef() && "Store is already an indexed store!"); + SDVTList VTs = getVTList(Base.getValueType(), MVT::Other); + SDValue Ops[] = {ST->getChain(), ST->getValue(), Base, + Offset, ST->getMask(), ST->getVectorLength()}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_STORE, VTs, Ops); + ID.AddInteger(ST->getMemoryVT().getRawBits()); + ID.AddInteger(ST->getRawSubclassData()); + ID.AddInteger(ST->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) + return SDValue(E, 0); + + auto *N = newSDNode( + dl.getIROrder(), dl.getDebugLoc(), VTs, AM, ST->isTruncatingStore(), + ST->isCompressingStore(), ST->getMemoryVT(), ST->getMemOperand()); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::MemIndexType IndexType) { + assert(Ops.size() == 6 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_GATHER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO, IndexType)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + VT, MMO, IndexType); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorElementCount() == + N->getValueType(0).getVectorElementCount() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorElementCount().isScalable() == + N->getValueType(0).getVectorElementCount().isScalable() && + "Scalable flags of index and data do not match"); + assert(ElementCount::isKnownGE( + N->getIndex().getValueType().getVectorElementCount(), + N->getValueType(0).getVectorElementCount()) && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO, + ISD::MemIndexType IndexType) { + assert(Ops.size() == 7 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_SCATTER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO, IndexType)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + VT, MMO, IndexType); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorElementCount() == + N->getValue().getValueType().getVectorElementCount() && + "Vector width mismatch between mask and data"); + assert( + N->getIndex().getValueType().getVectorElementCount().isScalable() == + N->getValue().getValueType().getVectorElementCount().isScalable() && + "Scalable flags of index and data do not match"); + assert(ElementCount::isKnownGE( + N->getIndex().getValueType().getVectorElementCount(), + N->getValue().getValueType().getVectorElementCount()) && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Base, SDValue Offset, SDValue Mask, SDValue PassThru, EVT MemVT, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -7356,8 +7356,122 @@ OpValues.push_back(Op); } - SDValue Result = DAG.getNode(Opcode, DL, VTs, OpValues); - setValue(&VPIntrin, Result); + switch (Opcode) { + default: { + SDValue Result = DAG.getNode(Opcode, DL, VTs, OpValues); + setValue(&VPIntrin, Result); + break; + } + case ISD::VP_LOAD: + case ISD::VP_GATHER: { + Value *PtrOperand = VPIntrin.getArgOperand(0); + EVT VT = ValueVTs[0]; + MaybeAlign Alignment = DAG.getEVTAlign(VT); + AAMDNodes AAInfo; + VPIntrin.getAAMetadata(AAInfo); + const MDNode *Ranges = VPIntrin.getMetadata(LLVMContext::MD_range); + SDValue LD; + bool AddToChain = true; + if (Opcode == ISD::VP_LOAD) { + // Do not serialize variable-length loads of constant memory with + // anything. + MemoryLocation ML; + if (VT.isScalableVector()) + ML = MemoryLocation::getAfter(PtrOperand); + else + ML = MemoryLocation( + PtrOperand, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(VPIntrin.getType())), + AAInfo); + AddToChain = !AA || !AA->pointsToConstantMemory(ML); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); + MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( + MachinePointerInfo(PtrOperand), MachineMemOperand::MOLoad, + VT.getStoreSize().getKnownMinSize(), *Alignment, AAInfo, Ranges); + LD = DAG.getLoadVP(VT, DL, InChain, OpValues[0], OpValues[1], OpValues[2], + MMO, false /*IsExpanding */); + } else { + unsigned AS = + PtrOperand->getType()->getScalarType()->getPointerAddressSpace(); + MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( + MachinePointerInfo(AS), MachineMemOperand::MOLoad, + MemoryLocation::UnknownSize, *Alignment, AAInfo, Ranges); + SDValue Base, Index, Scale; + ISD::MemIndexType IndexType; + bool UniformBase = getUniformBase(PtrOperand, Base, Index, IndexType, + Scale, this, VPIntrin.getParent()); + if (!UniformBase) { + Base = DAG.getConstant(0, DL, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(PtrOperand); + IndexType = ISD::SIGNED_UNSCALED; + Scale = + DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout())); + } + EVT IdxVT = Index.getValueType(); + EVT EltTy = IdxVT.getVectorElementType(); + if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) { + EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy); + Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); + } + LD = DAG.getGatherVP( + DAG.getVTList(VT, MVT::Other), VT, DL, + {DAG.getRoot(), Base, Index, Scale, OpValues[1], OpValues[2]}, MMO, + IndexType); + } + if (AddToChain) + PendingLoads.push_back(LD.getValue(1)); + setValue(&VPIntrin, LD); + break; + } + case ISD::VP_STORE: + case ISD::VP_SCATTER: { + Value *PtrOperand = VPIntrin.getArgOperand(1); + EVT VT = OpValues[0].getValueType(); + MaybeAlign Alignment = DAG.getEVTAlign(VT); + AAMDNodes AAInfo; + VPIntrin.getAAMetadata(AAInfo); + SDValue ST; + if (Opcode == ISD::VP_STORE) { + MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( + MachinePointerInfo(PtrOperand), MachineMemOperand::MOStore, + VT.getStoreSize().getKnownMinSize(), *Alignment, AAInfo); + ST = DAG.getStoreVP(getMemoryRoot(), DL, OpValues[0], OpValues[1], + OpValues[2], OpValues[3], MMO, + false /* IsTruncating */); + } else { + unsigned AS = + PtrOperand->getType()->getScalarType()->getPointerAddressSpace(); + MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( + MachinePointerInfo(AS), MachineMemOperand::MOStore, + MemoryLocation::UnknownSize, *Alignment, AAInfo); + SDValue Base, Index, Scale; + ISD::MemIndexType IndexType; + bool UniformBase = getUniformBase(PtrOperand, Base, Index, IndexType, + Scale, this, VPIntrin.getParent()); + if (!UniformBase) { + Base = DAG.getConstant(0, DL, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(PtrOperand); + IndexType = ISD::SIGNED_UNSCALED; + Scale = + DAG.getTargetConstant(1, DL, TLI.getPointerTy(DAG.getDataLayout())); + } + EVT IdxVT = Index.getValueType(); + EVT EltTy = IdxVT.getVectorElementType(); + if (TLI.shouldExtendGSIndex(IdxVT, EltTy)) { + EVT NewIdxVT = IdxVT.changeVectorElementType(EltTy); + Index = DAG.getNode(ISD::SIGN_EXTEND, DL, NewIdxVT, Index); + } + ST = DAG.getScatterVP(DAG.getVTList(MVT::Other), VT, DL, + {getMemoryRoot(), OpValues[0], Base, Index, Scale, + OpValues[2], OpValues[3]}, + MMO, IndexType); + } + DAG.setRoot(ST); + setValue(&VPIntrin, ST); + break; + } + } } SDValue SelectionDAGBuilder::lowerStartEH(SDValue Chain,