diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -4625,13 +4625,13 @@ if (!PredVT.isScalableVector() || PredVT.getVectorElementType() != MVT::i1) return EVT(); - const unsigned NumElts = PredVT.getVectorNumElements(); - - if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16) + if (PredVT != MVT::nxv16i1 && PredVT != MVT::nxv8i1 && + PredVT != MVT::nxv4i1 && PredVT != MVT::nxv2i1) return EVT(); - EVT ScalarVT = EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / NumElts); - EVT MemVT = EVT::getVectorVT(Ctx, ScalarVT, NumElts, /*IsScalable=*/true); + ElementCount EC = PredVT.getVectorElementCount(); + EVT ScalarVT = EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.Min); + EVT MemVT = EVT::getVectorVT(Ctx, ScalarVT, EC); return MemVT; }