diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -478,7 +478,7 @@ /// Returns true if the given vector is a power of 2. bool isPow2VectorType() const { - unsigned NElts = getVectorNumElements(); + unsigned NElts = getVectorMinNumElements(); return !(NElts & (NElts - 1)); } @@ -488,9 +488,10 @@ if (isPow2VectorType()) return *this; - unsigned NElts = getVectorNumElements(); - unsigned Pow2NElts = 1 << Log2_32_Ceil(NElts); - return MVT::getVectorVT(getVectorElementType(), Pow2NElts); + ElementCount NElts = getVectorElementCount(); + unsigned NewMinCount = 1 << Log2_32_Ceil(NElts.getKnownMinValue()); + NElts = ElementCount::get(NewMinCount, NElts.isScalable()); + return MVT::getVectorVT(getVectorElementType(), NElts); } /// If this is a vector, return the element type, otherwise return this. @@ -651,7 +652,8 @@ } } - unsigned getVectorNumElements() const { + /// Given a vector type, return the minimum number of elements it contains. + unsigned getVectorMinNumElements() const { switch (SimpleTy) { default: llvm_unreachable("Not a vector MVT!"); @@ -805,12 +807,12 @@ } ElementCount getVectorElementCount() const { - return ElementCount::get(getVectorNumElements(), isScalableVector()); + return ElementCount::get(getVectorMinNumElements(), isScalableVector()); } - /// Given a vector type, return the minimum number of elements it contains. - unsigned getVectorMinNumElements() const { - return getVectorElementCount().getKnownMinValue(); + unsigned getVectorNumElements() const { + // TODO: Check that this isn't a scalable vector. + return getVectorMinNumElements(); } /// Returns the size of the specified MVT in bits. diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -2101,7 +2101,7 @@ !Subtarget.hasBWI()) return TypeSplitVector; - if (VT.getVectorNumElements() != 1 && + if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 && VT.getVectorElementType() != MVT::i1) return TypeWidenVector; diff --git a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp --- a/llvm/utils/TableGen/CodeGenDAGPatterns.cpp +++ b/llvm/utils/TableGen/CodeGenDAGPatterns.cpp @@ -630,7 +630,7 @@ return false; if (B.getVectorElementType() != P.getVectorElementType()) return false; - return B.getVectorNumElements() < P.getVectorNumElements(); + return B.getVectorMinNumElements() < P.getVectorMinNumElements(); }; /// Return true if S has no element (vector type) that T is a sub-vector of, @@ -696,8 +696,10 @@ // An actual vector type cannot have 0 elements, so we can treat scalars // as zero-length vectors. This way both vectors and scalars can be // processed identically. - auto NoLength = [](const SmallSet &Lengths, MVT T) -> bool { - return !Lengths.count(T.isVector() ? T.getVectorNumElements() : 0); + auto NoLength = [](const SmallSet, 2> &Lengths, + MVT T) -> bool { + return !Lengths.count(std::make_pair( + T.isScalableVector(), T.isVector() ? T.getVectorMinNumElements() : 0)); }; SmallVector Modes; @@ -706,11 +708,13 @@ TypeSetByHwMode::SetType &VS = V.get(M); TypeSetByHwMode::SetType &WS = W.get(M); - SmallSet VN, WN; + SmallSet, 2> VN, WN; for (MVT T : VS) - VN.insert(T.isVector() ? T.getVectorNumElements() : 0); + VN.insert(std::make_pair(T.isScalableVector(), + T.isVector() ? T.getVectorMinNumElements() : 0)); for (MVT T : WS) - WN.insert(T.isVector() ? T.getVectorNumElements() : 0); + WN.insert(std::make_pair(T.isScalableVector(), + T.isVector() ? T.getVectorMinNumElements() : 0)); Changed |= berase_if(VS, std::bind(NoLength, WN, std::placeholders::_1)); Changed |= berase_if(WS, std::bind(NoLength, VN, std::placeholders::_1)); diff --git a/llvm/utils/TableGen/IntrinsicEmitter.cpp b/llvm/utils/TableGen/IntrinsicEmitter.cpp --- a/llvm/utils/TableGen/IntrinsicEmitter.cpp +++ b/llvm/utils/TableGen/IntrinsicEmitter.cpp @@ -378,7 +378,7 @@ MVT VVT = VT; if (VVT.isScalableVector()) Sig.push_back(IIT_SCALABLE_VEC); - switch (VVT.getVectorNumElements()) { + switch (VVT.getVectorMinNumElements()) { default: PrintFatalError("unhandled vector type width in intrinsic!"); case 1: Sig.push_back(IIT_V1); break; case 2: Sig.push_back(IIT_V2); break;