Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -838,6 +838,34 @@ return false; } +static bool isConstantSplatVectorMaskForType(SDNode *N, EVT ScalarTy) { + if (!ScalarTy.isSimple()) + return false; + + uint64_t MaskForTy = 0ull; + switch(ScalarTy.getSimpleVT().SimpleTy) { + case MVT::i8: + MaskForTy = 0xffull; + break; + case MVT::i16: + MaskForTy = 0xffffull; + break; + case MVT::i32: + MaskForTy = 0xffffffffull; + break; + default: + return false; + break; + } + + APInt Val; + if (ISD::isConstantSplatVector(N, Val)) { + return Val.getLimitedValue() == MaskForTy; + } + + return false; +} + // Returns the SDNode if it is a constant float BuildVector // or constant float. static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) { @@ -5221,6 +5249,24 @@ } } + if (auto *LN0 = dyn_cast(N0)) { + EVT MemVT = LN0->getMemoryVT(); + EVT ScalarVT = MemVT.getScalarType(); + if (SDValue(LN0, 0).hasOneUse() + && isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) + && TLI.isLoadExtLegal(ISD::ZEXTLOAD, VT, MemVT)) { + SDValue ZExtLoad = DAG.getMaskedLoad(VT, SDLoc(N), LN0->getChain(), + LN0->getBasePtr(), LN0->getMask(), + LN0->getPassThru(), MemVT, + LN0->getMemOperand(), ISD::ZEXTLOAD); + CombineTo(N, ZExtLoad); + CombineTo(N0.getNode(), ZExtLoad, ZExtLoad.getValue(1)); + AddToWorklist(ZExtLoad.getNode()); + // Avoid recheck of N. + return SDValue(N, 0); + } + } + // fold (and (load x), 255) -> (zextload x, i8) // fold (and (extload x, i16), 255) -> (zextload x, i8) // fold (and (any_ext (extload x, i16)), 255) -> (zextload x, i8) @@ -9043,6 +9089,9 @@ if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT)) return SDValue(); + if (DstVT.isScalableVector()) + return SDValue(); + SDLoc DL(N); const unsigned NumSplits = DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements(); @@ -10337,6 +10386,21 @@ AddToWorklist(ExtLoad.getNode()); return SDValue(N, 0); // Return N so it doesn't get rechecked! } + // fold (sext_inreg (masked_load x)) -> (sext_masked_load x) + if (isa(N0) && + EVT == cast(N0)->getMemoryVT() && + ((!LegalOperations && !cast(N0)->isVolatile()) || + TLI.isLoadExtLegal(ISD::SEXTLOAD, VT, EVT))) { + MaskedLoadSDNode *LN0 = cast(N0); + SDValue ExtLoad = DAG.getMaskedLoad(VT, SDLoc(N), LN0->getChain(), + LN0->getBasePtr(), LN0->getMask(), + LN0->getPassThru(), LN0->getMemoryVT(), + LN0->getMemOperand(), ISD::SEXTLOAD); + CombineTo(N, ExtLoad); + CombineTo(N0.getNode(), ExtLoad, ExtLoad.getValue(1)); + AddToWorklist(ExtLoad.getNode()); + return SDValue(N, 0); // Return N so it doesn't get rechecked! + } // fold (sext_inreg (zextload x)) -> (sextload x) iff load has one use if (ISD::isZEXTLoad(N0.getNode()) && ISD::isUNINDEXEDLoad(N0.getNode()) && N0.hasOneUse() && Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4447,12 +4447,15 @@ const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); // Do not serialize masked loads of constant memory with anything. - bool AddToChain = - !AA || !AA->pointsToConstantMemory(MemoryLocation( - PtrOperand, - LocationSize::precise( - DAG.getDataLayout().getTypeStoreSize(I.getType())), - AAInfo)); + bool AddToChain = false; + if (!VT.isScalableVector()) + AddToChain = + !AA || !AA->pointsToConstantMemory(MemoryLocation( + PtrOperand, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo)); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); MachineMemOperand *MMO = Index: llvm/lib/CodeGen/TargetLoweringBase.cpp =================================================================== --- llvm/lib/CodeGen/TargetLoweringBase.cpp +++ llvm/lib/CodeGen/TargetLoweringBase.cpp @@ -1265,18 +1265,23 @@ MVT EltVT = VT.getVectorElementType(); unsigned NElts = VT.getVectorNumElements(); bool IsLegalWiderType = false; + bool IsScalable = VT.isScalableVector(); LegalizeTypeAction PreferredAction = getPreferredVectorAction(VT); switch (PreferredAction) { - case TypePromoteInteger: + case TypePromoteInteger: { + MVT::SimpleValueType EndVT = IsScalable ? + MVT::LAST_INTEGER_SCALABLE_VECTOR_VALUETYPE : + MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE; // Try to promote the elements of integer vectors. If no legal // promotion was found, fall through to the widen-vector method. for (unsigned nVT = i + 1; - nVT <= MVT::LAST_INTEGER_FIXEDLEN_VECTOR_VALUETYPE; ++nVT) { + (MVT::SimpleValueType) nVT <= EndVT; ++nVT) { MVT SVT = (MVT::SimpleValueType) nVT; // Promote vectors of integers to vectors with the same number // of elements, with a wider element type. if (SVT.getScalarSizeInBits() > EltVT.getSizeInBits() && - SVT.getVectorNumElements() == NElts && isTypeLegal(SVT)) { + SVT.getVectorNumElements() == NElts && isTypeLegal(SVT) && + SVT.isScalableVector() == IsScalable && isTypeLegal(SVT)) { TransformToType[i] = SVT; RegisterTypeForVT[i] = SVT; NumRegistersForVT[i] = 1; @@ -1288,6 +1293,7 @@ if (IsLegalWiderType) break; LLVM_FALLTHROUGH; + } case TypeWidenVector: if (isPowerOf2_32(NElts)) { @@ -1295,6 +1301,7 @@ for (unsigned nVT = i + 1; nVT <= MVT::LAST_VECTOR_VALUETYPE; ++nVT) { MVT SVT = (MVT::SimpleValueType) nVT; if (SVT.getVectorElementType() == EltVT + && SVT.isScalableVector() == IsScalable && SVT.getVectorNumElements() > NElts && isTypeLegal(SVT)) { TransformToType[i] = SVT; RegisterTypeForVT[i] = SVT; Index: llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -140,6 +140,11 @@ return SelectAddrModeXRO(N, Width / 8, Base, Offset, SignExtend, DoShift); } + bool SelectUndef(SDValue N) { + if (N->getOpcode() == ISD::UNDEF) + return true; + return false; + } /// Form sequences of consecutive 64/128-bit registers for use in NEON /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have @@ -204,7 +209,7 @@ bool SelectAddrModeXRO(SDValue N, unsigned Size, SDValue &Base, SDValue &Offset, SDValue &SignExtend, SDValue &DoShift); - bool isWorthFolding(SDValue V) const; + bool isWorthFolding(SDValue V, unsigned MaxUses = 1) const; bool SelectExtendedSHL(SDValue N, unsigned Size, bool WantExtend, SDValue &Offset, SDValue &SignExtend); @@ -375,7 +380,7 @@ } /// Determine whether it is worth to fold V into an extended register. -bool AArch64DAGToDAGISel::isWorthFolding(SDValue V) const { +bool AArch64DAGToDAGISel::isWorthFolding(SDValue V, unsigned MaxUses) const { // Trivial if we are optimizing for code size or if there is only // one use of the value. if (ForCodeSize || V.hasOneUse()) @@ -394,6 +399,18 @@ return true; } + // If it has more than one use, check they're all loads/stores + // from/to the same memory type (e.g. if you can fold for one + // addressing mode, you can fold for the others as well). + EVT VT; + for (auto *Use : V.getNode()->uses()) + if (auto *MemNode = dyn_cast(Use)) + if (MemNode->getMemoryVT() != VT && VT != EVT()) + return false; + + if (V.getNode()->use_size() <= MaxUses) + return true; + // It hurts otherwise, since the value will be reused. return false; } Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -256,6 +256,55 @@ SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisSameAs<1, 4>]>; +// non-extending masked load fragment. +def nonext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; +// sign extending masked load fragments. +def asext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def),[{ + return cast(N)->getExtensionType() == ISD::EXTLOAD || + cast(N)->getExtensionType() == ISD::SEXTLOAD; +}]>; +def asext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def asext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def asext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; +// zero extending masked load fragments. +def zext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getExtensionType() == ISD::ZEXTLOAD; +}]>; +def zext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def zext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def zext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; // Node definitions. def AArch64adrp : SDNode<"AArch64ISD::ADRP", SDTIntUnaryOp, []>; Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1070,6 +1070,44 @@ def : Pat<(nxv2f64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2f64 ZPR:$src)>; def : Pat<(nxv2f64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2f64 ZPR:$src)>; + // Add more complex addressing modes here as required + multiclass pred_load { + + def _default_z : Pat<(Ty (Load GPR64:$base, (PredTy PPR:$gp), (SVEUndef))), + (RegImmInst PPR:$gp, GPR64:$base, (i64 0))>; + } + + // 2-element contiguous loads + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + // 4-element contiguous loads + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + // 8-element contiguous loads + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + // 16-element contiguous loads + defm : pred_load; + } let Predicates = [HasSVE2] in { Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -147,6 +147,13 @@ bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info); + bool isLegalMaskedLoad(Type *DataType) { + return ST->hasSVE(); + } + bool isLegalMaskedStore(Type *DataType) { + return ST->hasSVE(); + } + int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, unsigned Alignment, unsigned AddressSpace, Index: llvm/lib/Target/AArch64/SVEInstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/SVEInstrFormats.td +++ llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -293,6 +293,8 @@ : Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3)), (inst $Op1, $Op2, $Op3)>; +def SVEUndef : ComplexPattern; + //===----------------------------------------------------------------------===// // SVE Predicate Misc Group //===----------------------------------------------------------------------===// @@ -4732,6 +4734,13 @@ (!cast(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), 0>; def : InstAlias(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 1>; + + // We need a layer of indirection because early machine code passes balk at + // physical register (i.e. FFR) uses that have no previous definition. + let hasSideEffects = 1, hasNoSchedulingInfo = 1, mayLoad = 1 in { + def "" : Pseudo<(outs listty:$Zt), (ins PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), []>, + PseudoInstExpansion<(!cast(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4)>; + } } multiclass sve_mem_cld_si dtype, string asm, RegisterOperand listty, Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll @@ -0,0 +1,87 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Masked Loads +; + +define @masked_load_nxv2i64( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv2i64: +; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv2i64( *%a, i32 8, %mask, undef) + ret %load +} + +define @masked_load_nxv4i32( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv4i32: +; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv4i32( *%a, i32 4, %mask, undef) + ret %load +} + +define @masked_load_nxv8i16( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv8i16: +; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv8i16( *%a, i32 2, %mask, undef) + ret %load +} + +define @masked_load_nxv16i8( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv16i8: +; CHECK: ld1b { [[IN:z[0-9]+]].b }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv16i8( *%a, i32 1, %mask, undef) + ret %load +} + +define @masked_load_nxv2f64( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv2f64: +; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv2f64( *%a, i32 8, %mask, undef) + ret %load +} + +define @masked_load_nxv2f32( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv2f32: +; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv2f32( *%a, i32 4, %mask, undef) + ret %load +} + +define @masked_load_nxv2f16( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv2f16: +; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv2f16( *%a, i32 2, %mask, undef) + ret %load +} + +define @masked_load_nxv4f32( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv4f32: +; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv4f32( *%a, i32 4, %mask, undef) + ret %load +} + +define @masked_load_nxv4f16( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv4f16: +; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv4f16( *%a, i32 2, %mask, undef) + ret %load +} + +define @masked_load_nxv8f16( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv8f16: +; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv8f16( *%a, i32 2, %mask, undef) + ret %load +} + +declare @llvm.masked.load.nxv2i64(*, i32, , ) +declare @llvm.masked.load.nxv4i32(*, i32, , ) +declare @llvm.masked.load.nxv8i16(*, i32, , ) +declare @llvm.masked.load.nxv16i8(*, i32, , ) + +declare @llvm.masked.load.nxv2f64(*, i32, , ) +declare @llvm.masked.load.nxv2f32(*, i32, , ) +declare @llvm.masked.load.nxv2f16(*, i32, , ) +declare @llvm.masked.load.nxv4f32(*, i32, , ) +declare @llvm.masked.load.nxv4f16(*, i32, , ) +declare @llvm.masked.load.nxv8f16(*, i32, , ) Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll @@ -0,0 +1,66 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Masked Loads +; + +define @masked_sload_nxv2i8( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv2i8: +; CHECK: ld1sb { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i8( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv2i16( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv2i16: +; CHECK: ld1sh { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i16( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv2i32( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv2i32: +; CHECK: ld1sw { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i32( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv4i8( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv4i8: +; CHECK: ld1sb { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv4i8( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv4i16( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv4i16: +; CHECK: ld1sh { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv4i16( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv8i8( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv8i8: +; CHECK: ld1sb { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv8i8( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +declare @llvm.masked.load.nxv2i8(*, i32, , ) +declare @llvm.masked.load.nxv2i16(*, i32, , ) +declare @llvm.masked.load.nxv2i32(*, i32, , ) +declare @llvm.masked.load.nxv4i8(*, i32, , ) +declare @llvm.masked.load.nxv4i16(*, i32, , ) +declare @llvm.masked.load.nxv8i8(*, i32, , ) Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll @@ -0,0 +1,72 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Masked Loads +; + +define @masked_zload_nxv2i8(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv2i8: +; CHECK-NOT: ld1sb +; CHECK: ld1b { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i8(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv2i16(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv2i16: +; CHECK-NOT: ld1sh +; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i16(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv2i32(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv2i32: +; CHECK-NOT: ld1sw +; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i32(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv4i8(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv4i8: +; CHECK-NOT: ld1sb +; CHECK: ld1b { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv4i8(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv4i16(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv4i16: +; CHECK-NOT: ld1sh +; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv4i16(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv8i8(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv8i8: +; CHECK-NOT: ld1sb +; CHECK: ld1b { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv8i8(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +declare @llvm.masked.load.nxv2i8(*, i32, , ) +declare @llvm.masked.load.nxv2i16(*, i32, , ) +declare @llvm.masked.load.nxv2i32(*, i32, , ) +declare @llvm.masked.load.nxv4i8(*, i32, , ) +declare @llvm.masked.load.nxv4i16(*, i32, , ) +declare @llvm.masked.load.nxv8i8(*, i32, , )