diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -5760,39 +5760,40 @@ // Ensure MaxVF is a power of 2; the dependence distance bound may not be. // Note that both WidestRegister and WidestType may not be a powers of 2. - unsigned MaxVectorSize = PowerOf2Floor(WidestRegister / WidestType); + auto MaxVectorSize = + ElementCount::getFixed(PowerOf2Floor(WidestRegister / WidestType)); LLVM_DEBUG(dbgs() << "LV: The Smallest and Widest types: " << SmallestType << " / " << WidestType << " bits.\n"); LLVM_DEBUG(dbgs() << "LV: The Widest register safe to use is: " << WidestRegister << " bits.\n"); - assert(MaxVectorSize <= WidestRegister && + assert(MaxVectorSize.getFixedValue() <= WidestRegister && "Did not expect to pack so many elements" " into one vector!"); - if (MaxVectorSize == 0) { + if (MaxVectorSize.getFixedValue() == 0) { LLVM_DEBUG(dbgs() << "LV: The target has no vector registers.\n"); - MaxVectorSize = 1; - return ElementCount::getFixed(MaxVectorSize); - } else if (ConstTripCount && ConstTripCount < MaxVectorSize && + return ElementCount::getFixed(1); + } else if (ConstTripCount && ConstTripCount < MaxVectorSize.getFixedValue() && isPowerOf2_32(ConstTripCount)) { // We need to clamp the VF to be the ConstTripCount. There is no point in // choosing a higher viable VF as done in the loop below. LLVM_DEBUG(dbgs() << "LV: Clamping the MaxVF to the constant trip count: " << ConstTripCount << "\n"); - MaxVectorSize = ConstTripCount; - return ElementCount::getFixed(MaxVectorSize); + return ElementCount::getFixed(ConstTripCount); } - unsigned MaxVF = MaxVectorSize; + ElementCount MaxVF = MaxVectorSize; if (TTI.shouldMaximizeVectorBandwidth(!isScalarEpilogueAllowed()) || (MaximizeBandwidth && isScalarEpilogueAllowed())) { // Collect all viable vectorization factors larger than the default MaxVF // (i.e. MaxVectorSize). SmallVector VFs; - unsigned NewMaxVectorSize = WidestRegister / SmallestType; - for (unsigned VS = MaxVectorSize * 2; VS <= NewMaxVectorSize; VS *= 2) - VFs.push_back(ElementCount::getFixed(VS)); + auto MaxVectorSizeMaxBW = + ElementCount::getFixed(WidestRegister / SmallestType); + for (ElementCount VS = MaxVectorSize * 2; + ElementCount::isKnownLE(VS, MaxVectorSizeMaxBW); VS *= 2) + VFs.push_back(VS); // For each VF calculate its register usage. auto RUs = calculateRegisterUsage(VFs); @@ -5801,25 +5802,25 @@ // ones. for (int i = RUs.size() - 1; i >= 0; --i) { bool Selected = true; - for (auto& pair : RUs[i].MaxLocalUsers) { + for (auto &pair : RUs[i].MaxLocalUsers) { unsigned TargetNumRegisters = TTI.getNumberOfRegisters(pair.first); if (pair.second > TargetNumRegisters) Selected = false; } if (Selected) { - MaxVF = VFs[i].getKnownMinValue(); + MaxVF = VFs[i]; break; } } - if (unsigned MinVF = TTI.getMinimumVF(SmallestType)) { - if (MaxVF < MinVF) { + if (auto MinVF = ElementCount::getFixed(TTI.getMinimumVF(SmallestType))) { + if (ElementCount::isKnownLT(MaxVF, MinVF)) { LLVM_DEBUG(dbgs() << "LV: Overriding calculated MaxVF(" << MaxVF << ") with target's minimum: " << MinVF << '\n'); MaxVF = MinVF; } } } - return ElementCount::getFixed(MaxVF); + return MaxVF; } VectorizationFactor