diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -206,11 +206,9 @@ static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, const Query &Q) { - Type *Ty = V->getType(); + auto *VTy = dyn_cast(V->getType()); APInt DemandedElts = - Ty->isVectorTy() - ? APInt::getAllOnesValue(cast(Ty)->getNumElements()) - : APInt(1, 1); + VTy ? APInt::getAllOnesValue(VTy->getElementCount().Min) : APInt(1, 1); computeKnownBits(V, DemandedElts, Known, Depth, Q); } @@ -1874,32 +1872,48 @@ /// where V is a vector, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true /// for all of the demanded elements in the vector specified by DemandedElts. +/// +/// FIXME: If the type of V is ScalableVectorType, then the value of +/// DemandedElts is treated as a boolean. If no bits are demanded, then it +/// immediately bails. If any bits are demanded, then the specific value is +/// ignored. In this case, no code may use DemandedElts to derive any +/// information about the known bits void computeKnownBits(const Value *V, const APInt &DemandedElts, KnownBits &Known, unsigned Depth, const Query &Q) { + if (!DemandedElts) { + // No demanded elts, better to assume we don't know anything. + Known.resetAll(); + return; + } + assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); - unsigned BitWidth = Known.getBitWidth(); +#ifndef NDEBUG Type *Ty = V->getType(); + unsigned BitWidth = Known.getBitWidth(); + assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) && "Not integer or pointer type!"); - assert(((Ty->isVectorTy() && cast(Ty)->getNumElements() == - DemandedElts.getBitWidth()) || - (!Ty->isVectorTy() && DemandedElts == APInt(1, 1))) && - "Unexpected vector size"); - Type *ScalarTy = Ty->getScalarType(); - unsigned ExpectedWidth = ScalarTy->isPointerTy() ? - Q.DL.getPointerTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy); - assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth"); - (void)BitWidth; - (void)ExpectedWidth; + if (auto *FVTy = dyn_cast(Ty)) { + assert( + FVTy->getNumElements() == DemandedElts.getBitWidth() && + "DemandedElt width should equal the fixed vector number of elements"); + } else { + assert(DemandedElts == APInt(1, 1) && + "DemandedElt width should be 1 for scalars"); + } - if (!DemandedElts) { - // No demanded elts, better to assume we don't know anything. - Known.resetAll(); - return; + Type *ScalarTy = Ty->getScalarType(); + if (ScalarTy->isPointerTy()) { + assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) && + "V and Known should have same BitWidth"); + } else { + assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) && + "V and Known should have same BitWidth"); } +#endif const APInt *C; if (match(V, m_APInt(C))) { @@ -1915,17 +1929,14 @@ } // Handle a constant vector by taking the intersection of the known bits of // each element. - if (const ConstantDataSequential *CDS = dyn_cast(V)) { - assert((!Ty->isVectorTy() || - CDS->getNumElements() == DemandedElts.getBitWidth()) && - "Unexpected vector size"); - // We know that CDS must be a vector of integers. Take the intersection of + if (const ConstantDataVector *CDV = dyn_cast(V)) { + // We know that CDV must be a vector of integers. Take the intersection of // each element. Known.Zero.setAllBits(); Known.One.setAllBits(); - for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) { - if (Ty->isVectorTy() && !DemandedElts[i]) + for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) { + if (!DemandedElts[i]) continue; - APInt Elt = CDS->getElementAsAPInt(i); + APInt Elt = CDV->getElementAsAPInt(i); Known.Zero &= ~Elt; Known.One &= Elt; } @@ -1933,8 +1944,6 @@ } if (const auto *CV = dyn_cast(V)) { - assert(CV->getNumOperands() == DemandedElts.getBitWidth() && - "Unexpected vector size"); // We know that CV must be a vector of integers. Take the intersection of // each element. Known.Zero.setAllBits(); Known.One.setAllBits(); @@ -1982,7 +1991,7 @@ computeKnownBitsFromOperator(I, DemandedElts, Known, Depth, Q); // Aligned pointers have trailing zeros - refine Known.Zero set - if (Ty->isPointerTy()) { + if (isa(V->getType())) { const MaybeAlign Align = V->getPointerAlignment(Q.DL); if (Align) Known.Zero.setLowBits(countTrailingZeros(Align->value()));