Index: include/llvm/ADT/STLExtras.h =================================================================== --- include/llvm/ADT/STLExtras.h +++ include/llvm/ADT/STLExtras.h @@ -626,8 +626,12 @@ } }; +template struct remove_rvalue_reference { typedef T type; }; + +template struct remove_rvalue_reference { typedef T type; }; + namespace detail { -template class enumerator_impl { +template class enumerator_impl { public: template struct result_pair { result_pair(std::size_t Index, X Value) : Index(Index), Value(Value) {} @@ -637,6 +641,9 @@ }; struct iterator { + using I = decltype(std::begin(std::declval())); + using V = decltype(*std::declval()); + iterator(I Iter, std::size_t Index) : Iter(Iter), Index(Index) {} result_pair operator*() const { @@ -657,24 +664,21 @@ std::size_t Index; }; - enumerator_impl(I Begin, I End) - : Begin(std::move(Begin)), End(std::move(End)) {} + explicit enumerator_impl(R &&Range) : Range(std::forward(Range)) {} - iterator begin() { return iterator(Begin, 0); } - iterator end() { return iterator(End, std::size_t(-1)); } + iterator begin() { return iterator(std::begin(Range), 0); } + iterator end() { return iterator(std::end(Range), std::size_t(-1)); } - iterator begin() const { return iterator(Begin, 0); } - iterator end() const { return iterator(End, std::size_t(-1)); } + iterator begin() const { return iterator(std::begin(Range), 0); } + iterator end() const { return iterator(std::end(Range), std::size_t(-1)); } private: - I Begin; - I End; + // If R is an r-value reference, it means we're trying to enumerate a + // temporary range. By removing the r-value reference, we store a copy of the + // range. If it's anything else (including a reference) store it as-is to + // enable mutation of the underlying range if the type supports it. + typename remove_rvalue_reference::type Range; }; - -template -auto make_enumerator(I Begin, I End) -> enumerator_impl { - return enumerator_impl(std::move(Begin), std::move(End)); -} } /// Given an input range, returns a new range whose values are are pair (A,B) @@ -692,10 +696,8 @@ /// Item 2 - C /// Item 3 - D /// -template -auto enumerate(R &&Range) - -> decltype(detail::make_enumerator(std::begin(Range), std::end(Range))) { - return detail::make_enumerator(std::begin(Range), std::end(Range)); +template detail::enumerator_impl enumerate(R &&Range) { + return detail::enumerator_impl(std::forward(Range)); } } // End llvm namespace Index: unittests/ADT/STLExtrasTest.cpp =================================================================== --- unittests/ADT/STLExtrasTest.cpp +++ unittests/ADT/STLExtrasTest.cpp @@ -86,4 +86,34 @@ EXPECT_EQ('c', foo[1]); EXPECT_EQ('d', foo[2]); } + +TEST(STLExtrasTest, EnumerateRValueRef) { + using ResultPair = std::pair; + std::vector Results; + for (auto X : llvm::enumerate(std::vector{1, 2, 3})) { + Results.push_back(std::make_pair(X.Index, X.Value)); + } + + EXPECT_EQ(3, Results.size()); + EXPECT_EQ(ResultPair(0, 1), Results[0]); + EXPECT_EQ(ResultPair(1, 2), Results[1]); + EXPECT_EQ(ResultPair(2, 3), Results[2]); +} + +TEST(STLExtrasTest, EnumerateNested) { + std::vector foo = {'a', 'b', 'c'}; + + using InternalPairType = std::pair; + using ExternalPairType = std::pair; + std::vector Results; + + for (auto X : enumerate(enumerate(foo))) { + Results.push_back( + std::make_pair(X.Index, std::make_pair(X.Value.Index, X.Value.Value))); + } + + EXPECT_EQ(ExternalPairType(0u, InternalPairType(0u, 'a')), Results[0]); + EXPECT_EQ(ExternalPairType(1u, InternalPairType(1u, 'b')), Results[1]); + EXPECT_EQ(ExternalPairType(2u, InternalPairType(2u, 'c')), Results[2]); +} }