Index: llvm/include/llvm/CodeGen/SelectionDAG.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAG.h +++ llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1368,6 +1368,81 @@ SDValue getIndexedStoreVP(SDValue OrigStore, const SDLoc &dl, SDValue Base, SDValue Offset, ISD::MemIndexedMode AM); + SDValue getStridedLoadVP(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, + EVT VT, const SDLoc &DL, SDValue Chain, SDValue Ptr, + SDValue Offset, SDValue Stride, SDValue Mask, + SDValue EVL, MachinePointerInfo PtrInfo, EVT MemVT, + Align Alignment, MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, + const MDNode *Ranges = nullptr, + bool IsExpanding = false); + inline SDValue getStridedLoadVP( + ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, const SDLoc &DL, + SDValue Chain, SDValue Ptr, SDValue Offset, SDValue Stride, SDValue Mask, + SDValue EVL, MachinePointerInfo PtrInfo, EVT MemVT, + MaybeAlign Alignment = MaybeAlign(), + MachineMemOperand::Flags MMOFlags = MachineMemOperand::MONone, + const AAMDNodes &AAInfo = AAMDNodes(), const MDNode *Ranges = nullptr, + bool IsExpanding = false) { + // Ensures that codegen never sees a None Alignment. + return getStridedLoadVP(AM, ExtType, VT, DL, Chain, Ptr, Offset, Stride, + Mask, EVL, PtrInfo, MemVT, + Alignment.getValueOr(getEVTAlign(MemVT)), MMOFlags, + AAInfo, Ranges, IsExpanding); + } + SDValue getStridedLoadVP(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, + EVT VT, const SDLoc &DL, SDValue Chain, SDValue Ptr, + SDValue Offset, SDValue Stride, SDValue Mask, + SDValue EVL, EVT MemVT, MachineMemOperand *MMO, + bool IsExpanding = false); + SDValue getStridedLoadVP(EVT VT, const SDLoc &DL, SDValue Chain, SDValue Ptr, + SDValue Stride, SDValue Mask, SDValue EVL, + MachinePointerInfo PtrInfo, MaybeAlign Alignment, + MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, + const MDNode *Ranges = nullptr, + bool IsExpanding = false); + SDValue getStridedLoadVP(EVT VT, const SDLoc &DL, SDValue Chain, SDValue Ptr, + SDValue Stride, SDValue Mask, SDValue EVL, + MachineMemOperand *MMO, bool IsExpanding = false); + SDValue + getExtStridedLoadVP(ISD::LoadExtType ExtType, const SDLoc &DL, EVT VT, + SDValue Chain, SDValue Ptr, SDValue Stride, SDValue Mask, + SDValue EVL, MachinePointerInfo PtrInfo, EVT MemVT, + MaybeAlign Alignment, MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, bool IsExpanding = false); + SDValue getExtStridedLoadVP(ISD::LoadExtType ExtType, const SDLoc &DL, EVT VT, + SDValue Chain, SDValue Ptr, SDValue Stride, + SDValue Mask, SDValue EVL, EVT MemVT, + MachineMemOperand *MMO, bool IsExpanding = false); + SDValue getIndexedStridedLoadVP(SDValue OrigLoad, const SDLoc &DL, + SDValue Base, SDValue Offset, + ISD::MemIndexedMode AM); + SDValue getStridedStoreVP(SDValue Chain, const SDLoc &DL, SDValue Val, + SDValue Ptr, SDValue Stride, SDValue Mask, + SDValue EVL, MachinePointerInfo PtrInfo, + Align Alignment, MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo = AAMDNodes(), + bool IsCompressing = false); + SDValue getStridedStoreVP(SDValue Chain, const SDLoc &DL, SDValue Val, + SDValue Ptr, SDValue Stride, SDValue Mask, + SDValue EVL, MachineMemOperand *MMO, + bool IsCompressing = false); + SDValue getTruncStridedStoreVP(SDValue Chain, const SDLoc &DL, SDValue Val, + SDValue Ptr, SDValue Stride, SDValue Mask, + SDValue EVL, MachinePointerInfo PtrInfo, + EVT SVT, Align Alignment, + MachineMemOperand::Flags MMOFlags, + const AAMDNodes &AAInfo, + bool IsCompressing = false); + SDValue getTruncStridedStoreVP(SDValue Chain, const SDLoc &DL, SDValue Val, + SDValue Ptr, SDValue Stride, SDValue Mask, + SDValue EVL, EVT SVT, MachineMemOperand *MMO, + bool IsCompressing = false); + SDValue getIndexedStridedStoreVP(SDValue OrigStore, const SDLoc &DL, + SDValue Base, SDValue Offset, + ISD::MemIndexedMode AM); + SDValue getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, ISD::MemIndexType IndexType); Index: llvm/include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -508,7 +508,7 @@ class LSBaseSDNodeBitfields { friend class LSBaseSDNode; - friend class VPLoadStoreSDNode; + friend class VPBaseLoadStoreSDNode; friend class MaskedLoadStoreSDNode; friend class MaskedGatherScatterSDNode; friend class VPGatherScatterSDNode; @@ -529,6 +529,7 @@ class LoadSDNodeBitfields { friend class LoadSDNode; friend class VPLoadSDNode; + friend class VPStridedLoadSDNode; friend class MaskedLoadSDNode; friend class MaskedGatherSDNode; friend class VPGatherSDNode; @@ -542,6 +543,7 @@ class StoreSDNodeBitfields { friend class StoreSDNode; friend class VPStoreSDNode; + friend class VPStridedStoreSDNode; friend class MaskedStoreSDNode; friend class MaskedScatterSDNode; friend class VPScatterSDNode; @@ -1363,6 +1365,7 @@ case ISD::VP_STORE: case ISD::MSTORE: case ISD::VP_SCATTER: + case ISD::EXPERIMENTAL_VP_STRIDED_STORE: return getOperand(2); case ISD::MGATHER: case ISD::MSCATTER: @@ -1406,6 +1409,8 @@ case ISD::VP_STORE: case ISD::VP_GATHER: case ISD::VP_SCATTER: + case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: + case ISD::EXPERIMENTAL_VP_STRIDED_STORE: return true; default: return N->isMemIntrinsic() || N->isTargetMemoryOpcode(); @@ -2352,34 +2357,64 @@ } }; -/// This base class is used to represent VP_LOAD and VP_STORE nodes -class VPLoadStoreSDNode : public MemSDNode { +/// This base class is used to represent VP_LOAD, VP_STORE, +/// EXPERIMENTAL_VP_STRIDED_LOAD and EXPERIMENTAL_VP_STRIDED_STORE nodes +class VPBaseLoadStoreSDNode : public MemSDNode { public: friend class SelectionDAG; - VPLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, const DebugLoc &dl, - SDVTList VTs, ISD::MemIndexedMode AM, EVT MemVT, - MachineMemOperand *MMO) - : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) { + VPBaseLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &DL, SDVTList VTs, + ISD::MemIndexedMode AM, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, DL, VTs, MemVT, MMO) { LSBaseSDNodeBits.AddressingMode = AM; assert(getAddressingMode() == AM && "Value truncated"); } - // VPLoadSDNode (Chain, Ptr, Offset, Mask, EVL) - // VPStoreSDNode (Chain, Data, Ptr, Offset, Mask, EVL) + // VPStridedStoreSDNode (Chain, Data, Ptr, Offset, Stride, Mask, EVL) + // VPStoreSDNode (Chain, Data, Ptr, Offset, Mask, EVL) + // VPStridedLoadSDNode (Chain, Ptr, Offset, Stride, Mask, EVL) + // VPLoadSDNode (Chain, Ptr, Offset, Mask, EVL) // Mask is a vector of i1 elements; // the type of EVL is TLI.getVPExplicitVectorLengthTy(). const SDValue &getOffset() const { - return getOperand(getOpcode() == ISD::VP_LOAD ? 2 : 3); + return getOperand((getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD || + getOpcode() == ISD::VP_LOAD) + ? 2 + : 3); } const SDValue &getBasePtr() const { - return getOperand(getOpcode() == ISD::VP_LOAD ? 1 : 2); + return getOperand((getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD || + getOpcode() == ISD::VP_LOAD) + ? 1 + : 2); } const SDValue &getMask() const { - return getOperand(getOpcode() == ISD::VP_LOAD ? 3 : 4); + switch (getOpcode()) { + default: + llvm_unreachable("Invalid opcode"); + case ISD::VP_LOAD: + return getOperand(3); + case ISD::VP_STORE: + case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: + return getOperand(4); + case ISD::EXPERIMENTAL_VP_STRIDED_STORE: + return getOperand(5); + } } const SDValue &getVectorLength() const { - return getOperand(getOpcode() == ISD::VP_LOAD ? 4 : 5); + switch (getOpcode()) { + default: + llvm_unreachable("Invalid opcode"); + case ISD::VP_LOAD: + return getOperand(4); + case ISD::VP_STORE: + case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: + return getOperand(5); + case ISD::EXPERIMENTAL_VP_STRIDED_STORE: + return getOperand(6); + } } /// Return the addressing mode for this load or store: @@ -2395,19 +2430,21 @@ bool isUnindexed() const { return getAddressingMode() == ISD::UNINDEXED; } static bool classof(const SDNode *N) { - return N->getOpcode() == ISD::VP_LOAD || N->getOpcode() == ISD::VP_STORE; + return N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD || + N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE || + N->getOpcode() == ISD::VP_LOAD || N->getOpcode() == ISD::VP_STORE; } }; /// This class is used to represent a VP_LOAD node -class VPLoadSDNode : public VPLoadStoreSDNode { +class VPLoadSDNode : public VPBaseLoadStoreSDNode { public: friend class SelectionDAG; VPLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, ISD::MemIndexedMode AM, ISD::LoadExtType ETy, bool isExpanding, EVT MemVT, MachineMemOperand *MMO) - : VPLoadStoreSDNode(ISD::VP_LOAD, Order, dl, VTs, AM, MemVT, MMO) { + : VPBaseLoadStoreSDNode(ISD::VP_LOAD, Order, dl, VTs, AM, MemVT, MMO) { LoadSDNodeBits.ExtTy = ETy; LoadSDNodeBits.IsExpanding = isExpanding; } @@ -2427,15 +2464,45 @@ bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } }; +/// This class is used to represent an EXPERIMENTAL_VP_STRIDED_LOAD node. +class VPStridedLoadSDNode : public VPBaseLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPStridedLoadSDNode(unsigned Order, const DebugLoc &DL, SDVTList VTs, + ISD::MemIndexedMode AM, ISD::LoadExtType ETy, + bool IsExpanding, EVT MemVT, MachineMemOperand *MMO) + : VPBaseLoadStoreSDNode(ISD::EXPERIMENTAL_VP_STRIDED_LOAD, Order, DL, VTs, + AM, MemVT, MMO) { + LoadSDNodeBits.ExtTy = ETy; + LoadSDNodeBits.IsExpanding = IsExpanding; + } + + ISD::LoadExtType getExtensionType() const { + return static_cast(LoadSDNodeBits.ExtTy); + } + + const SDValue &getBasePtr() const { return getOperand(1); } + const SDValue &getOffset() const { return getOperand(2); } + const SDValue &getStride() const { return getOperand(3); } + const SDValue &getMask() const { return getOperand(4); } + const SDValue &getVectorLength() const { return getOperand(5); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_LOAD; + } + bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } +}; + /// This class is used to represent a VP_STORE node -class VPStoreSDNode : public VPLoadStoreSDNode { +class VPStoreSDNode : public VPBaseLoadStoreSDNode { public: friend class SelectionDAG; VPStoreSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, ISD::MemIndexedMode AM, bool isTrunc, bool isCompressing, EVT MemVT, MachineMemOperand *MMO) - : VPLoadStoreSDNode(ISD::VP_STORE, Order, dl, VTs, AM, MemVT, MMO) { + : VPBaseLoadStoreSDNode(ISD::VP_STORE, Order, dl, VTs, AM, MemVT, MMO) { StoreSDNodeBits.IsTruncating = isTrunc; StoreSDNodeBits.IsCompressing = isCompressing; } @@ -2462,6 +2529,43 @@ } }; +/// This class is used to represent an EXPERIMENTAL_VP_STRIDED_STORE node. +class VPStridedStoreSDNode : public VPBaseLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPStridedStoreSDNode(unsigned Order, const DebugLoc &DL, SDVTList VTs, + ISD::MemIndexedMode AM, bool IsTrunc, bool IsCompressing, + EVT MemVT, MachineMemOperand *MMO) + : VPBaseLoadStoreSDNode(ISD::EXPERIMENTAL_VP_STRIDED_STORE, Order, DL, + VTs, AM, MemVT, MMO) { + StoreSDNodeBits.IsTruncating = IsTrunc; + StoreSDNodeBits.IsCompressing = IsCompressing; + } + + /// Return true if this is a truncating store. + /// For integers this is the same as doing a TRUNCATE and storing the result. + /// For floats, it is the same as doing an FP_ROUND and storing the result. + bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; } + + /// Returns true if the op does a compression to the vector before storing. + /// The node contiguously stores the active elements (integers or floats) + /// in src (those with their respective bit set in writemask k) to unaligned + /// memory at base_addr. + bool isCompressingStore() const { return StoreSDNodeBits.IsCompressing; } + + const SDValue &getValue() const { return getOperand(1); } + const SDValue &getBasePtr() const { return getOperand(2); } + const SDValue &getOffset() const { return getOperand(3); } + const SDValue &getStride() const { return getOperand(4); } + const SDValue &getMask() const { return getOperand(5); } + const SDValue &getVectorLength() const { return getOperand(6); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::EXPERIMENTAL_VP_STRIDED_STORE; + } +}; + /// This base class is used to represent MLOAD and MSTORE nodes class MaskedLoadStoreSDNode : public MemSDNode { public: Index: llvm/include/llvm/IR/Intrinsics.td =================================================================== --- llvm/include/llvm/IR/Intrinsics.td +++ llvm/include/llvm/IR/Intrinsics.td @@ -1397,6 +1397,22 @@ llvm_i32_ty], [ IntrArgMemOnly, IntrNoSync, IntrWillReturn ]>; // TODO allow IntrNoCapture for vectors of pointers +// Experimental strided memory accesses +def int_experimental_vp_strided_store : DefaultAttrsIntrinsic<[], + [ llvm_anyvector_ty, + LLVMAnyPointerType>, + llvm_anyint_ty, // Stride in bytes + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ NoCapture>, IntrNoSync, IntrWriteMem, IntrArgMemOnly, IntrWillReturn ]>; + +def int_experimental_vp_strided_load : DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [ LLVMAnyPointerType>, + llvm_anyint_ty, // Stride in bytes + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ NoCapture>, IntrNoSync, IntrReadMem, IntrWillReturn, IntrArgMemOnly ]>; + // Speculatable Binary operators let IntrProperties = [IntrSpeculatable, IntrNoMem, IntrNoSync, IntrWillReturn] in { def int_vp_add : DefaultAttrsIntrinsic<[ llvm_anyvector_ty ], Index: llvm/include/llvm/IR/VPIntrinsics.def =================================================================== --- llvm/include/llvm/IR/VPIntrinsics.def +++ llvm/include/llvm/IR/VPIntrinsics.def @@ -214,7 +214,7 @@ ///// } Floating-Point Arithmetic ///// Memory Operations { -// llvm.vp.store(ptr,val,mask,vlen) +// llvm.vp.store(val,ptr,mask,vlen) BEGIN_REGISTER_VP_INTRINSIC(vp_store, 2, 3) // chain = VP_STORE chain,val,base,offset,mask,evl BEGIN_REGISTER_VP_SDNODE(VP_STORE, 0, vp_store, 4, 5) @@ -223,6 +223,13 @@ VP_PROPERTY_MEMOP(1, 0) END_REGISTER_VP(vp_store, VP_STORE) +// llvm.experimental.vp.strided.store(val,ptr,stride,mask,vlen) +BEGIN_REGISTER_VP_INTRINSIC(experimental_vp_strided_store, 3, 4) +// chain = EXPERIMENTAL_VP_STRIDED_STORE chain,val,base,offset,stride,mask,evl +BEGIN_REGISTER_VP_SDNODE(EXPERIMENTAL_VP_STRIDED_STORE, 0, experimental_vp_strided_store, 5, 6) +VP_PROPERTY_MEMOP(1, 0) +END_REGISTER_VP(experimental_vp_strided_store, EXPERIMENTAL_VP_STRIDED_STORE) + // llvm.vp.scatter(ptr,val,mask,vlen) BEGIN_REGISTER_VP_INTRINSIC(vp_scatter, 2, 3) // chain = VP_SCATTER chain,val,base,indices,scale,mask,evl @@ -240,6 +247,13 @@ VP_PROPERTY_MEMOP(0, None) END_REGISTER_VP(vp_load, VP_LOAD) +// llvm.experimental.vp.strided.load(ptr,stride,mask,vlen) +BEGIN_REGISTER_VP_INTRINSIC(experimental_vp_strided_load, 2, 3) +// chain = EXPERIMENTAL_VP_STRIDED_LOAD chain,base,offset,stride,mask,evl +BEGIN_REGISTER_VP_SDNODE(EXPERIMENTAL_VP_STRIDED_LOAD, -1, experimental_vp_strided_load, 4, 5) +VP_PROPERTY_MEMOP(0, None) +END_REGISTER_VP(experimental_vp_strided_load, EXPERIMENTAL_VP_STRIDED_LOAD) + // llvm.vp.gather(ptr,mask,vlen) BEGIN_REGISTER_VP_INTRINSIC(vp_gather, 1, 2) // val,chain = VP_GATHER chain,base,indices,scale,mask,evl Index: llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -1174,6 +1174,11 @@ Node->getOpcode(), cast(Node)->getValue().getValueType()); break; + case ISD::EXPERIMENTAL_VP_STRIDED_STORE: + Action = TLI.getOperationAction( + Node->getOpcode(), + cast(Node)->getValue().getValueType()); + break; case ISD::VECREDUCE_FADD: case ISD::VECREDUCE_FMUL: case ISD::VECREDUCE_ADD: Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -722,6 +722,20 @@ ID.AddInteger(EST->getPointerInfo().getAddrSpace()); break; } + case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: { + const VPStridedLoadSDNode *SLD = cast(N); + ID.AddInteger(SLD->getMemoryVT().getRawBits()); + ID.AddInteger(SLD->getRawSubclassData()); + ID.AddInteger(SLD->getPointerInfo().getAddrSpace()); + break; + } + case ISD::EXPERIMENTAL_VP_STRIDED_STORE: { + const VPStridedStoreSDNode *SST = cast(N); + ID.AddInteger(SST->getMemoryVT().getRawBits()); + ID.AddInteger(SST->getRawSubclassData()); + ID.AddInteger(SST->getPointerInfo().getAddrSpace()); + break; + } case ISD::VP_GATHER: { const VPGatherSDNode *EG = cast(N); ID.AddInteger(EG->getMemoryVT().getRawBits()); @@ -7905,6 +7919,296 @@ return V; } +SDValue SelectionDAG::getStridedLoadVP( + ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, const SDLoc &DL, + SDValue Chain, SDValue Ptr, SDValue Offset, SDValue Stride, SDValue Mask, + SDValue EVL, MachinePointerInfo PtrInfo, EVT MemVT, Align Alignment, + MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + const MDNode *Ranges, bool IsExpanding) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + + MMOFlags |= MachineMemOperand::MOLoad; + assert((MMOFlags & MachineMemOperand::MOStore) == 0); + // If we don't have a PtrInfo, infer the trivial frame index case to simplify + // clients. + if (PtrInfo.V.isNull()) + PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr, Offset); + + uint64_t Size = MemoryLocation::getSizeOrUnknown(MemVT.getStoreSize()); + MachineFunction &MF = getMachineFunction(); + MachineMemOperand *MMO = MF.getMachineMemOperand(PtrInfo, MMOFlags, Size, + Alignment, AAInfo, Ranges); + return getStridedLoadVP(AM, ExtType, VT, DL, Chain, Ptr, Offset, Stride, Mask, + EVL, MemVT, MMO, IsExpanding); +} + +SDValue SelectionDAG::getStridedLoadVP( + ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT, const SDLoc &DL, + SDValue Chain, SDValue Ptr, SDValue Offset, SDValue Stride, SDValue Mask, + SDValue EVL, EVT MemVT, MachineMemOperand *MMO, bool IsExpanding) { + if (VT == MemVT) { + ExtType = ISD::NON_EXTLOAD; + } else if (ExtType == ISD::NON_EXTLOAD) { + assert(VT == MemVT && "Non-extending load from different memory type!"); + } else { + // Extending load. + assert(MemVT.getScalarType().bitsLT(VT.getScalarType()) && + "Should only be an extending load, not truncating!"); + assert(VT.isInteger() == MemVT.isInteger() && + "Cannot convert from FP to Int or Int -> FP!"); + assert(VT.isVector() == MemVT.isVector() && + "Cannot use an ext load to convert to or from a vector!"); + assert((!VT.isVector() || + VT.getVectorElementCount() == MemVT.getVectorElementCount()) && + "Cannot use an ext load to change the number of vector elements!"); + } + + bool Indexed = AM != ISD::UNINDEXED; + assert((Indexed || Offset.isUndef()) && "Unindexed load with an offset!"); + + SDValue Ops[] = {Chain, Ptr, Offset, Stride, Mask, EVL}; + SDVTList VTs = Indexed ? getVTList(VT, Ptr.getValueType(), MVT::Other) + : getVTList(VT, MVT::Other); + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_LOAD, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + DL.getIROrder(), VTs, AM, ExtType, IsExpanding, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + + auto *N = + newSDNode(DL.getIROrder(), DL.getDebugLoc(), VTs, AM, + ExtType, IsExpanding, MemVT, MMO); + createOperands(N, Ops); + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getStridedLoadVP( + EVT VT, const SDLoc &DL, SDValue Chain, SDValue Ptr, SDValue Stride, + SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo, MaybeAlign Alignment, + MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + const MDNode *Ranges, bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getStridedLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, DL, Chain, Ptr, + Undef, Stride, Mask, EVL, PtrInfo, VT, Alignment, + MMOFlags, AAInfo, Ranges, IsExpanding); +} + +SDValue SelectionDAG::getStridedLoadVP(EVT VT, const SDLoc &DL, SDValue Chain, + SDValue Ptr, SDValue Stride, + SDValue Mask, SDValue EVL, + MachineMemOperand *MMO, + bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getStridedLoadVP(ISD::UNINDEXED, ISD::NON_EXTLOAD, VT, DL, Chain, Ptr, + Undef, Stride, Mask, EVL, VT, MMO, IsExpanding); +} + +SDValue SelectionDAG::getExtStridedLoadVP( + ISD::LoadExtType ExtType, const SDLoc &DL, EVT VT, SDValue Chain, + SDValue Ptr, SDValue Stride, SDValue Mask, SDValue EVL, + MachinePointerInfo PtrInfo, EVT MemVT, MaybeAlign Alignment, + MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getStridedLoadVP(ISD::UNINDEXED, ExtType, VT, DL, Chain, Ptr, Undef, + Stride, Mask, EVL, PtrInfo, MemVT, Alignment, + MMOFlags, AAInfo, nullptr, IsExpanding); +} + +SDValue SelectionDAG::getExtStridedLoadVP( + ISD::LoadExtType ExtType, const SDLoc &DL, EVT VT, SDValue Chain, + SDValue Ptr, SDValue Stride, SDValue Mask, SDValue EVL, EVT MemVT, + MachineMemOperand *MMO, bool IsExpanding) { + SDValue Undef = getUNDEF(Ptr.getValueType()); + return getStridedLoadVP(ISD::UNINDEXED, ExtType, VT, DL, Chain, Ptr, Undef, + Stride, Mask, EVL, MemVT, MMO, IsExpanding); +} + +SDValue SelectionDAG::getIndexedStridedLoadVP(SDValue OrigLoad, const SDLoc &DL, + SDValue Base, SDValue Offset, + ISD::MemIndexedMode AM) { + auto *SLD = cast(OrigLoad); + assert(SLD->getOffset().isUndef() && + "Strided load is already a indexed load!"); + // Don't propagate the invariant or dereferenceable flags. + auto MMOFlags = + SLD->getMemOperand()->getFlags() & + ~(MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable); + return getStridedLoadVP( + AM, SLD->getExtensionType(), OrigLoad.getValueType(), DL, SLD->getChain(), + Base, Offset, SLD->getStride(), SLD->getMask(), SLD->getVectorLength(), + SLD->getPointerInfo(), SLD->getMemoryVT(), SLD->getAlign(), MMOFlags, + SLD->getAAInfo(), nullptr, SLD->isExpandingLoad()); +} + +SDValue SelectionDAG::getStridedStoreVP( + SDValue Chain, const SDLoc &DL, SDValue Val, SDValue Ptr, SDValue Stride, + SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo, Align Alignment, + MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + bool IsCompressing) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + + MMOFlags |= MachineMemOperand::MOStore; + assert((MMOFlags & MachineMemOperand::MOLoad) == 0); + + if (PtrInfo.V.isNull()) + PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr); + + MachineFunction &MF = getMachineFunction(); + uint64_t Size = + MemoryLocation::getSizeOrUnknown(Val.getValueType().getStoreSize()); + MachineMemOperand *MMO = + MF.getMachineMemOperand(PtrInfo, MMOFlags, Size, Alignment, AAInfo); + return getStridedStoreVP(Chain, DL, Val, Ptr, Stride, Mask, EVL, MMO, + IsCompressing); +} + +SDValue SelectionDAG::getStridedStoreVP(SDValue Chain, const SDLoc &DL, + SDValue Val, SDValue Ptr, + SDValue Stride, SDValue Mask, + SDValue EVL, MachineMemOperand *MMO, + bool IsCompressing) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + EVT VT = Val.getValueType(); + SDVTList VTs = getVTList(MVT::Other); + SDValue Undef = getUNDEF(Ptr.getValueType()); + SDValue Ops[] = {Chain, Val, Ptr, Undef, Stride, Mask, EVL}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_STORE, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + DL.getIROrder(), VTs, ISD::UNINDEXED, false, IsCompressing, VT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + + auto *N = newSDNode(DL.getIROrder(), DL.getDebugLoc(), + VTs, ISD::UNINDEXED, false, + IsCompressing, VT, MMO); + createOperands(N, Ops); + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getTruncStridedStoreVP( + SDValue Chain, const SDLoc &DL, SDValue Val, SDValue Ptr, SDValue Stride, + SDValue Mask, SDValue EVL, MachinePointerInfo PtrInfo, EVT SVT, + Align Alignment, MachineMemOperand::Flags MMOFlags, const AAMDNodes &AAInfo, + bool IsCompressing) { + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + + MMOFlags |= MachineMemOperand::MOStore; + assert((MMOFlags & MachineMemOperand::MOLoad) == 0); + + if (PtrInfo.V.isNull()) + PtrInfo = InferPointerInfo(PtrInfo, *this, Ptr); + + MachineFunction &MF = getMachineFunction(); + MachineMemOperand *MMO = MF.getMachineMemOperand( + PtrInfo, MMOFlags, MemoryLocation::getSizeOrUnknown(SVT.getStoreSize()), + Alignment, AAInfo); + return getTruncStridedStoreVP(Chain, DL, Val, Ptr, Stride, Mask, EVL, SVT, + MMO, IsCompressing); +} + +SDValue SelectionDAG::getTruncStridedStoreVP(SDValue Chain, const SDLoc &DL, + SDValue Val, SDValue Ptr, + SDValue Stride, SDValue Mask, + SDValue EVL, EVT SVT, + MachineMemOperand *MMO, + bool IsCompressing) { + EVT VT = Val.getValueType(); + + assert(Chain.getValueType() == MVT::Other && "Invalid chain type"); + if (VT == SVT) + return getStridedStoreVP(Chain, DL, Val, Ptr, Stride, Mask, EVL, MMO, + IsCompressing); + + assert(SVT.getScalarType().bitsLT(VT.getScalarType()) && + "Should only be a truncating store, not extending!"); + assert(VT.isInteger() == SVT.isInteger() && "Can't do FP-INT conversion!"); + assert(VT.isVector() == SVT.isVector() && + "Cannot use trunc store to convert to or from a vector!"); + assert((!VT.isVector() || + VT.getVectorElementCount() == SVT.getVectorElementCount()) && + "Cannot use trunc store to change the number of vector elements!"); + + SDVTList VTs = getVTList(MVT::Other); + SDValue Undef = getUNDEF(Ptr.getValueType()); + SDValue Ops[] = {Chain, Val, Ptr, Undef, Stride, Mask, EVL}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_STORE, VTs, Ops); + ID.AddInteger(SVT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + DL.getIROrder(), VTs, ISD::UNINDEXED, true, IsCompressing, SVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(DL.getIROrder(), DL.getDebugLoc(), + VTs, ISD::UNINDEXED, true, + IsCompressing, SVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getIndexedStridedStoreVP(SDValue OrigStore, + const SDLoc &DL, SDValue Base, + SDValue Offset, + ISD::MemIndexedMode AM) { + auto *SST = cast(OrigStore); + assert(SST->getOffset().isUndef() && + "Strided store is already an indexed store!"); + SDVTList VTs = getVTList(Base.getValueType(), MVT::Other); + SDValue Ops[] = { + SST->getChain(), SST->getValue(), Base, Offset, SST->getStride(), + SST->getMask(), SST->getVectorLength()}; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EXPERIMENTAL_VP_STRIDED_STORE, VTs, Ops); + ID.AddInteger(SST->getMemoryVT().getRawBits()); + ID.AddInteger(SST->getRawSubclassData()); + ID.AddInteger(SST->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, DL, IP)) + return SDValue(E, 0); + + auto *N = newSDNode( + DL.getIROrder(), DL.getDebugLoc(), VTs, AM, SST->isTruncatingStore(), + SST->isCompressingStore(), SST->getMemoryVT(), SST->getMemOperand()); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, ArrayRef Ops, MachineMemOperand *MMO, ISD::MemIndexType IndexType) { Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -571,6 +571,10 @@ SmallVector &OpValues, bool isGather); void visitVPStoreScatter(const VPIntrinsic &VPIntrin, SmallVector &OpValues, bool isScatter); + void visitVPStridedLoad(const VPIntrinsic &VPIntrin, EVT VT, + SmallVector &OpValues); + void visitVPStridedStore(const VPIntrinsic &VPIntrin, + SmallVector &OpValues); void visitVectorPredicationIntrinsic(const VPIntrinsic &VPIntrin); void visitVAStart(const CallInst &I); Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -7427,6 +7427,58 @@ setValue(&VPIntrin, ST); } +void SelectionDAGBuilder::visitVPStridedLoad( + const VPIntrinsic &VPIntrin, EVT VT, SmallVector &OpValues) { + SDLoc DL = getCurSDLoc(); + Value *PtrOperand = VPIntrin.getArgOperand(0); + MaybeAlign Alignment = DAG.getEVTAlign(VT); + AAMDNodes AAInfo = VPIntrin.getAAMetadata(); + const MDNode *Ranges = VPIntrin.getMetadata(LLVMContext::MD_range); + MemoryLocation ML; + if (VT.isScalableVector()) + ML = MemoryLocation::getAfter(PtrOperand); + else + ML = MemoryLocation( + PtrOperand, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(VPIntrin.getType())), + AAInfo); + + bool AddToChain = !AA || !AA->pointsToConstantMemory(ML); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); + MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( + MachinePointerInfo(PtrOperand), MachineMemOperand::MOLoad, + VT.getStoreSize().getKnownMinSize(), *Alignment, AAInfo, Ranges); + + SDValue LD = DAG.getStridedLoadVP(VT, DL, InChain, OpValues[0], OpValues[1], + OpValues[2], OpValues[3], MMO, + false /*IsExpanding*/); + + if (AddToChain) + PendingLoads.push_back(LD.getValue(1)); + setValue(&VPIntrin, LD); +} + +void SelectionDAGBuilder::visitVPStridedStore( + const VPIntrinsic &VPIntrin, SmallVector &OpValues) { + SDLoc DL = getCurSDLoc(); + Value *PtrOperand = VPIntrin.getArgOperand(1); + EVT VT = OpValues[0].getValueType(); + MaybeAlign Alignment = DAG.getEVTAlign(VT); + AAMDNodes AAInfo = VPIntrin.getAAMetadata(); + + MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( + MachinePointerInfo(PtrOperand), MachineMemOperand::MOStore, + VT.getStoreSize().getKnownMinSize(), *Alignment, AAInfo); + + SDValue ST = DAG.getStridedStoreVP( + getMemoryRoot(), DL, OpValues[0], OpValues[1], OpValues[2], OpValues[3], + OpValues[4], MMO, false /* IsCompressing */); + + DAG.setRoot(ST); + setValue(&VPIntrin, ST); +} + void SelectionDAGBuilder::visitVectorPredicationIntrinsic( const VPIntrinsic &VPIntrin) { SDLoc DL = getCurSDLoc(); @@ -7464,10 +7516,16 @@ visitVPLoadGather(VPIntrin, ValueVTs[0], OpValues, Opcode == ISD::VP_GATHER); break; + case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: + visitVPStridedLoad(VPIntrin, ValueVTs[0], OpValues); + break; case ISD::VP_STORE: case ISD::VP_SCATTER: visitVPStoreScatter(VPIntrin, OpValues, Opcode == ISD::VP_SCATTER); break; + case ISD::EXPERIMENTAL_VP_STRIDED_STORE: + visitVPStridedStore(VPIntrin, OpValues); + break; } } Index: llvm/lib/IR/IntrinsicInst.cpp =================================================================== --- llvm/lib/IR/IntrinsicInst.cpp +++ llvm/lib/IR/IntrinsicInst.cpp @@ -489,6 +489,12 @@ M, VPID, {Params[0]->getType()->getPointerElementType(), Params[0]->getType()}); break; + case Intrinsic::experimental_vp_strided_load: + VPFunc = Intrinsic::getDeclaration( + M, VPID, + {Params[0]->getType()->getPointerElementType(), Params[0]->getType(), + Params[1]->getType()}); + break; case Intrinsic::vp_gather: VPFunc = Intrinsic::getDeclaration( M, VPID, @@ -503,6 +509,12 @@ M, VPID, {Params[1]->getType()->getPointerElementType(), Params[1]->getType()}); break; + case Intrinsic::experimental_vp_strided_store: + VPFunc = Intrinsic::getDeclaration( + M, VPID, + {Params[1]->getType()->getPointerElementType(), Params[1]->getType(), + Params[2]->getType()}); + break; case Intrinsic::vp_scatter: VPFunc = Intrinsic::getDeclaration( M, VPID, {Params[0]->getType(), Params[1]->getType()}); Index: llvm/unittests/IR/VPIntrinsicTest.cpp =================================================================== --- llvm/unittests/IR/VPIntrinsicTest.cpp +++ llvm/unittests/IR/VPIntrinsicTest.cpp @@ -53,10 +53,16 @@ Str << " declare void @llvm.vp.store.v8i32.p0v8i32(<8 x i32>, <8 x i32>*, " "<8 x i1>, i32) "; + Str << "declare void " + "@llvm.experimental.vp.strided.store.v8i32.p0v8i32.i32(<8 x i32>, " + "<8 x i32>*, i32, <8 x i1>, i32) "; Str << " declare void @llvm.vp.scatter.v8i32.v8p0i32(<8 x i32>, <8 x " "i32*>, <8 x i1>, i32) "; Str << " declare <8 x i32> @llvm.vp.load.v8i32.p0v8i32(<8 x i32>*, <8 x " "i1>, i32) "; + Str << "declare <8 x i32> " + "@llvm.experimental.vp.strided.load.v8i32.p0v8i32.i32(<8 x i32>*, " + "i32, <8 x i1>, i32) "; Str << " declare <8 x i32> @llvm.vp.gather.v8i32.v8p0i32(<8 x i32*>, <8 x " "i1>, i32) ";