diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1906,11 +1906,24 @@ return std::move(adl_begin(Range), adl_end(Range), Out); } -/// Wrapper function around std::find to detect if an element exists -/// in a container. +namespace detail { +template +using check_has_member_contains_t = + decltype(std::declval().contains(std::declval())); + +template +static constexpr bool HasMemberContains = + is_detected::value; +} // namespace detail + +/// Returns true if \p Element is found in \p Range. This either calls +/// `.contains()` for range types `R` that define it, or `llvm::find`. template bool is_contained(R &&Range, const E &Element) { - return std::find(adl_begin(Range), adl_end(Range), Element) != adl_end(Range); + if constexpr (detail::HasMemberContains) + return Range.contains(Element); + else + return llvm::find(Range, Element) != adl_end(Range); } /// Returns true iff \p Element exists in \p Set. This overload takes \p Set as diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -1029,6 +1029,26 @@ static_assert(!is_contained({1, 2, 3, 4}, 5), "It's not there :("); } +TEST(STLExtrasTest, IsContainedMemberContains) { + // Check that `llvm::is_contained` uses the `.contains()` when available. + struct Foo { + bool contains(int) const { + ++NumContainsCalls; + return ContainsResult; + } + bool ContainsResult = false; + mutable unsigned NumContainsCalls = 0; + } Container; + + EXPECT_EQ(Container.NumContainsCalls, 0u); + EXPECT_FALSE(is_contained(Container, 1)); + EXPECT_EQ(Container.NumContainsCalls, 1u); + + Container.ContainsResult = true; + EXPECT_TRUE(is_contained(Container, 1)); + EXPECT_EQ(Container.NumContainsCalls, 2u); +} + TEST(STLExtrasTest, addEnumValues) { enum A { Zero = 0, One = 1 }; enum B { IntMax = INT_MAX, ULongLongMax = ULLONG_MAX };