diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -775,6 +775,7 @@ template struct zip_common : public zip_traits { using Base = zip_traits; + using IndexSequence = std::index_sequence_for; using value_type = typename Base::value_type; std::tuple iterators; @@ -784,19 +785,17 @@ return value_type(*std::get(iterators)...); } - template - decltype(iterators) tup_inc(std::index_sequence) const { - return std::tuple(std::next(std::get(iterators))...); + template void tup_inc(std::index_sequence) { + (void)(++std::get(iterators), ...); } - template - decltype(iterators) tup_dec(std::index_sequence) const { - return std::tuple(std::prev(std::get(iterators))...); + template void tup_dec(std::index_sequence) { + (void)(--std::get(iterators), ...); } template bool test_all_equals(const zip_common &other, - std::index_sequence) const { + std::index_sequence) const { return ((std::get(this->iterators) == std::get(other.iterators)) && ...); } @@ -804,25 +803,23 @@ public: zip_common(Iters &&... ts) : iterators(std::forward(ts)...) {} - value_type operator*() const { - return deref(std::index_sequence_for{}); - } + value_type operator*() const { return deref(IndexSequence{}); } ZipType &operator++() { - iterators = tup_inc(std::index_sequence_for{}); - return *reinterpret_cast(this); + tup_inc(IndexSequence{}); + return static_cast(*this); } ZipType &operator--() { static_assert(Base::IsBidirectional, "All inner iterators must be at least bidirectional."); - iterators = tup_dec(std::index_sequence_for{}); - return *reinterpret_cast(this); + tup_dec(IndexSequence{}); + return static_cast(*this); } /// Return true if all the iterator are matching `other`'s iterators. bool all_equals(zip_common &other) { - return test_all_equals(other, std::index_sequence_for{}); + return test_all_equals(other, IndexSequence{}); } }; diff --git a/llvm/unittests/ADT/IteratorTest.cpp b/llvm/unittests/ADT/IteratorTest.cpp --- a/llvm/unittests/ADT/IteratorTest.cpp +++ b/llvm/unittests/ADT/IteratorTest.cpp @@ -692,6 +692,48 @@ EXPECT_TRUE(all_of(ascending, [](unsigned n) { return (n & 0x01) == 0; })); } +// Int iterator that keeps track of the number of its copies. +struct CountingIntIterator : IntIterator { + unsigned *cnt; + + CountingIntIterator(int *it, unsigned &counter) + : IntIterator(it), cnt(&counter) {} + + CountingIntIterator(const CountingIntIterator &other) + : IntIterator(other.I), cnt(other.cnt) { + ++(*cnt); + } + CountingIntIterator &operator=(const CountingIntIterator &other) { + this->I = other.I; + this->cnt = other.cnt; + ++(*cnt); + return *this; + } +}; + +// Check that the iterators do not get copied with each `zippy` iterator +// increment. +TEST(ZipIteratorTest, IteratorCopies) { + std::vector ints(1000, 42); + unsigned copy_count = 0; + CountingIntIterator begin(ints.data(), copy_count); + CountingIntIterator end(ints.data() + ints.size(), copy_count); + + size_t iters = 0; + for (auto [a, b] : zip_equal(ints, llvm::make_range(begin, end))) { + EXPECT_EQ(a, b); + ++iters; + } + EXPECT_EQ(iters, ints.size()); + + // Since CountingIntIterator is not moveable, we expect at least two copies + // just when the range is moved into `zippy`. + EXPECT_GE(copy_count, 2u); + // We expect the number of copies to be much smaller than the number of loop + // iterations. + EXPECT_LT(copy_count, ints.size() / 10); +} + TEST(RangeTest, Distance) { std::vector v1; std::vector v2{1, 2, 3};