diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -4473,14 +4473,22 @@ assert((EGW == 128 || EGW == 256) && "EGW can only be 128 or 256 bits"); // LMUL * VLEN >= EGW - uint64_t ElemSize = Type->isRVVType(32, false) ? 32 : 64; - uint64_t ElemCount = Type->isRVVType(1) ? 1 : + unsigned ElemSize = Type->isRVVType(32, false) ? 32 : 64; + unsigned ElemCount = Type->isRVVType(1) ? 1 : Type->isRVVType(2) ? 2 : Type->isRVVType(4) ? 4 : Type->isRVVType(8) ? 8 : 16; - float Lmul = (float)(ElemSize * ElemCount) / llvm::RISCV::RVVBitsPerBlock; - uint64_t MinRequiredVLEN = std::max(EGW / Lmul, (float)ElemSize); + + unsigned EGS = EGW / ElemSize; + // If EGS is more than our minimum number of elements we're done. + if (EGS <= ElemCount) + return false; + + // We need vscale to be at least this value. + unsigned VScaleFactor = EGS / ElemCount; + // Vscale is VLEN/RVVBitsPerBlock. + unsigned MinRequiredVLEN = VScaleFactor * llvm::RISCV::RVVBitsPerBlock; std::string RequiredExt = "zvl" + std::to_string(MinRequiredVLEN) + "b"; if (!TI.hasFeature(RequiredExt)) return S.Diag(TheCall->getBeginLoc(),