Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6054,9 +6054,11 @@ return N0; // fold (and (masked_load) (build_vec (x, ...))) to zext_masked_load + // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load auto *MLoad = dyn_cast(N0); - auto *BVec = dyn_cast(N1); - if (MLoad && BVec && MLoad->getExtensionType() == ISD::EXTLOAD && + if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && + (dyn_cast(N1) || + N1->getOpcode() == ISD::SPLAT_VECTOR) && N0.hasOneUse() && N1.hasOneUse()) { EVT LoadVT = MLoad->getMemoryVT(); EVT ExtVT = VT; @@ -6064,7 +6066,7 @@ // For this AND to be a zero extension of the masked load the elements // of the BuildVec must mask the bottom bits of the extended element // type - if (ConstantSDNode *Splat = BVec->getConstantSplatNode()) { + if (ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true)) { uint64_t ElementSize = LoadVT.getVectorElementType().getScalarSizeInBits(); if (Splat->getAPIntValue().isMask(ElementSize)) { Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1247,7 +1247,6 @@ setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i16, Legal); setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i32, Legal); setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i8, Legal); - setLoadExtAction(Op, MVT::nxv2i32, MVT::nxv2i16, Legal); setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i16, Legal); setLoadExtAction(Op, MVT::nxv8i16, MVT::nxv8i8, Legal); } Index: llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll +++ llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll @@ -97,7 +97,7 @@ ; CHECK-LABEL: masked_zload_2i16_2f64: ; CHECK: ld1h { z0.d }, p0/z, [x0] ; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: ucvtf z0.d, p0/m, z0.s +; CHECK-NEXT: ucvtf z0.d, p0/m, z0.d ; CHECK-NEXT: ret %wide.load = call @llvm.masked.load.nxv2i16(* %in, i32 2, %mask, undef) %zext = zext %wide.load to