Index: include/llvm/Analysis/ValueTracking.h =================================================================== --- include/llvm/Analysis/ValueTracking.h +++ include/llvm/Analysis/ValueTracking.h @@ -138,9 +138,9 @@ /// the other bits. We know that at least 1 bit is always equal to the sign /// bit (itself), but other cases can give us information. For example, /// immediately after an "ashr X, 2", we know that the top 3 bits are all - /// equal to each other, so we return 3. - /// - /// 'Op' must have a scalar integer type. + /// equal to each other, so we return 3. For vectors, return the number of + /// sign bits for the vector element with the mininum number of known sign + /// bits. unsigned ComputeNumSignBits(Value *Op, const DataLayout &DL, unsigned Depth = 0, AssumptionCache *AC = nullptr, const Instruction *CxtI = nullptr, Index: lib/Analysis/ValueTracking.cpp =================================================================== --- lib/Analysis/ValueTracking.cpp +++ lib/Analysis/ValueTracking.cpp @@ -1918,16 +1918,12 @@ return (KnownZero & Mask) == Mask; } - - /// Return the number of times the sign bit of the register is replicated into /// the other bits. We know that at least 1 bit is always equal to the sign bit /// (itself), but other cases can give us information. For example, immediately /// after an "ashr X, 2", we know that the top 3 bits are all equal to each -/// other, so we return 3. -/// -/// 'Op' must have a scalar integer type. -/// +/// other, so we return 3. For vectors, return the number of sign bits for the +/// vector element with the mininum number of known sign bits. unsigned ComputeNumSignBits(Value *V, unsigned Depth, const Query &Q) { unsigned TyBits = Q.DL.getTypeSizeInBits(V->getType()->getScalarType()); unsigned Tmp, Tmp2; @@ -2123,6 +2119,35 @@ // Finally, if we can prove that the top bits of the result are 0's or 1's, // use this information. + + // For vector constants, loop over the elements and find the constant with the + // minimum number of sign bits. + auto *CV = dyn_cast(V); + if (CV && CV->getType()->isVectorTy()) { + unsigned MinSignBits = TyBits; + unsigned NumElts = CV->getType()->getVectorNumElements(); + bool UnknownElt = false; + for (unsigned i = 0; i != NumElts; ++i) { + // If we find a non-ConstantInt, bail out. + auto *Elt = dyn_cast_or_null(CV->getAggregateElement(i)); + if (!Elt) { + UnknownElt = true; + break; + } + + // If the sign bit is 1, flip the bits, so we always count leading zeros. + APInt EltVal = Elt->getValue(); + if (EltVal.isNegative()) + EltVal = ~EltVal; + MinSignBits = std::min(MinSignBits, EltVal.countLeadingZeros()); + } + + // If we examined all elements successfully, we're done (we can't do any + // better than this). If not, keep trying with computeKnownBits() below. + if (!UnknownElt) + return MinSignBits; + } + APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); APInt Mask; computeKnownBits(V, KnownZero, KnownOne, Depth, Q); Index: test/Transforms/InstSimplify/shr-nop.ll =================================================================== --- test/Transforms/InstSimplify/shr-nop.ll +++ test/Transforms/InstSimplify/shr-nop.ll @@ -423,8 +423,7 @@ define <2 x i4> @ashr_zero_minus1_vec(<2 x i4> %shiftval) { ; CHECK-LABEL: @ashr_zero_minus1_vec( -; CHECK-NEXT: [[SHR:%.*]] = ashr <2 x i4> , %shiftval -; CHECK-NEXT: ret <2 x i4> [[SHR]] +; CHECK-NEXT: ret <2 x i4> ; %shr = ashr <2 x i4> , %shiftval ret <2 x i4> %shr