diff --git a/llvm/include/llvm/ADT/DenseMapInfo.h b/llvm/include/llvm/ADT/DenseMapInfo.h --- a/llvm/include/llvm/ADT/DenseMapInfo.h +++ b/llvm/include/llvm/ADT/DenseMapInfo.h @@ -318,7 +318,22 @@ } static bool isEqual(const Variant &LHS, const Variant &RHS) { - return LHS == RHS; + if (LHS.index() != RHS.index()) + return false; + if (LHS.valueless_by_exception()) + return true; + // We want to dispatch to DenseMapInfo::isEqual(LHS.get(I), RHS.get(I)) + // We know the types are the same, but std::visit(V, LHS, RHS) doesn't. + // We erase the type held in LHS to void*, and dispatch over RHS. + const void *ErasedLHS = + std::visit([](const auto &LHS) -> const void * { return &LHS; }, LHS); + return std::visit( + [&](const auto &RHS) -> bool { + using T = std::remove_cv_t>; + return DenseMapInfo::isEqual(*static_cast(ErasedLHS), + RHS); + }, + RHS); } }; } // end namespace llvm diff --git a/llvm/unittests/ADT/DenseMapTest.cpp b/llvm/unittests/ADT/DenseMapTest.cpp --- a/llvm/unittests/ADT/DenseMapTest.cpp +++ b/llvm/unittests/ADT/DenseMapTest.cpp @@ -690,6 +690,10 @@ struct B : public A { using A::A; }; + +struct AlwaysEqType { + bool operator==(const AlwaysEqType &RHS) const { return true; } +}; } // namespace namespace llvm { @@ -702,6 +706,16 @@ return LHS.value == RHS.value; } }; + +template <> struct DenseMapInfo { + using T = AlwaysEqType; + static inline T getEmptyKey() { return {}; } + static inline T getTombstoneKey() { return {}; } + static unsigned getHashValue(const T &Val) { return 0; } + static bool isEqual(const T &LHS, const T &RHS) { + return false; + } +}; } // namespace llvm namespace { @@ -725,16 +739,20 @@ } TEST(DenseMapCustomTest, VariantSupport) { - using variant = std::variant; + using variant = std::variant; DenseMap Map; variant Keys[] = { variant(std::in_place_index<0>, 1), variant(std::in_place_index<1>, 1), + variant(std::in_place_index<2>), }; Map.try_emplace(Keys[0], 0); Map.try_emplace(Keys[1], 1); EXPECT_THAT(Map, testing::SizeIs(2)); EXPECT_NE(DenseMapInfo::getHashValue(Keys[0]), DenseMapInfo::getHashValue(Keys[1])); + // Check that isEqual dispatches to isEqual of underlying type, and not to + // operator==. + EXPECT_FALSE(DenseMapInfo::isEqual(Keys[2], Keys[2])); } } // namespace