diff --git a/llvm/include/llvm/ADT/SetVector.h b/llvm/include/llvm/ADT/SetVector.h --- a/llvm/include/llvm/ADT/SetVector.h +++ b/llvm/include/llvm/ADT/SetVector.h @@ -56,7 +56,7 @@ typename Set = DenseSet, unsigned N = 0> class SetVector { // Much like in SmallPtrSet, this value should not be too high to prevent - // excessively long linear scans from occuring. + // excessively long linear scans from occurring. static_assert(N <= 32, "Small size should be less than or equal to 32!"); public: @@ -162,7 +162,10 @@ bool insert(const value_type &X) { if constexpr (canBeSmall()) if (isSmall()) { - if (llvm::find(vector_, X) == vector_.end()) { + auto comparator = [&X](const value_type &element) -> bool { + return llvm::DenseMapInfo::isEqual(X, element); + }; + if (llvm::find_if(vector_, comparator) == vector_.end()) { vector_.push_back(X); if (vector_.size() > N) makeBig(); @@ -253,8 +256,12 @@ /// Check if the SetVector contains the given key. bool contains(const key_type &key) const { if constexpr (canBeSmall()) - if (isSmall()) - return is_contained(vector_, key); + if (isSmall()) { + auto comparator = [&key](const value_type &element) -> bool { + return llvm::DenseMapInfo::isEqual(key, element); + }; + return llvm::find_if(vector_, comparator) != vector_.end(); + } return set_.find(key) != set_.end(); } @@ -262,11 +269,7 @@ /// Count the number of elements of a given key in the SetVector. /// \returns 0 if the element is not in the SetVector, 1 if it is. size_type count(const key_type &key) const { - if constexpr (canBeSmall()) - if (isSmall()) - return is_contained(vector_, key); - - return set_.count(key); + return contains(key); } /// Completely clear the SetVector diff --git a/llvm/unittests/ADT/SetVectorTest.cpp b/llvm/unittests/ADT/SetVectorTest.cpp --- a/llvm/unittests/ADT/SetVectorTest.cpp +++ b/llvm/unittests/ADT/SetVectorTest.cpp @@ -86,3 +86,36 @@ EXPECT_FALSE(S.contains(&j)); EXPECT_FALSE(S.contains((const int *)&j)); } + +struct NonEquatableType { + int Value; + bool operator==(const NonEquatableType &) const = delete; + + NonEquatableType(int Value): Value(Value) {} +}; + +namespace llvm { +template <> struct DenseMapInfo { + static inline NonEquatableType getEmptyKey() { return NonEquatableType(-1); } + static inline NonEquatableType getTombstoneKey() { + return NonEquatableType(-2); + } + static unsigned getHashValue(const NonEquatableType &Val) { + return Val.Value; + } + + static bool isEqual(const NonEquatableType &LHS, + const NonEquatableType &RHS) { + return LHS.Value == RHS.Value; + } +}; +} // namespace llvm + +TEST(SmallSetVector, InsertNonEquatableType) { + SmallSetVector S; + S.insert(NonEquatableType(1)); + S.insert(NonEquatableType(2)); + S.insert(NonEquatableType(3)); + EXPECT_TRUE(S.contains(NonEquatableType(2))); + EXPECT_FALSE(S.insert(NonEquatableType(3))); +}