diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -645,14 +645,12 @@ return value_type(*std::get(iterators)...); } - template - decltype(iterators) tup_inc(std::index_sequence) const { - return std::tuple(std::next(std::get(iterators))...); + template void tup_inc(std::index_sequence) { + (++std::get(iterators), ...); } - template - decltype(iterators) tup_dec(std::index_sequence) const { - return std::tuple(std::prev(std::get(iterators))...); + template void tup_dec(std::index_sequence) { + (--std::get(iterators), ...); } template @@ -671,14 +669,14 @@ } ZipType &operator++() { - iterators = tup_inc(std::index_sequence_for{}); + tup_inc(std::index_sequence_for{}); return *reinterpret_cast(this); } ZipType &operator--() { static_assert(Base::IsBidirectional, "All inner iterators must be at least bidirectional."); - iterators = tup_dec(std::index_sequence_for{}); + tup_dec(std::index_sequence_for{}); return *reinterpret_cast(this); } @@ -768,11 +766,9 @@ } namespace detail { -template -Iter next_or_end(const Iter &I, const Iter &End) { - if (I == End) - return End; - return std::next(I); +template void increment_if_not_end(Iter &I, const Iter &End) { + if (I != End) + ++I; } template @@ -825,10 +821,9 @@ deref_or_none(std::get(iterators), std::get(end_iterators))...); } - template - decltype(iterators) tup_inc(std::index_sequence) const { - return std::tuple( - next_or_end(std::get(iterators), std::get(end_iterators))...); + template void tup_inc(std::index_sequence) { + (increment_if_not_end(std::get(iterators), std::get(end_iterators)), + ...); } public: @@ -841,7 +836,7 @@ } zip_longest_iterator &operator++() { - iterators = tup_inc(std::index_sequence_for{}); + tup_inc(std::index_sequence_for{}); return *this; } diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -881,6 +881,63 @@ EXPECT_EQ(2, Destructors); } +TEST(STLExtrasTest, Zip) { + struct CountingIterator { + // For std::iterator_traits + using difference_type = int; + using value_type = int; + using pointer = void; + using reference = int; + using iterator_category = std::input_iterator_tag; + + // Test that zip does not require copy assign operator + CountingIterator &operator=(const CountingIterator &other) = delete; + + CountingIterator(int count) : count(count) {} + CountingIterator(const CountingIterator &other) : count(other.count) {} + bool operator==(const CountingIterator &other) const { + return count == other.count; + } + bool operator!=(const CountingIterator &other) const { + return !(*this == other); + } + CountingIterator &operator++() { + ++count; + return *this; + } + int operator*() const { return count; } + int count; + }; + + auto r1 = make_range(CountingIterator(5), CountingIterator(9)); + auto r2 = make_range(CountingIterator(8), CountingIterator(11)); + + { + auto zip_range = zip(r1, r2); + auto it = zip_range.begin(); + EXPECT_EQ(*it, std::make_tuple(5, 8)); + ++it; + EXPECT_EQ(*it, std::make_tuple(6, 9)); + ++it; + EXPECT_EQ(*it, std::make_tuple(7, 10)); + ++it; + EXPECT_EQ(it, zip_range.end()); + } + { + auto zip_range = zip_longest(r1, r2); + auto it = zip_range.begin(); + EXPECT_EQ(*it, std::make_tuple(5, 8)); + ++it; + EXPECT_EQ(*it, std::make_tuple(6, 9)); + ++it; + EXPECT_EQ(*it, std::make_tuple(7, 10)); + ++it; + EXPECT_EQ(*it, std::make_tuple(8, None)); + ++it; + EXPECT_EQ(it, zip_range.end()); + } +} + TEST(STLExtrasTest, AllOfZip) { std::vector v1 = {0, 4, 2, 1}; std::vector v2 = {1, 4, 3, 6};