diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -206,11 +206,9 @@ static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, const Query &Q) { - Type *Ty = V->getType(); + auto *VTy = dyn_cast(V->getType()); APInt DemandedElts = - Ty->isVectorTy() - ? APInt::getAllOnesValue(cast(Ty)->getNumElements()) - : APInt(1, 1); + VTy ? APInt::getAllOnesValue(VTy->getElementCount().Min) : APInt(1, 1); computeKnownBits(V, DemandedElts, Known, Depth, Q); } @@ -1878,22 +1876,35 @@ KnownBits &Known, unsigned Depth, const Query &Q) { assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); + Type *Ty = V->getType(); + + Optional EC; + if (auto *VTy = dyn_cast(Ty)) + EC = VTy->getElementCount(); + +#ifndef NDEBUG unsigned BitWidth = Known.getBitWidth(); - Type *Ty = V->getType(); assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) && "Not integer or pointer type!"); - assert(((Ty->isVectorTy() && cast(Ty)->getNumElements() == - DemandedElts.getBitWidth()) || - (!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) && - "Unexpected vector size"); + + if (EC) { + assert(EC->Min == DemandedElts.getBitWidth() && + "DemandedElt width should equal the vector min number of elements"); + } else { + assert(DemandedElts == APInt(1, 1) && + "DemandedElt width should be 1 for scalars"); + } Type *ScalarTy = Ty->getScalarType(); - unsigned ExpectedWidth = ScalarTy->isPointerTy() ? - Q.DL.getPointerTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy); - assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth"); - (void)BitWidth; - (void)ExpectedWidth; + if (ScalarTy->isPointerTy()) { + assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) && + "V and Known should have same BitWidth"); + } else { + assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) && + "V and Known should have same BitWidth"); + } +#endif if (!DemandedElts) { // No demanded elts, better to assume we don't know anything. @@ -1915,26 +1926,26 @@ } // Handle a constant vector by taking the intersection of the known bits of // each element. - if (const ConstantDataSequential *CDS = dyn_cast(V)) { - assert((!Ty->isVectorTy() || - CDS->getNumElements() == DemandedElts.getBitWidth()) && - "Unexpected vector size"); - // We know that CDS must be a vector of integers. Take the intersection of + if (const ConstantDataVector *CDV = dyn_cast(V)) { + assert((CDV->getNumElements() == EC->Min) && + "CDV->getNumElements() and EC->Min must agree"); + // We know that CDV must be a vector of integers. Take the intersection of // each element. Known.Zero.setAllBits(); Known.One.setAllBits(); - for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) { - if (Ty->isVectorTy() && !DemandedElts[i]) + for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) { + if (!DemandedElts[i]) continue; - APInt Elt = CDS->getElementAsAPInt(i); + APInt Elt = CDV->getElementAsAPInt(i); Known.Zero &= ~Elt; Known.One &= Elt; } return; } - if (const auto *CV = dyn_cast(V)) { - assert(CV->getNumOperands() == DemandedElts.getBitWidth() && - "Unexpected vector size"); + if (isa(V) && !EC->Scalable) { + const auto *CV = cast(V); + assert((CV->getNumOperands() == EC->Min) && + "CV->getNumOperands() and EC->Min must agree"); // We know that CV must be a vector of integers. Take the intersection of // each element. Known.Zero.setAllBits(); Known.One.setAllBits();