Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -945,7 +945,8 @@ /// \return The minimum vectorization factor for types of given element /// bit width, or 0 if there is no minimum VF. The returned value only /// applies when shouldMaximizeVectorBandwidth returns true. - unsigned getMinimumVF(unsigned ElemWidth) const; + /// If IsScalableVF is true, the returned ElementCount must be a scalable VF. + ElementCount getMinimumVF(unsigned ElemWidth, bool IsScalableVF) const; /// \return The maximum vectorization factor for types of given element /// bit width and opcode, or 0 if there is no maximum VF. @@ -1523,7 +1524,8 @@ virtual unsigned getMinVectorRegisterBitWidth() = 0; virtual Optional getMaxVScale() const = 0; virtual bool shouldMaximizeVectorBandwidth(bool OptSize) const = 0; - virtual unsigned getMinimumVF(unsigned ElemWidth) const = 0; + virtual ElementCount getMinimumVF(unsigned ElemWidth, + bool IsScalableVF) const = 0; virtual unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const = 0; virtual bool shouldConsiderAddressTypePromotion( const Instruction &I, bool &AllowPromotionWithoutCommonHeader) = 0; @@ -1951,8 +1953,9 @@ bool shouldMaximizeVectorBandwidth(bool OptSize) const override { return Impl.shouldMaximizeVectorBandwidth(OptSize); } - unsigned getMinimumVF(unsigned ElemWidth) const override { - return Impl.getMinimumVF(ElemWidth); + ElementCount getMinimumVF(unsigned ElemWidth, + bool IsScalableVF) const override { + return Impl.getMinimumVF(ElemWidth, IsScalableVF); } unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const override { return Impl.getMaximumVF(ElemWidth, Opcode); Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -374,7 +374,9 @@ bool shouldMaximizeVectorBandwidth(bool OptSize) const { return false; } - unsigned getMinimumVF(unsigned ElemWidth) const { return 0; } + ElementCount getMinimumVF(unsigned ElemWidth, bool IsScalableVF) const { + return ElementCount::get(0, IsScalableVF); + } unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const { return 0; } Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -640,8 +640,9 @@ return TTIImpl->shouldMaximizeVectorBandwidth(OptSize); } -unsigned TargetTransformInfo::getMinimumVF(unsigned ElemWidth) const { - return TTIImpl->getMinimumVF(ElemWidth); +ElementCount TargetTransformInfo::getMinimumVF(unsigned ElemWidth, + bool IsScalableVF) const { + return TTIImpl->getMinimumVF(ElemWidth, IsScalableVF); } unsigned TargetTransformInfo::getMaximumVF(unsigned ElemWidth, Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -115,6 +115,8 @@ return ST->getMinVectorRegisterBitWidth(); } + ElementCount getMinimumVF(unsigned ElemWidth, bool IsScalableVF) const; + Optional getMaxVScale() const { if (ST->hasSVE()) return AArch64::SVEMaxBitsPerVector / AArch64::SVEBitsPerBlock; Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -27,6 +27,15 @@ static cl::opt EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix", cl::init(true), cl::Hidden); +ElementCount AArch64TTIImpl::getMinimumVF(unsigned ElementWidth, + bool IsScalableVF) const { + // SVE always needs a minimum of two elements per vector. + if (IsScalableVF) + return ElementCount::getScalable(2); + + return BaseT::getMinimumVF(ElementWidth, IsScalableVF); +} + bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, const Function *Callee) const { const TargetMachine &TM = getTLI()->getTargetMachine(); Index: llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h =================================================================== --- llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -82,7 +82,7 @@ unsigned getMaxInterleaveFactor(unsigned VF); unsigned getRegisterBitWidth(bool Vector) const; unsigned getMinVectorRegisterBitWidth() const; - unsigned getMinimumVF(unsigned ElemWidth) const; + ElementCount getMinimumVF(unsigned ElemWidth, bool IsScalableVF) const; bool shouldMaximizeVectorBandwidth(bool OptSize) const { return true; Index: llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp =================================================================== --- llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -104,8 +104,10 @@ return useHVX() ? ST.getVectorLength()*8 : 32; } -unsigned HexagonTTIImpl::getMinimumVF(unsigned ElemWidth) const { - return (8 * ST.getVectorLength()) / ElemWidth; +ElementCount HexagonTTIImpl::getMinimumVF(unsigned ElemWidth, + bool IsScalableVF) const { + assert(!IsScalableVF && "Scalable VFs are not supported for Hexagon"); + return ElementCount::getFixed((8 * ST.getVectorLength()) / ElemWidth); } unsigned HexagonTTIImpl::getScalarizationOverhead(VectorType *Ty, Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5770,7 +5770,7 @@ break; } } - if (auto MinVF = ElementCount::getFixed(TTI.getMinimumVF(SmallestType))) { + if (ElementCount MinVF = TTI.getMinimumVF(SmallestType, false)) { if (ElementCount::isKnownLT(MaxVF, MinVF)) { LLVM_DEBUG(dbgs() << "LV: Overriding calculated MaxVF(" << MaxVF << ") with target's minimum: " << MinVF << '\n');