diff --git a/llvm/include/llvm/ADT/DenseMapInfo.h b/llvm/include/llvm/ADT/DenseMapInfo.h --- a/llvm/include/llvm/ADT/DenseMapInfo.h +++ b/llvm/include/llvm/ADT/DenseMapInfo.h @@ -318,8 +318,19 @@ } static bool isEqual(const Variant &LHS, const Variant &RHS) { - return LHS == RHS; - } + return LHS.index() == RHS.index() && std::visit(EqVisitor{}, LHS, RHS); + } + +private: + struct EqVisitor { + template + bool operator()(const T &LHS, const U &RHS) const { + return false; + } + template bool operator()(const T &LHS, const T &RHS) const { + return DenseMapInfo::isEqual(LHS, RHS); + } + }; }; } // end namespace llvm diff --git a/llvm/unittests/ADT/DenseMapTest.cpp b/llvm/unittests/ADT/DenseMapTest.cpp --- a/llvm/unittests/ADT/DenseMapTest.cpp +++ b/llvm/unittests/ADT/DenseMapTest.cpp @@ -690,6 +690,10 @@ struct B : public A { using A::A; }; + +struct AlwaysEqType { + bool operator==(const AlwaysEqType &RHS) const { return true; } +}; } // namespace namespace llvm { @@ -702,6 +706,17 @@ return LHS.value == RHS.value; } }; + +template <> struct DenseMapInfo { + using T = AlwaysEqType; + static inline T getEmptyKey() { return {}; } + static inline T getTombstoneKey() { return {}; } + static unsigned getHashValue(const T &Val) { return 0; } + static bool isEqual(const T &LHS, const T &RHS) { + return false; + } +}; + } // namespace llvm namespace { @@ -725,16 +740,20 @@ } TEST(DenseMapCustomTest, VariantSupport) { - using variant = std::variant; + using variant = std::variant; DenseMap Map; variant Keys[] = { variant(std::in_place_index<0>, 1), variant(std::in_place_index<1>, 1), + variant(std::in_place_index<2>), }; Map.try_emplace(Keys[0], 0); Map.try_emplace(Keys[1], 1); EXPECT_THAT(Map, testing::SizeIs(2)); EXPECT_NE(DenseMapInfo::getHashValue(Keys[0]), DenseMapInfo::getHashValue(Keys[1])); + // Check that isEqual dispatches to isEqual of underlying type, and not to + // operator==. + EXPECT_FALSE(DenseMapInfo::isEqual(Keys[2], Keys[2])); } } // namespace