diff --git a/llvm/include/llvm/ADT/DenseMapInfo.h b/llvm/include/llvm/ADT/DenseMapInfo.h --- a/llvm/include/llvm/ADT/DenseMapInfo.h +++ b/llvm/include/llvm/ADT/DenseMapInfo.h @@ -290,6 +290,32 @@ } }; +namespace detail { +// Helpers for DenseMapInfo + +// Returns a pointer to std::get(Val), where Index is not constant. +template +static const void *variantGetDynamic(const V &Val, int Index, + std::index_sequence) { + const void *Result = nullptr; + unsigned I = 0; + (((I++ == Index) ? (Result = &std::get(Val), true) : false) || ...); + return Result; +} + +// Visitor to check equality of values stored in variants. +// Assumes both values have the same type. +struct EqVisitor { + const void *ErasedRHSVal; // Points to the value stored in B. + + template bool operator()(const T &LHSVal) const { + const T &RHSVal = *reinterpret_cast(ErasedRHSVal); + return DenseMapInfo::eq(LHSVal, RHSVal); + } +}; + +} // namespace detail + // Provide DenseMapInfo for variants whose all alternatives have DenseMapInfo. template struct DenseMapInfo> { using Variant = std::variant; @@ -318,7 +344,17 @@ } static bool isEqual(const Variant &LHS, const Variant &RHS) { - return LHS == RHS; + if (LHS.index() != RHS.index()) return false; + if (LHS.valueless_by_exception()) return true; + // We want to dispatch to DenseMapInfo::isEqual(LHS.get(I), RHS.get(I)) + // This is surprisingly fiddly: + // - there's no get() for a runtime index (see detail::variantGetDynamic) + // - we know the types are the same, but std::visit(V, LHS, RHS) doesn't + // We erase the type held in RHS to void*, and dispatch over LHS. + return std::visit( + detail::EqVisitor{ + dyn_get(RHS, RHS.index(), std::index_sequence_for{})}, + LHS); } }; } // end namespace llvm diff --git a/llvm/unittests/ADT/DenseMapTest.cpp b/llvm/unittests/ADT/DenseMapTest.cpp --- a/llvm/unittests/ADT/DenseMapTest.cpp +++ b/llvm/unittests/ADT/DenseMapTest.cpp @@ -690,6 +690,10 @@ struct B : public A { using A::A; }; + +struct AlwaysEqType { + bool operator==(const AlwaysEqType &RHS) const { return true; } +}; } // namespace namespace llvm { @@ -702,6 +706,16 @@ return LHS.value == RHS.value; } }; + +template <> struct DenseMapInfo { + using T = AlwaysEqType; + static inline T getEmptyKey() { return {}; } + static inline T getTombstoneKey() { return {}; } + static unsigned getHashValue(const T &Val) { return 0; } + static bool isEqual(const T &LHS, const T &RHS) { + return false; + } +}; } // namespace llvm namespace { @@ -725,16 +739,20 @@ } TEST(DenseMapCustomTest, VariantSupport) { - using variant = std::variant; + using variant = std::variant; DenseMap Map; variant Keys[] = { variant(std::in_place_index<0>, 1), variant(std::in_place_index<1>, 1), + variant(std::in_place_index<2>), }; Map.try_emplace(Keys[0], 0); Map.try_emplace(Keys[1], 1); EXPECT_THAT(Map, testing::SizeIs(2)); EXPECT_NE(DenseMapInfo::getHashValue(Keys[0]), DenseMapInfo::getHashValue(Keys[1])); + // Check that isEqual dispatches to isEqual of underlying type, and not to + // operator==. + EXPECT_FALSE(DenseMapInfo::isEqual(Keys[2], Keys[2])); } } // namespace