diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1982,6 +1982,16 @@ IterOfRange Iter; }; +template +decltype(auto) get(const result_pair &Pair) { + static_assert(i < 2); + if constexpr (i == 0) { + return Pair.index(); + } else { + return Pair.value(); + } +} + template class enumerator_iter : public iterator_facade_base, std::forward_iterator_tag, @@ -2054,6 +2064,12 @@ /// printf("Item %d - %c\n", X.index(), X.value()); /// } /// +/// or using structured bindings: +/// +/// for (auto [Index, Value] : enumerate(Items)) { +/// printf("Item %d - %c\n", Index, Value); +/// } +/// /// Output: /// Item 0 - A /// Item 1 - B @@ -2192,4 +2208,17 @@ } // end namespace llvm +namespace std { +template +struct tuple_size> + : std::integral_constant {}; + +template +struct tuple_element> + : std::conditional::value_reference> { +}; + +} // namespace std + #endif // LLVM_ADT_STLEXTRAS_H diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/STLExtras.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include @@ -15,6 +16,8 @@ using namespace llvm; +using testing::ElementsAre; + namespace { int f(rank<0>) { return 0; } @@ -47,31 +50,29 @@ typedef std::pair CharPairType; std::vector CharResults; - for (auto X : llvm::enumerate(foo)) { - CharResults.emplace_back(X.index(), X.value()); + for (auto [index, value] : llvm::enumerate(foo)) { + CharResults.emplace_back(index, value); } - ASSERT_EQ(3u, CharResults.size()); - EXPECT_EQ(CharPairType(0u, 'a'), CharResults[0]); - EXPECT_EQ(CharPairType(1u, 'b'), CharResults[1]); - EXPECT_EQ(CharPairType(2u, 'c'), CharResults[2]); + + EXPECT_THAT(CharResults, + ElementsAre(CharPairType(0u, 'a'), CharPairType(1u, 'b'), + CharPairType(2u, 'c'))); // Test a const range of a different type. typedef std::pair IntPairType; std::vector IntResults; const std::vector bar = {1, 2, 3}; - for (auto X : llvm::enumerate(bar)) { - IntResults.emplace_back(X.index(), X.value()); + for (auto [index, value] : llvm::enumerate(bar)) { + IntResults.emplace_back(index, value); } - ASSERT_EQ(3u, IntResults.size()); - EXPECT_EQ(IntPairType(0u, 1), IntResults[0]); - EXPECT_EQ(IntPairType(1u, 2), IntResults[1]); - EXPECT_EQ(IntPairType(2u, 3), IntResults[2]); + EXPECT_THAT(IntResults, ElementsAre(IntPairType(0u, 1), IntPairType(1u, 2), + IntPairType(2u, 3))); // Test an empty range. IntResults.clear(); const std::vector baz{}; - for (auto X : llvm::enumerate(baz)) { - IntResults.emplace_back(X.index(), X.value()); + for (auto [index, value] : llvm::enumerate(baz)) { + IntResults.emplace_back(index, value); } EXPECT_TRUE(IntResults.empty()); } @@ -84,9 +85,15 @@ for (auto X : llvm::enumerate(foo)) { ++X.value(); } - EXPECT_EQ('b', foo[0]); - EXPECT_EQ('c', foo[1]); - EXPECT_EQ('d', foo[2]); + EXPECT_THAT(foo, ElementsAre('b', 'c', 'd')); + + // Also test if this works with structured bindings. + foo = {'a', 'b', 'c'}; + + for (auto [index, value] : llvm::enumerate(foo)) { + ++value; + } + EXPECT_THAT(foo, ElementsAre('b', 'c', 'd')); } TEST(STLExtrasTest, EnumerateRValueRef) { @@ -100,10 +107,18 @@ Results.emplace_back(X.index(), X.value()); } - ASSERT_EQ(3u, Results.size()); - EXPECT_EQ(PairType(0u, 1), Results[0]); - EXPECT_EQ(PairType(1u, 2), Results[1]); - EXPECT_EQ(PairType(2u, 3), Results[2]); + EXPECT_THAT(Results, + ElementsAre(PairType(0u, 1), PairType(1u, 2), PairType(2u, 3))); + + // Also test if this works with structured bindings. + Results.clear(); + + for (auto [index, value] : llvm::enumerate(std::vector{1, 2, 3})) { + Results.emplace_back(index, value); + } + + EXPECT_THAT(Results, + ElementsAre(PairType(0u, 1), PairType(1u, 2), PairType(2u, 3))); } TEST(STLExtrasTest, EnumerateModifyRValue) { @@ -118,10 +133,20 @@ Results.emplace_back(X.index(), X.value()); } - ASSERT_EQ(3u, Results.size()); - EXPECT_EQ(PairType(0u, '2'), Results[0]); - EXPECT_EQ(PairType(1u, '3'), Results[1]); - EXPECT_EQ(PairType(2u, '4'), Results[2]); + EXPECT_THAT(Results, ElementsAre(PairType(0u, '2'), PairType(1u, '3'), + PairType(2u, '4'))); + + // Also test if this works with structured bindings. + Results.clear(); + + for (auto [index, value] : + llvm::enumerate(std::vector{'1', '2', '3'})) { + ++value; + Results.emplace_back(index, value); + } + + EXPECT_THAT(Results, ElementsAre(PairType(0u, '2'), PairType(1u, '3'), + PairType(2u, '4'))); } template struct CanMove {};