Index: llvm/include/llvm/CodeGen/SelectionDAG.h =================================================================== --- llvm/include/llvm/CodeGen/SelectionDAG.h +++ llvm/include/llvm/CodeGen/SelectionDAG.h @@ -779,6 +779,20 @@ return getNode(ISD::BUILD_VECTOR, DL, VT, Ops); } + // Return a splat ISD::SPLAT_VECTOR node, consisting of Op splatted to all + // elements. + SDValue getSplatVector(EVT VT, const SDLoc &DL, SDValue Op) { + if (Op.getOpcode() == ISD::UNDEF) { + assert((VT.getVectorElementType() == Op.getValueType() || + (VT.isInteger() && + VT.getVectorElementType().bitsLE(Op.getValueType()))) && + "A splatted value must have a width equal or (for integers) " + "greater than the vector element type!"); + return getNode(ISD::UNDEF, SDLoc(), VT); + } + return getNode(ISD::SPLAT_VECTOR, DL, VT, Op); + } + /// Returns an ISD::VECTOR_SHUFFLE node semantically equivalent to /// the shuffle node in input but with swapped operands. /// Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -9108,6 +9108,8 @@ if (!TLI.isLoadExtLegalOrCustom(ExtType, SplitDstVT, SplitSrcVT)) return SDValue(); + assert(!DstVT.isScalableVector() && "Unexpected scalable vector type"); + SDLoc DL(N); const unsigned NumSplits = DstVT.getVectorNumElements() / SplitDstVT.getVectorNumElements(); Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -1279,7 +1279,9 @@ } SDValue Result(N, 0); - if (VT.isVector()) + if (VT.isScalableVector()) + Result = getSplatVector(VT, DL, Result); + else if (VT.isVector()) Result = getSplatBuildVector(VT, DL, Result); return Result; Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4462,12 +4462,15 @@ const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); // Do not serialize masked loads of constant memory with anything. - bool AddToChain = - !AA || !AA->pointsToConstantMemory(MemoryLocation( - PtrOperand, - LocationSize::precise( - DAG.getDataLayout().getTypeStoreSize(I.getType())), - AAInfo)); + MemoryLocation ML; + if (VT.isScalableVector()) + ML = MemoryLocation(PtrOperand); + else + ML = MemoryLocation(PtrOperand, LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo); + bool AddToChain = !AA || !AA->pointsToConstantMemory(ML); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); MachineMemOperand *MMO = Index: llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -140,6 +140,26 @@ return SelectAddrModeXRO(N, Width / 8, Base, Offset, SignExtend, DoShift); } + bool SelectDupZeroOrUndef(SDValue N) { + switch(N->getOpcode()) { + case ISD::UNDEF: + return true; + case AArch64ISD::DUP: + case ISD::SPLAT_VECTOR: { + auto Opnd0 = N->getOperand(0); + if (auto CN = dyn_cast(Opnd0)) + if (CN->isNullValue()) + return true; + if (auto CN = dyn_cast(Opnd0)) + if (CN->isZero()) + return true; + } + default: + break; + } + + return false; + } /// Form sequences of consecutive 64/128-bit registers for use in NEON /// instructions making use of a vector-list (e.g. ldN, tbl). Vecs must have Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -741,6 +741,7 @@ return TargetLowering::getInlineAsmMemConstraint(ConstraintCode); } + bool isVectorLoadExtDesirable(SDValue ExtVal) const override; bool isUsedByReturnOnly(SDNode *N, SDValue &Chain) const override; bool mayBeEmittedAsTailCall(const CallInst *CI) const override; bool getIndexedAddressParts(SDNode *Op, SDValue &Base, SDValue &Offset, Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2886,6 +2886,10 @@ } } +bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { + return ExtVal.getValueType().isScalableVector(); +} + // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16. static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST, EVT VT, EVT MemVT, Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -259,6 +259,55 @@ SDTCisSameAs<1, 2>, SDTCisSameAs<1, 3>, SDTCisSameAs<1, 4>]>; +// non-extending masked load fragment. +def nonext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getExtensionType() == ISD::NON_EXTLOAD; +}]>; +// sign extending masked load fragments. +def asext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def),[{ + return cast(N)->getExtensionType() == ISD::EXTLOAD || + cast(N)->getExtensionType() == ISD::SEXTLOAD; +}]>; +def asext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def asext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def asext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (asext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; +// zero extending masked load fragments. +def zext_masked_load : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (masked_ld node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getExtensionType() == ISD::ZEXTLOAD; +}]>; +def zext_masked_load_i8 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i8; +}]>; +def zext_masked_load_i16 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i16; +}]>; +def zext_masked_load_i32 : + PatFrag<(ops node:$ptr, node:$pred, node:$def), + (zext_masked_load node:$ptr, node:$pred, node:$def), [{ + return cast(N)->getMemoryVT().getScalarType() == MVT::i32; +}]>; // Node definitions. def AArch64adrp : SDNode<"AArch64ISD::ADRP", SDTIntUnaryOp, []>; Index: llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td +++ llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td @@ -1070,6 +1070,44 @@ def : Pat<(nxv2f64 (bitconvert (nxv8f16 ZPR:$src))), (nxv2f64 ZPR:$src)>; def : Pat<(nxv2f64 (bitconvert (nxv4f32 ZPR:$src))), (nxv2f64 ZPR:$src)>; + // Add more complex addressing modes here as required + multiclass pred_load { + + def _default_z : Pat<(Ty (Load GPR64:$base, (PredTy PPR:$gp), (SVEDup0Undef))), + (RegImmInst PPR:$gp, GPR64:$base, (i64 0))>; + } + + // 2-element contiguous loads + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + // 4-element contiguous loads + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + // 8-element contiguous loads + defm : pred_load; + defm : pred_load; + defm : pred_load; + defm : pred_load; + + // 16-element contiguous loads + defm : pred_load; + } let Predicates = [HasSVE2] in { Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -147,6 +147,21 @@ bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info); + bool isLegalMaskedLoad(Type *DataType, MaybeAlign Alignment) { + if (!isa(DataType) || !ST->hasSVE()) + return false; + + Type *Ty = DataType->getVectorElementType(); + if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) + return true; + + if (Ty->isIntegerTy(8) || Ty->isIntegerTy(16) || + Ty->isIntegerTy(32) || Ty->isIntegerTy(64)) + return true; + + return false; + } + int getInterleavedMemoryOpCost(unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef Indices, unsigned Alignment, unsigned AddressSpace, Index: llvm/lib/Target/AArch64/SVEInstrFormats.td =================================================================== --- llvm/lib/Target/AArch64/SVEInstrFormats.td +++ llvm/lib/Target/AArch64/SVEInstrFormats.td @@ -293,6 +293,8 @@ : Pat<(vtd (op vt1:$Op1, vt2:$Op2, vt3:$Op3)), (inst $Op1, $Op2, $Op3)>; +def SVEDup0Undef : ComplexPattern; + //===----------------------------------------------------------------------===// // SVE Predicate Misc Group //===----------------------------------------------------------------------===// @@ -4736,14 +4738,14 @@ multiclass sve_mem_cld_si_base dtype, bit nf, string asm, RegisterOperand listty, ZPRRegOp zprty> { - def _REAL : sve_mem_cld_si_base; + def "" : sve_mem_cld_si_base; def : InstAlias(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 0>; + (!cast(NAME) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 0>; def : InstAlias(NAME # _REAL) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), 0>; + (!cast(NAME) zprty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, simm4s1:$imm4), 0>; def : InstAlias(NAME # _REAL) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 1>; + (!cast(NAME) listty:$Zt, PPR3bAny:$Pg, GPR64sp:$Rn, 0), 1>; } multiclass sve_mem_cld_si dtype, string asm, RegisterOperand listty, Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll @@ -0,0 +1,87 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Masked Loads +; + +define @masked_load_nxv2i64( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv2i64: +; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv2i64( *%a, i32 8, %mask, undef) + ret %load +} + +define @masked_load_nxv4i32( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv4i32: +; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv4i32( *%a, i32 4, %mask, undef) + ret %load +} + +define @masked_load_nxv8i16( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv8i16: +; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv8i16( *%a, i32 2, %mask, undef) + ret %load +} + +define @masked_load_nxv16i8( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv16i8: +; CHECK: ld1b { [[IN:z[0-9]+]].b }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv16i8( *%a, i32 1, %mask, undef) + ret %load +} + +define @masked_load_nxv2f64( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv2f64: +; CHECK: ld1d { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv2f64( *%a, i32 8, %mask, undef) + ret %load +} + +define @masked_load_nxv2f32( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv2f32: +; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv2f32( *%a, i32 4, %mask, undef) + ret %load +} + +define @masked_load_nxv2f16( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv2f16: +; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv2f16( *%a, i32 2, %mask, undef) + ret %load +} + +define @masked_load_nxv4f32( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv4f32: +; CHECK: ld1w { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv4f32( *%a, i32 4, %mask, undef) + ret %load +} + +define @masked_load_nxv4f16( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv4f16: +; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv4f16( *%a, i32 2, %mask, undef) + ret %load +} + +define @masked_load_nxv8f16( *%a, %mask) { +; CHECK-LABEL: masked_load_nxv8f16: +; CHECK: ld1h { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] + %load = call @llvm.masked.load.nxv8f16( *%a, i32 2, %mask, undef) + ret %load +} + +declare @llvm.masked.load.nxv2i64(*, i32, , ) +declare @llvm.masked.load.nxv4i32(*, i32, , ) +declare @llvm.masked.load.nxv8i16(*, i32, , ) +declare @llvm.masked.load.nxv16i8(*, i32, , ) + +declare @llvm.masked.load.nxv2f64(*, i32, , ) +declare @llvm.masked.load.nxv2f32(*, i32, , ) +declare @llvm.masked.load.nxv2f16(*, i32, , ) +declare @llvm.masked.load.nxv4f32(*, i32, , ) +declare @llvm.masked.load.nxv4f16(*, i32, , ) +declare @llvm.masked.load.nxv8f16(*, i32, , ) Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll @@ -0,0 +1,66 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Masked Loads +; + +define @masked_sload_nxv2i8( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv2i8: +; CHECK: ld1sb { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i8( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv2i16( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv2i16: +; CHECK: ld1sh { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i16( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv2i32( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv2i32: +; CHECK: ld1sw { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i32( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv4i8( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv4i8: +; CHECK: ld1sb { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv4i8( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv4i16( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv4i16: +; CHECK: ld1sh { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv4i16( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +define @masked_sload_nxv8i8( *%a, %mask) { +; CHECK-LABEL: masked_sload_nxv8i8: +; CHECK: ld1sb { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv8i8( *%a, i32 1, %mask, undef) + %ext = sext %load to + ret %ext +} + +declare @llvm.masked.load.nxv2i8(*, i32, , ) +declare @llvm.masked.load.nxv2i16(*, i32, , ) +declare @llvm.masked.load.nxv2i32(*, i32, , ) +declare @llvm.masked.load.nxv4i8(*, i32, , ) +declare @llvm.masked.load.nxv4i16(*, i32, , ) +declare @llvm.masked.load.nxv8i8(*, i32, , ) Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll @@ -0,0 +1,72 @@ +; RUN: llc -mtriple=aarch64--linux-gnu -mattr=+sve < %s | FileCheck %s + +; +; Masked Loads +; + +define @masked_zload_nxv2i8(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv2i8: +; CHECK-NOT: ld1sb +; CHECK: ld1b { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i8(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv2i16(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv2i16: +; CHECK-NOT: ld1sh +; CHECK: ld1h { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i16(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv2i32(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv2i32: +; CHECK-NOT: ld1sw +; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i32(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv4i8(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv4i8: +; CHECK-NOT: ld1sb +; CHECK: ld1b { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv4i8(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv4i16(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv4i16: +; CHECK-NOT: ld1sh +; CHECK: ld1h { [[IN:z[0-9]+]].s }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv4i16(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +define @masked_zload_nxv8i8(* %src, %mask) { +; CHECK-LABEL: masked_zload_nxv8i8: +; CHECK-NOT: ld1sb +; CHECK: ld1b { [[IN:z[0-9]+]].h }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv8i8(* %src, i32 1, %mask, undef) + %ext = zext %load to + ret %ext +} + +declare @llvm.masked.load.nxv2i8(*, i32, , ) +declare @llvm.masked.load.nxv2i16(*, i32, , ) +declare @llvm.masked.load.nxv2i32(*, i32, , ) +declare @llvm.masked.load.nxv4i8(*, i32, , ) +declare @llvm.masked.load.nxv4i16(*, i32, , ) +declare @llvm.masked.load.nxv8i8(*, i32, , )