diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -646,13 +646,13 @@ } template - decltype(iterators) tup_inc(std::index_sequence) const { - return std::tuple(std::next(std::get(iterators))...); + void tup_inc(std::index_sequence) { + (++std::get(iterators), ...); } template - decltype(iterators) tup_dec(std::index_sequence) const { - return std::tuple(std::prev(std::get(iterators))...); + void tup_dec(std::index_sequence) { + (--std::get(iterators), ...); } template @@ -671,14 +671,14 @@ } ZipType &operator++() { - iterators = tup_inc(std::index_sequence_for{}); + tup_inc(std::index_sequence_for{}); return *reinterpret_cast(this); } ZipType &operator--() { static_assert(Base::IsBidirectional, "All inner iterators must be at least bidirectional."); - iterators = tup_dec(std::index_sequence_for{}); + tup_dec(std::index_sequence_for{}); return *reinterpret_cast(this); } @@ -769,10 +769,9 @@ namespace detail { template -Iter next_or_end(const Iter &I, const Iter &End) { - if (I == End) - return End; - return std::next(I); +void increment_if_not_end(Iter &I, const Iter &End) { + if (I != End) + ++I; } template @@ -826,9 +825,8 @@ } template - decltype(iterators) tup_inc(std::index_sequence) const { - return std::tuple( - next_or_end(std::get(iterators), std::get(end_iterators))...); + void tup_inc(std::index_sequence) { + (increment_if_not_end(std::get(iterators), std::get(end_iterators)), ...); } public: @@ -841,7 +839,7 @@ } zip_longest_iterator &operator++() { - iterators = tup_inc(std::index_sequence_for{}); + tup_inc(std::index_sequence_for{}); return *this; } diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -881,6 +881,57 @@ EXPECT_EQ(2, Destructors); } + +TEST(STLExtrasTest, Zip) { + struct CountingIterator { + // For std::iterator_traits + using difference_type = int; + using value_type = int; + using pointer = void; + using reference = int; + using iterator_category = std::input_iterator_tag; + + // Test that zip does not require copy assign operator + CountingIterator& operator=(const CountingIterator& other) = delete; + + CountingIterator(int count) : count(count) {} + CountingIterator(const CountingIterator& other) : count(other.count) {} + bool operator==(const CountingIterator& other) const { return count == other.count; } + bool operator!=(const CountingIterator& other) const { return !(*this == other); } + CountingIterator& operator++() { ++count; return *this; } + int operator*() const { return count; } + int count; + }; + + auto r1 = make_range(CountingIterator(5), CountingIterator(9)); + auto r2 = make_range(CountingIterator(8), CountingIterator(11)); + + { + auto zip_range = zip(r1, r2); + auto it = zip_range.begin(); + EXPECT_EQ(*it, std::make_tuple(5, 8)); + ++it; + EXPECT_EQ(*it, std::make_tuple(6, 9)); + ++it; + EXPECT_EQ(*it, std::make_tuple(7, 10)); + ++it; + EXPECT_EQ(it, zip_range.end()); + } + { + auto zip_range = zip_longest(r1, r2); + auto it = zip_range.begin(); + EXPECT_EQ(*it, std::make_tuple(5, 8)); + ++it; + EXPECT_EQ(*it, std::make_tuple(6, 9)); + ++it; + EXPECT_EQ(*it, std::make_tuple(7, 10)); + ++it; + EXPECT_EQ(*it, std::make_tuple(8, None)); + ++it; + EXPECT_EQ(it, zip_range.end()); + } +} + TEST(STLExtrasTest, AllOfZip) { std::vector v1 = {0, 4, 2, 1}; std::vector v2 = {1, 4, 3, 6};