diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6067,27 +6067,25 @@ if (ISD::isConstantSplatVectorAllOnes(N1.getNode())) return N0; - // fold (and (masked_load) (build_vec (x, ...))) to zext_masked_load + // fold (and (masked_load) (splat_vec (x, ...))) to zext_masked_load auto *MLoad = dyn_cast(N0); - auto *BVec = dyn_cast(N1); - if (MLoad && BVec && MLoad->getExtensionType() == ISD::EXTLOAD && - N0.hasOneUse() && N1.hasOneUse()) { + ConstantSDNode *Splat = isConstOrConstSplat(N1, true, true); + if (MLoad && MLoad->getExtensionType() == ISD::EXTLOAD && N0.hasOneUse() && + Splat && N1.hasOneUse()) { EVT LoadVT = MLoad->getMemoryVT(); EVT ExtVT = VT; if (TLI.isLoadExtLegal(ISD::ZEXTLOAD, ExtVT, LoadVT)) { // For this AND to be a zero extension of the masked load the elements // of the BuildVec must mask the bottom bits of the extended element // type - if (ConstantSDNode *Splat = BVec->getConstantSplatNode()) { - uint64_t ElementSize = - LoadVT.getVectorElementType().getScalarSizeInBits(); - if (Splat->getAPIntValue().isMask(ElementSize)) { - return DAG.getMaskedLoad( - ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(), - MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(), - LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(), - ISD::ZEXTLOAD, MLoad->isExpandingLoad()); - } + uint64_t ElementSize = + LoadVT.getVectorElementType().getScalarSizeInBits(); + if (Splat->getAPIntValue().isMask(ElementSize)) { + return DAG.getMaskedLoad( + ExtVT, SDLoc(N), MLoad->getChain(), MLoad->getBasePtr(), + MLoad->getOffset(), MLoad->getMask(), MLoad->getPassThru(), + LoadVT, MLoad->getMemOperand(), MLoad->getAddressingMode(), + ISD::ZEXTLOAD, MLoad->isExpandingLoad()); } } } diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1230,7 +1230,6 @@ setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i16, Legal); setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i32, Legal); setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i8, Legal); - setLoadExtAction(Op, MVT::nxv2i32, MVT::nxv2i16, Legal); setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i16, Legal); setLoadExtAction(Op, MVT::nxv8i16, MVT::nxv8i8, Legal); } diff --git a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll --- a/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll +++ b/llvm/test/CodeGen/AArch64/sve-masked-ldst-zext.ll @@ -97,7 +97,7 @@ ; CHECK-LABEL: masked_zload_2i16_2f64: ; CHECK: ld1h { z0.d }, p0/z, [x0] ; CHECK-NEXT: ptrue p0.d -; CHECK-NEXT: ucvtf z0.d, p0/m, z0.s +; CHECK-NEXT: ucvtf z0.d, p0/m, z0.d ; CHECK-NEXT: ret %wide.load = call @llvm.masked.load.nxv2i16(* %in, i32 2, %mask, undef) %zext = zext %wide.load to