diff --git a/llvm/include/llvm/ADT/BitVector.h b/llvm/include/llvm/ADT/BitVector.h --- a/llvm/include/llvm/ADT/BitVector.h +++ b/llvm/include/llvm/ADT/BitVector.h @@ -532,24 +532,10 @@ // Comparison operators. bool operator==(const BitVector &RHS) const { - unsigned ThisWords = NumBitWords(size()); - unsigned RHSWords = NumBitWords(RHS.size()); - unsigned i; - for (i = 0; i != std::min(ThisWords, RHSWords); ++i) - if (Bits[i] != RHS.Bits[i]) - return false; - - // Verify that any extra words are all zeros. - if (i != ThisWords) { - for (; i != ThisWords; ++i) - if (Bits[i]) - return false; - } else if (i != RHSWords) { - for (; i != RHSWords; ++i) - if (RHS.Bits[i]) - return false; - } - return true; + if (size() != RHS.size()) + return false; + unsigned NumWords = NumBitWords(size()); + return Bits.take_front(NumWords) == RHS.Bits.take_front(NumWords); } bool operator!=(const BitVector &RHS) const { diff --git a/llvm/unittests/ADT/BitVectorTest.cpp b/llvm/unittests/ADT/BitVectorTest.cpp --- a/llvm/unittests/ADT/BitVectorTest.cpp +++ b/llvm/unittests/ADT/BitVectorTest.cpp @@ -179,6 +179,24 @@ EXPECT_TRUE(Vec.empty()); } +TYPED_TEST(BitVectorTest, Equality) { + TypeParam A; + TypeParam B; + EXPECT_TRUE(A == B); + A.resize(10); + EXPECT_FALSE(A == B); + B.resize(10); + EXPECT_TRUE(A == B); + A.set(5); + EXPECT_FALSE(A == B); + B.set(5); + EXPECT_TRUE(A == B); + A.resize(20); + EXPECT_FALSE(A == B); + B.resize(20); + EXPECT_TRUE(A == B); +} + TYPED_TEST(BitVectorTest, SimpleFindOpsMultiWord) { TypeParam A;