diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1906,11 +1906,39 @@ return std::move(adl_begin(Range), adl_end(Range), Out); } -/// Wrapper function around std::find to detect if an element exists -/// in a container. +namespace detail { +template +using check_has_member_contains_t = + decltype(std::declval().contains(std::declval())); + +template +static constexpr bool HasMemberContains = + is_detected::value; + +template +using check_has_member_find_t = + decltype(std::declval().find(std::declval()) != + std::declval().end()); + +template +static constexpr bool HasMemberFind = + is_detected::value; + +} // namespace detail + +/// Returns true if \p Element is found in \p Range. Delegates the check to +/// either `.contains(Element)`, `.find(Element)`, or `llvm::find`, in this +/// order of preference. This is intended as the canonical way to check if an +/// element exists in a range in generic code or range type that does not +/// expose an `.contains(Element)` member. template bool is_contained(R &&Range, const E &Element) { - return std::find(adl_begin(Range), adl_end(Range), Element) != adl_end(Range); + if constexpr (detail::HasMemberContains) + return Range.contains(Element); + else if constexpr (detail::HasMemberFind) + return Range.find(Element) != Range.end(); + else + return llvm::find(Range, Element) != adl_end(Range); } /// Returns true iff \p Element exists in \p Set. This overload takes \p Set as diff --git a/llvm/include/llvm/Analysis/LoopInfoImpl.h b/llvm/include/llvm/Analysis/LoopInfoImpl.h --- a/llvm/include/llvm/Analysis/LoopInfoImpl.h +++ b/llvm/include/llvm/Analysis/LoopInfoImpl.h @@ -371,7 +371,7 @@ // Check the parent loop pointer. if (ParentLoop) { - assert(is_contained(*ParentLoop, this) && + assert(is_contained(ParentLoop->getSubLoops(), this) && "Loop is not a subloop of its parent!"); } #endif diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -1029,6 +1029,52 @@ static_assert(!is_contained({1, 2, 3, 4}, 5), "It's not there :("); } +TEST(STLExtrasTest, IsContainedMemberContains) { + // Check that `llvm::is_contained` uses the member `.contains()` when + // available. + struct Foo { + bool contains(int) const { + ++NumContainsCalls; + return ContainsResult; + } + bool ContainsResult = false; + mutable unsigned NumContainsCalls = 0; + } Container; + + EXPECT_EQ(Container.NumContainsCalls, 0u); + EXPECT_FALSE(is_contained(Container, 1)); + EXPECT_EQ(Container.NumContainsCalls, 1u); + + Container.ContainsResult = true; + EXPECT_TRUE(is_contained(Container, 1)); + EXPECT_EQ(Container.NumContainsCalls, 2u); +} + +TEST(STLExtrasTest, IsContainedMemberFind) { + // Check that `llvm::is_contained` uses the member `.find(x)` when available. + struct Foo { + auto begin() const { return Data.begin(); } + auto end() const { return Data.end(); } + auto find(int X) const { + ++NumFindCalls; + return std::find(begin(), end(), X); + } + + std::vector Data; + mutable unsigned NumFindCalls = 0; + } Container; + + Container.Data = {1, 2, 3}; + + EXPECT_EQ(Container.NumFindCalls, 0u); + EXPECT_TRUE(is_contained(Container, 1)); + EXPECT_TRUE(is_contained(Container, 3)); + EXPECT_EQ(Container.NumFindCalls, 2u); + + EXPECT_FALSE(is_contained(Container, 4)); + EXPECT_EQ(Container.NumFindCalls, 3u); +} + TEST(STLExtrasTest, addEnumValues) { enum A { Zero = 0, One = 1 }; enum B { IntMax = INT_MAX, ULongLongMax = ULLONG_MAX };