diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -238,6 +238,16 @@ SDTCisSameNumEltsAs<0, 3> ]>; +def SDTMaskedGather : SDTypeProfile<1, 4, [ + SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisVec<2>, SDTCisPtrTy<3>, SDTCisVec<4>, + SDTCisSameNumEltsAs<0, 2>, SDTCisSameNumEltsAs<0, 4> +]>; + +def SDTMaskedScatter : SDTypeProfile<0, 4, [ + SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>, SDTCisVec<3>, + SDTCisSameNumEltsAs<0, 1>, SDTCisSameNumEltsAs<0, 3> +]>; + def SDTVecShuffle : SDTypeProfile<1, 2, [ SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2> ]>; @@ -652,6 +662,12 @@ def masked_ld : SDNode<"ISD::MLOAD", SDTMaskedLoad, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; +def masked_gather : SDNode<"ISD::MGATHER", SDTMaskedGather, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + +def masked_scatter : SDNode<"ISD::MSCATTER", SDTMaskedScatter, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; + // Do not use ld, st directly. Use load, extload, sextload, zextload, store, // and truncst (see below). def ld : SDNode<"ISD::LOAD" , SDTLoad, @@ -1628,6 +1644,124 @@ let MemoryVT = i64; } +def nonext_masked_gather : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; + +// Any extending masked gather fragments. +def ext_masked_gather_i8 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::EXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def ext_masked_gather_i16 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::EXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def ext_masked_gather_i32 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::EXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i32; +}]>; + +// Sign extending masked gather fragments. +def sext_masked_gather_i8 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::SEXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def sext_masked_gather_i16 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::SEXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def sext_masked_gather_i32 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::SEXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i32; +}]>; + +// Zero extending masked gather fragments. +def zext_masked_gather_i8 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::ZEXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def zext_masked_gather_i16 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::ZEXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def zext_masked_gather_i32 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (masked_gather node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::ZEXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i32; +}]>; + +// Any/Zero extending masked gather fragments. +def azext_masked_gather_i8 : + PatFrags<(ops node:$def, node:$pred, node:$ptr, node:$idx), + [(ext_masked_gather_i8 node:$def, node:$pred, node:$ptr, node:$idx), + (zext_masked_gather_i8 node:$def, node:$pred, node:$ptr, node:$idx)]>; +def azext_masked_gather_i16 : + PatFrags<(ops node:$def, node:$pred, node:$ptr, node:$idx), + [(ext_masked_gather_i16 node:$def, node:$pred, node:$ptr, node:$idx), + (zext_masked_gather_i16 node:$def, node:$pred, node:$ptr, node:$idx)]>; +def azext_masked_gather_i32 : + PatFrags<(ops node:$def, node:$pred, node:$ptr, node:$idx), + [(ext_masked_gather_i32 node:$def, node:$pred, node:$ptr, node:$idx), + (zext_masked_gather_i32 node:$def, node:$pred, node:$ptr, node:$idx)]>; + +def nontrunc_masked_scatter : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + return !cast(N)->isTruncatingStore(); +}]>; + +// Truncating masked scatter fragments. +def trunc_masked_scatter_i8 : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + return MSN->isTruncatingStore() && + MSN->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def trunc_masked_scatter_i16 : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + return MSN->isTruncatingStore() && + MSN->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def trunc_masked_scatter_i32 : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + return MSN->isTruncatingStore() && + MSN->getMemoryVT().getScalarType() == MVT::i32; +}]>; + //===----------------------------------------------------------------------===// // Selection DAG Pattern Support. // diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4562,29 +4562,6 @@ return AddrModes.find(Key)->second; } -unsigned getScatterVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) { - std::map, unsigned> AddrModes = { - {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false), - AArch64ISD::SST1_PRED}, - {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true), - AArch64ISD::SST1_UXTW_PRED}, - {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false), - AArch64ISD::SST1_PRED}, - {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true), - AArch64ISD::SST1_SXTW_PRED}, - {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false), - AArch64ISD::SST1_SCALED_PRED}, - {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true), - AArch64ISD::SST1_UXTW_SCALED_PRED}, - {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false), - AArch64ISD::SST1_SCALED_PRED}, - {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true), - AArch64ISD::SST1_SXTW_SCALED_PRED}, - }; - auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend); - return AddrModes.find(Key)->second; -} - unsigned getSignExtendedGatherOpcode(unsigned Opcode) { switch (Opcode) { default: @@ -4607,53 +4584,6 @@ } } -bool getGatherScatterIndexIsExtended(SDValue Index) { - // Ignore non-pointer sized indices. - if (Index.getValueType() != MVT::nxv2i64) - return false; - - unsigned Opcode = Index.getOpcode(); - if (Opcode == ISD::SIGN_EXTEND_INREG) - return cast(Index.getOperand(1))->getVT() == MVT::nxv2i32; - - if (Opcode == ISD::AND) { - SDValue Splat = Index.getOperand(1); - if (Splat.getOpcode() != ISD::SPLAT_VECTOR) - return false; - ConstantSDNode *Mask = dyn_cast(Splat.getOperand(0)); - if (!Mask || Mask->getZExtValue() != 0xFFFFFFFF) - return false; - return true; - } - - return false; -} - -// If the base pointer of a masked gather or scatter is constant, we -// may be able to swap BasePtr & Index and use the vector + immediate addressing -// mode, e.g. -// VECTOR + IMMEDIATE: -// getelementptr nullptr, (splat(#x)) + %indices) -// -> getelementptr #x, %indices -void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, - bool IsScaled, EVT MemVT, unsigned &Opcode, - bool IsGather, SelectionDAG &DAG) { - ConstantSDNode *Offset = dyn_cast(BasePtr); - if (!Offset || IsScaled) - return; - - uint64_t OffsetVal = Offset->getZExtValue(); - unsigned ScalarSizeInBytes = MemVT.getScalarSizeInBits() / 8; - - if (OffsetVal % ScalarSizeInBytes || OffsetVal / ScalarSizeInBytes > 31) - return; - - // Immediate is in range - Opcode = - IsGather ? AArch64ISD::GLD1_IMM_MERGE_ZERO : AArch64ISD::SST1_IMM_PRED; - std::swap(BasePtr, Index); -} - SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, SelectionDAG &DAG) const { MaskedGatherSDNode *MGT = cast(Op); @@ -4749,37 +4679,8 @@ return DAG.getMergeValues({Result, Load.getValue(1)}, DL); } - bool IdxNeedsExtend = - getGatherScatterIndexIsExtended(Index) || - Index.getSimpleValueType().getVectorElementType() == MVT::i32; - - EVT IndexVT = Index.getSimpleValueType(); - SDValue InputVT = DAG.getValueType(MemVT); - - // Handle FP data by using an integer gather and casting the result. - if (VT.isFloatingPoint()) - InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); - - SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other); - - if (getGatherScatterIndexIsExtended(Index)) - Index = Index.getOperand(0); - - unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend); - selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode, - /*isGather=*/true, DAG); - - if (ExtType == ISD::SEXTLOAD) - Opcode = getSignExtendedGatherOpcode(Opcode); - - SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT}; - SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops); - Chain = Result.getValue(1); - - if (VT.isFloatingPoint()) - Result = getSVESafeBitCast(VT, Result, DAG); - - return DAG.getMergeValues({Result, Chain}, DL); + // Everything else is legal. + return Op; } SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, @@ -4859,29 +4760,8 @@ MSC->getMemOperand(), IndexType, Truncating); } - bool NeedsExtend = - getGatherScatterIndexIsExtended(Index) || - Index.getSimpleValueType().getVectorElementType() == MVT::i32; - - SDVTList VTs = DAG.getVTList(MVT::Other); - SDValue InputVT = DAG.getValueType(MemVT); - - if (VT.isFloatingPoint()) { - // Handle FP data by casting the data so an integer scatter can be used. - EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount()); - StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG); - InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); - } - - if (getGatherScatterIndexIsExtended(Index)) - Index = Index.getOperand(0); - - unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend); - selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode, - /*isGather=*/false, DAG); - - SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT}; - return DAG.getNode(Opcode, DL, VTs, Ops); + // Everything else is legal. + return Op; } SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const { diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -443,6 +443,58 @@ cast(N)->isNonTemporal(); }]>; +multiclass masked_gather_scatter { + // offsets = (signed)Index << sizeof(elt) + def NAME#_signed_scaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{ + auto MGS = cast(N); + bool Signed = MGS->isIndexSigned() || + MGS->getIndex().getValueType().getVectorElementType() == MVT::i64; + return Signed && MGS->isIndexScaled(); + }]>; + // offsets = (signed)Index + def NAME#_signed_unscaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{ + auto MGS = cast(N); + bool Signed = MGS->isIndexSigned() || + MGS->getIndex().getValueType().getVectorElementType() == MVT::i64; + return Signed && !MGS->isIndexScaled(); + }]>; + // offsets = (unsigned)Index << sizeof(elt) + def NAME#_unsigned_scaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{ + auto MGS = cast(N); + bool Signed = MGS->isIndexSigned() || + MGS->getIndex().getValueType().getVectorElementType() == MVT::i64; + return !Signed && MGS->isIndexScaled(); + }]>; + // offsets = (unsigned)Index + def NAME#_unsigned_unscaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (GatherScatterOp node:$val, node:$pred, node:$ptr, node:$idx),[{ + auto MGS = cast(N); + bool Signed = MGS->isIndexSigned() || + MGS->getIndex().getValueType().getVectorElementType() == MVT::i64; + return !Signed && !MGS->isIndexScaled(); + }]>; +} + +defm nonext_masked_gather : masked_gather_scatter; +defm azext_masked_gather_i8 : masked_gather_scatter; +defm azext_masked_gather_i16 : masked_gather_scatter; +defm azext_masked_gather_i32 : masked_gather_scatter; +defm sext_masked_gather_i8 : masked_gather_scatter; +defm sext_masked_gather_i16 : masked_gather_scatter; +defm sext_masked_gather_i32 : masked_gather_scatter; + +defm nontrunc_masked_scatter : masked_gather_scatter; +defm trunc_masked_scatter_i8 : masked_gather_scatter; +defm trunc_masked_scatter_i16 : masked_gather_scatter; +defm trunc_masked_scatter_i32 : masked_gather_scatter; + // top16Zero - answer true if the upper 16 bits of $src are 0, false otherwise def top16Zero: PatLeaf<(i32 GPR32:$src), [{ return SDValue(N,0)->getValueType(0) == MVT::i32 && diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1036,6 +1036,92 @@ defm GLDFF1W_D : sve_mem_64b_gld_sv_32_scaled<0b1011, "ldff1w", AArch64ldff1_gather_sxtw_scaled_z, AArch64ldff1_gather_uxtw_scaled_z, ZPR64ExtSXTW32, ZPR64ExtUXTW32, nxv2i32>; defm GLD1D : sve_mem_64b_gld_sv_32_scaled<0b1110, "ld1d", AArch64ld1_gather_sxtw_scaled_z, AArch64ld1_gather_uxtw_scaled_z, ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>; defm GLDFF1D : sve_mem_64b_gld_sv_32_scaled<0b1111, "ldff1d", AArch64ldff1_gather_sxtw_scaled_z, AArch64ldff1_gather_uxtw_scaled_z, ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>; + + multiclass sve_masked_gather_x2_scaled { + // base + vector of scaled offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs))), + (!cast(Inst # _SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit scaled offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))), + (!cast(Inst # _SXTW_SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit scaled offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))), + (!cast(Inst # _UXTW_SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_gather_x2_unscaled { + // vector of pointers + immediate offset (includes zero) + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), (i64 ImmTy:$imm), (nxv2i64 ZPR:$ptrs))), + (!cast(Inst # _IMM) PPR:$gp, ZPR:$ptrs, ImmTy:$imm)>; + // base + vector of offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs))), + (!cast(Inst) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))), + (!cast(Inst # _SXTW) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))), + (!cast(Inst # _UXTW) PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_gather_x4 { + def : Pat<(Ty (Load (SVEDup0Undef), (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs))), + (Inst PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; } // End HasSVE let Predicates = [HasSVEorStreamingSVE] in { @@ -1126,6 +1212,81 @@ defm SST1H_D : sve_mem_sst_sv_64_scaled<0b01, "st1h", AArch64st1_scatter_scaled, ZPR64ExtLSL16, nxv2i16>; defm SST1W_D : sve_mem_sst_sv_64_scaled<0b10, "st1w", AArch64st1_scatter_scaled, ZPR64ExtLSL32, nxv2i32>; defm SST1D : sve_mem_sst_sv_64_scaled<0b11, "st1d", AArch64st1_scatter_scaled, ZPR64ExtLSL64, nxv2i64>; + + multiclass sve_masked_scatter_x2_scaled { + // base + vector of scaled offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs)), + (!cast(Inst # _SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit scaled offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)), + (!cast(Inst # _SXTW_SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit scaled offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))), + (!cast(Inst # _UXTW_SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_scatter_x2_unscaled { + // vector of pointers + immediate offset (includes zero) + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), (i64 ImmTy:$imm), (nxv2i64 ZPR:$ptrs)), + (!cast(Inst # _IMM) ZPR:$data, PPR:$gp, ZPR:$ptrs, ImmTy:$imm)>; + // base + vector of offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs)), + (!cast(Inst) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)), + (!cast(Inst # _SXTW) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))), + (!cast(Inst # _UXTW) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_scatter_x4 { + def : Pat<(Store (Ty ZPR:$data), (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs)), + (Inst ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; } // End HasSVE let Predicates = [HasSVEorStreamingSVE] in {