diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -859,6 +859,8 @@ SDValue LowerMGATHER(SDValue Op, SelectionDAG &DAG) const; SDValue LowerMSCATTER(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const; bool isEligibleForTailCallOptimization( diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1154,6 +1154,7 @@ setOperationAction(ISD::FP_TO_SINT, VT, Custom); setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::MUL, VT, Custom); setOperationAction(ISD::MULHS, VT, Custom); setOperationAction(ISD::MULHU, VT, Custom); @@ -1245,6 +1246,7 @@ setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom); setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); setOperationAction(ISD::SELECT, VT, Custom); setOperationAction(ISD::FADD, VT, Custom); @@ -1280,6 +1282,7 @@ setOperationAction(ISD::CONCAT_VECTORS, VT, Custom); setOperationAction(ISD::MGATHER, VT, Custom); setOperationAction(ISD::MSCATTER, VT, Custom); + setOperationAction(ISD::MLOAD, VT, Custom); } setOperationAction(ISD::SPLAT_VECTOR, MVT::nxv8bf16, Custom); @@ -4476,6 +4479,32 @@ return DAG.getNode(Opcode, DL, VTs, Ops); } +SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const { + SDLoc DL(Op); + MaskedLoadSDNode *LoadNode = cast(Op); + assert(LoadNode && "Expected custom lowering of a masked load node"); + EVT VT = Op->getValueType(0); + + if (useSVEForFixedLengthVectorVT(VT, true)) + return LowerFixedLengthVectorMLoadToSVE(Op, DAG); + + SDValue PassThru = LoadNode->getPassThru(); + SDValue Mask = LoadNode->getMask(); + + if (PassThru->isUndef() || isZerosVector(PassThru.getNode())) + return Op; + + SDValue Load = DAG.getMaskedLoad( + VT, DL, LoadNode->getChain(), LoadNode->getBasePtr(), + LoadNode->getOffset(), Mask, DAG.getUNDEF(VT), LoadNode->getMemoryVT(), + LoadNode->getMemOperand(), LoadNode->getAddressingMode(), + LoadNode->getExtensionType()); + + SDValue Result = DAG.getSelect(DL, VT, Mask, Load, PassThru); + + return DAG.getMergeValues({Result, Load.getValue(1)}, DL); +} + // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16. static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST, EVT VT, EVT MemVT, @@ -4854,7 +4883,7 @@ case ISD::TRUNCATE: return LowerTRUNCATE(Op, DAG); case ISD::MLOAD: - return LowerFixedLengthVectorMLoadToSVE(Op, DAG); + return LowerMLOAD(Op, DAG); case ISD::LOAD: if (useSVEForFixedLengthVectorVT(Op.getValueType())) return LowerFixedLengthVectorLoadToSVE(Op, DAG); diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll --- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-nonext.ll @@ -92,6 +92,15 @@ ret %load } +define @masked_load_passthru( *%a, %mask, %passthru) nounwind { +; CHECK-LABEL: masked_load_passthru: +; CHECK-NEXT: ld1w { z1.s }, p0/z, [x0] +; CHECK-NEXT: mov z0.s, p0/m, z1.s +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv4i32( *%a, i32 4, %mask, %passthru) + ret %load +} + ; ; Masked Stores ; diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll --- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-sext.ll @@ -58,6 +58,18 @@ ret %ext } +define @masked_sload_passthru( *%a, %mask, %passthru) { +; CHECK-LABEL: masked_sload_passthru: +; CHECK: ld1sw { [[IN:z[0-9]+]].d }, [[PG1:p[0-9]+]]/z, [x0] +; CHECK-NEXT: ptrue [[PG2:p[0-9]+]].d +; CHECK-NEXT: sxtw z0.d, [[PG2]]/m, z0.d +; CHECK-NEXT: mov z0.d, [[PG1]]/m, [[IN]].d +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i32( *%a, i32 1, %mask, %passthru) + %ext = sext %load to + ret %ext +} + declare @llvm.masked.load.nxv2i8(*, i32, , ) declare @llvm.masked.load.nxv2i16(*, i32, , ) declare @llvm.masked.load.nxv2i32(*, i32, , ) diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll --- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll @@ -64,6 +64,18 @@ ret %ext } +define @masked_zload_passthru(* %src, %mask, %passthru) { +; CHECK-LABEL: masked_zload_passthru: +; CHECK-NOT: ld1sw +; CHECK: ld1w { [[IN:z[0-9]+]].d }, [[PG:p[0-9]+]]/z, [x0] +; CHECK-NEXT: and z0.d, z0.d, #0xffffffff +; CHECK-NEXT: mov z0.d, [[PG]]/m, [[IN]].d +; CHECK-NEXT: ret + %load = call @llvm.masked.load.nxv2i32(* %src, i32 1, %mask, %passthru) + %ext = zext %load to + ret %ext +} + declare @llvm.masked.load.nxv2i8(*, i32, , ) declare @llvm.masked.load.nxv2i16(*, i32, , ) declare @llvm.masked.load.nxv2i32(*, i32, , )