diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -4565,29 +4565,6 @@ return AddrModes.find(Key)->second; } -unsigned getScatterVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) { - std::map, unsigned> AddrModes = { - {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false), - AArch64ISD::SST1_PRED}, - {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true), - AArch64ISD::SST1_UXTW_PRED}, - {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false), - AArch64ISD::SST1_PRED}, - {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true), - AArch64ISD::SST1_SXTW_PRED}, - {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false), - AArch64ISD::SST1_SCALED_PRED}, - {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true), - AArch64ISD::SST1_UXTW_SCALED_PRED}, - {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false), - AArch64ISD::SST1_SCALED_PRED}, - {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true), - AArch64ISD::SST1_SXTW_SCALED_PRED}, - }; - auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend); - return AddrModes.find(Key)->second; -} - unsigned getSignExtendedGatherOpcode(unsigned Opcode) { switch (Opcode) { default: @@ -4610,53 +4587,6 @@ } } -bool getGatherScatterIndexIsExtended(SDValue Index) { - // Ignore non-pointer sized indices. - if (Index.getValueType() != MVT::nxv2i64) - return false; - - unsigned Opcode = Index.getOpcode(); - if (Opcode == ISD::SIGN_EXTEND_INREG) - return cast(Index.getOperand(1))->getVT() == MVT::nxv2i32; - - if (Opcode == ISD::AND) { - SDValue Splat = Index.getOperand(1); - if (Splat.getOpcode() != ISD::SPLAT_VECTOR) - return false; - ConstantSDNode *Mask = dyn_cast(Splat.getOperand(0)); - if (!Mask || Mask->getZExtValue() != 0xFFFFFFFF) - return false; - return true; - } - - return false; -} - -// If the base pointer of a masked gather or scatter is constant, we -// may be able to swap BasePtr & Index and use the vector + immediate addressing -// mode, e.g. -// VECTOR + IMMEDIATE: -// getelementptr nullptr, (splat(#x)) + %indices) -// -> getelementptr #x, %indices -void selectGatherScatterAddrMode(SDValue &BasePtr, SDValue &Index, - bool IsScaled, EVT MemVT, unsigned &Opcode, - bool IsGather, SelectionDAG &DAG) { - ConstantSDNode *Offset = dyn_cast(BasePtr); - if (!Offset || IsScaled) - return; - - uint64_t OffsetVal = Offset->getZExtValue(); - unsigned ScalarSizeInBytes = MemVT.getScalarSizeInBits() / 8; - - if (OffsetVal % ScalarSizeInBytes || OffsetVal / ScalarSizeInBytes > 31) - return; - - // Immediate is in range - Opcode = - IsGather ? AArch64ISD::GLD1_IMM_MERGE_ZERO : AArch64ISD::SST1_IMM_PRED; - std::swap(BasePtr, Index); -} - SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op, SelectionDAG &DAG) const { MaskedGatherSDNode *MGT = cast(Op); @@ -4752,37 +4682,8 @@ return DAG.getMergeValues({Result, Load.getValue(1)}, DL); } - bool IdxNeedsExtend = - getGatherScatterIndexIsExtended(Index) || - Index.getSimpleValueType().getVectorElementType() == MVT::i32; - - EVT IndexVT = Index.getSimpleValueType(); - SDValue InputVT = DAG.getValueType(MemVT); - - // Handle FP data by using an integer gather and casting the result. - if (VT.isFloatingPoint()) - InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); - - SDVTList VTs = DAG.getVTList(IndexVT, MVT::Other); - - if (getGatherScatterIndexIsExtended(Index)) - Index = Index.getOperand(0); - - unsigned Opcode = getGatherVecOpcode(IsScaled, IsSigned, IdxNeedsExtend); - selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode, - /*isGather=*/true, DAG); - - if (ExtType == ISD::SEXTLOAD) - Opcode = getSignExtendedGatherOpcode(Opcode); - - SDValue Ops[] = {Chain, Mask, BasePtr, Index, InputVT}; - SDValue Result = DAG.getNode(Opcode, DL, VTs, Ops); - Chain = Result.getValue(1); - - if (VT.isFloatingPoint()) - Result = getSVESafeBitCast(VT, Result, DAG); - - return DAG.getMergeValues({Result, Chain}, DL); + // Everything else is legal. + return Op; } SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op, @@ -4862,29 +4763,8 @@ MSC->getMemOperand(), IndexType, Truncating); } - bool NeedsExtend = - getGatherScatterIndexIsExtended(Index) || - Index.getSimpleValueType().getVectorElementType() == MVT::i32; - - SDVTList VTs = DAG.getVTList(MVT::Other); - SDValue InputVT = DAG.getValueType(MemVT); - - if (VT.isFloatingPoint()) { - // Handle FP data by casting the data so an integer scatter can be used. - EVT StoreValVT = getPackedSVEVectorVT(VT.getVectorElementCount()); - StoreVal = getSVESafeBitCast(StoreValVT, StoreVal, DAG); - InputVT = DAG.getValueType(MemVT.changeVectorElementTypeToInteger()); - } - - if (getGatherScatterIndexIsExtended(Index)) - Index = Index.getOperand(0); - - unsigned Opcode = getScatterVecOpcode(IsScaled, IsSigned, NeedsExtend); - selectGatherScatterAddrMode(BasePtr, Index, IsScaled, MemVT, Opcode, - /*isGather=*/false, DAG); - - SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, InputVT}; - return DAG.getNode(Opcode, DL, VTs, Ops); + // Everything else is legal. + return Op; } SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const { diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -443,6 +443,192 @@ cast(N)->isNonTemporal(); }]>; +def SDT_SVEMaskedGather : SDTypeProfile<1, 4, [ + SDTCisVec<0>, SDTCisSameAs<0, 1>, SDTCisVec<2>, SDTCisPtrTy<3>, SDTCisVec<4>, + SDTCVecEltisVT<2, i1>, SDTCisSameNumEltsAs<0, 2>, SDTCisSameNumEltsAs<0, 4> +]>; +def sve_masked_gather : SDNode<"ISD::MGATHER", SDT_SVEMaskedGather, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + +// masked gather (signed scaled offsets). +def masked_gather_signed_scaled : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (sve_masked_gather node:$def, node:$pred, node:$ptr, node:$idx),[{ + auto MGN = cast(N); + bool Signed = + MGN->isIndexSigned() || MGN->getIndex().getValueType() == MVT::nxv2i64; + return Signed && MGN->isIndexScaled(); +}]>; +// masked gather (signed unscaled offsets). +def masked_gather_signed_unscaled : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (sve_masked_gather node:$def, node:$pred, node:$ptr, node:$idx),[{ + auto MGN = cast(N); + bool Signed = + MGN->isIndexSigned() || MGN->getIndex().getValueType() == MVT::nxv2i64; + return Signed && !MGN->isIndexScaled(); +}]>; +// masked gather (unsigned scaled offsets). +def masked_gather_unsigned_scaled : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (sve_masked_gather node:$def, node:$pred, node:$ptr, node:$idx),[{ + auto MGN = cast(N); + bool Signed = + MGN->isIndexSigned() || MGN->getIndex().getValueType() == MVT::nxv2i64; + return !Signed && MGN->isIndexScaled(); +}]>; +// masked gather (unsigned unscaled offsets). +def masked_gather_unsigned_unscaled : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (sve_masked_gather node:$def, node:$pred, node:$ptr, node:$idx),[{ + auto MGN = cast(N); + bool Signed = + MGN->isIndexSigned() || MGN->getIndex().getValueType() == MVT::nxv2i64; + return !Signed && !MGN->isIndexScaled(); +}]>; + +multiclass masked_gather { + def nonext_#NAME : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + return cast(N)->getExtensionType() == ISD::NON_EXTLOAD; + }]>; + + // Any/Zero extending masked gather fragments. + def azext_#NAME#_i8 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return (MGN->getExtensionType() == ISD::EXTLOAD || + MGN->getExtensionType() == ISD::ZEXTLOAD) && + MGN->getMemoryVT().getScalarType() == MVT::i8; + }]>; + def azext_#NAME#_i16 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return (MGN->getExtensionType() == ISD::EXTLOAD || + MGN->getExtensionType() == ISD::ZEXTLOAD) && + MGN->getMemoryVT().getScalarType() == MVT::i16; + }]>; + def azext_#NAME#_i32 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return (MGN->getExtensionType() == ISD::EXTLOAD || + MGN->getExtensionType() == ISD::ZEXTLOAD) && + MGN->getMemoryVT().getScalarType() == MVT::i32; + }]>; + + // Sign extending masked gather fragments. + def sext_#NAME#_i8 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::SEXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i8; + }]>; + def sext_#NAME#_i16 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::SEXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i16; + }]>; + def sext_#NAME#_i32 : + PatFrag<(ops node:$def, node:$pred, node:$ptr, node:$idx), + (GatherOp node:$def, node:$pred, node:$ptr, node:$idx), [{ + auto MGN = cast(N); + return MGN->getExtensionType() == ISD::SEXTLOAD && + MGN->getMemoryVT().getScalarType() == MVT::i32; + }]>; +} + +defm masked_gather_signed_scaled : masked_gather; +defm masked_gather_signed_unscaled : masked_gather; +defm masked_gather_unsigned_scaled : masked_gather; +defm masked_gather_unsigned_unscaled : masked_gather; + +def SDT_SVEMaskedScatter : SDTypeProfile<0, 4, [ + SDTCisVec<0>, SDTCisVec<1>, SDTCisPtrTy<2>, SDTCisVec<3>, + SDTCVecEltisVT<1, i1>, SDTCisSameNumEltsAs<0, 1>, SDTCisSameNumEltsAs<0, 3> +]>; +def sve_masked_scatter : SDNode<"ISD::MSCATTER", SDT_SVEMaskedScatter, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; + +// masked scatter fragment (signed scaled offsets). +def masked_scatter_signed_scaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (sve_masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + bool Signed = + MSN->isIndexSigned() || MSN->getIndex().getValueType() == MVT::nxv2i64; + return Signed && MSN->isIndexScaled(); +}]>; +// masked scatter fragment (signed unscaled offsets). +def masked_scatter_signed_unscaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (sve_masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + bool Signed = + MSN->isIndexSigned() || MSN->getIndex().getValueType() == MVT::nxv2i64; + return Signed && !MSN->isIndexScaled(); +}]>; +// masked scatter fragment (unsigned scaled offsets). +def masked_scatter_unsigned_scaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (sve_masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + bool Signed = + MSN->isIndexSigned() || MSN->getIndex().getValueType() == MVT::nxv2i64; + return !Signed && MSN->isIndexScaled(); +}]>; +// masked scatter fragment (unsigned unscaled offsets). +def masked_scatter_unsigned_unscaled : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (sve_masked_scatter node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + bool Signed = + MSN->isIndexSigned() || MSN->getIndex().getValueType() == MVT::nxv2i64; + return !Signed && !MSN->isIndexScaled(); +}]>; + +multiclass masked_scatter { + def nontrunc_#NAME : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (ScatterOp node:$val, node:$pred, node:$ptr, node:$idx), [{ + return !cast(N)->isTruncatingStore(); + }]>; + + // Truncating masked scatter fragments. + def trunc_#NAME#_i8 : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (ScatterOp node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + return MSN->isTruncatingStore() && + MSN->getMemoryVT().getScalarType() == MVT::i8; + }]>; + def trunc_#NAME#_i16 : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (ScatterOp node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + return MSN->isTruncatingStore() && + MSN->getMemoryVT().getScalarType() == MVT::i16; + }]>; + def trunc_#NAME#_i32 : + PatFrag<(ops node:$val, node:$pred, node:$ptr, node:$idx), + (ScatterOp node:$val, node:$pred, node:$ptr, node:$idx), [{ + auto MSN = cast(N); + return MSN->isTruncatingStore() && + MSN->getMemoryVT().getScalarType() == MVT::i32; + }]>; +} + +defm masked_scatter_signed_scaled : masked_scatter; +defm masked_scatter_signed_unscaled : masked_scatter; +defm masked_scatter_unsigned_scaled : masked_scatter; +defm masked_scatter_unsigned_unscaled : masked_scatter; + // top16Zero - answer true if the upper 16 bits of $src are 0, false otherwise def top16Zero: PatLeaf<(i32 GPR32:$src), [{ return SDValue(N,0)->getValueType(0) == MVT::i32 && diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1036,6 +1036,92 @@ defm GLDFF1W_D : sve_mem_64b_gld_sv_32_scaled<0b1011, "ldff1w", AArch64ldff1_gather_sxtw_scaled_z, AArch64ldff1_gather_uxtw_scaled_z, ZPR64ExtSXTW32, ZPR64ExtUXTW32, nxv2i32>; defm GLD1D : sve_mem_64b_gld_sv_32_scaled<0b1110, "ld1d", AArch64ld1_gather_sxtw_scaled_z, AArch64ld1_gather_uxtw_scaled_z, ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>; defm GLDFF1D : sve_mem_64b_gld_sv_32_scaled<0b1111, "ldff1d", AArch64ldff1_gather_sxtw_scaled_z, AArch64ldff1_gather_uxtw_scaled_z, ZPR64ExtSXTW64, ZPR64ExtUXTW64, nxv2i64>; + + multiclass sve_masked_gather_x2_scaled { + // base + vector of scaled offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs))), + (!cast(Inst # _SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit scaled offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))), + (!cast(Inst # _SXTW_SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit scaled offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))), + (!cast(Inst # _UXTW_SCALED) PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_gather_x2_unscaled { + // vector of pointers + immediate offset (includes zero) + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), (i64 ImmTy:$imm), (nxv2i64 ZPR:$ptrs))), + (!cast(Inst # _IMM) PPR:$gp, ZPR:$ptrs, ImmTy:$imm)>; + // base + vector of offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs))), + (!cast(Inst) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32))), + (!cast(Inst # _SXTW) PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit offsets + def : Pat<(Ty (Load (SVEDup0Undef), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF)))))), + (!cast(Inst # _UXTW) PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_gather_x4 { + def : Pat<(Ty (Load (SVEDup0Undef), (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs))), + (Inst PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + defm : sve_masked_gather_x2_scaled; + + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + defm : sve_masked_gather_x2_unscaled; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; + defm : sve_masked_gather_x4; } // End HasSVE let Predicates = [HasSVEorStreamingSVE] in { @@ -1126,6 +1212,81 @@ defm SST1H_D : sve_mem_sst_sv_64_scaled<0b01, "st1h", AArch64st1_scatter_scaled, ZPR64ExtLSL16, nxv2i16>; defm SST1W_D : sve_mem_sst_sv_64_scaled<0b10, "st1w", AArch64st1_scatter_scaled, ZPR64ExtLSL32, nxv2i32>; defm SST1D : sve_mem_sst_sv_64_scaled<0b11, "st1d", AArch64st1_scatter_scaled, ZPR64ExtLSL64, nxv2i64>; + + multiclass sve_masked_scatter_x2_scaled { + // base + vector of scaled offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs)), + (!cast(Inst # _SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit scaled offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)), + (!cast(Inst # _SXTW_SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit scaled offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))), + (!cast(Inst # _UXTW_SCALED) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_scatter_x2_unscaled { + // vector of pointers + immediate offset (includes zero) + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), (i64 ImmTy:$imm), (nxv2i64 ZPR:$ptrs)), + (!cast(Inst # _IMM) ZPR:$data, PPR:$gp, ZPR:$ptrs, ImmTy:$imm)>; + // base + vector of offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (nxv2i64 ZPR:$offs)), + (!cast(Inst) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of signed 32bit offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (sext_inreg (nxv2i64 ZPR:$offs), nxv2i32)), + (!cast(Inst # _SXTW) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + // base + vector of unsigned 32bit offsets + def : Pat<(Store (Ty ZPR:$data), (nxv2i1 PPR:$gp), GPR64:$base, (and (nxv2i64 ZPR:$offs), (nxv2i64 (AArch64dup (i64 0xFFFFFFFF))))), + (!cast(Inst # _UXTW) ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + multiclass sve_masked_scatter_x4 { + def : Pat<(Store (Ty ZPR:$data), (nxv4i1 PPR:$gp), GPR64:$base, (nxv4i32 ZPR:$offs)), + (Inst ZPR:$data, PPR:$gp, GPR64:$base, ZPR:$offs)>; + } + + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + defm : sve_masked_scatter_x2_scaled; + + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + defm : sve_masked_scatter_x2_unscaled; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; + defm : sve_masked_scatter_x4; } // End HasSVE let Predicates = [HasSVEorStreamingSVE] in {