Index: include/llvm/ADT/STLExtras.h =================================================================== --- include/llvm/ADT/STLExtras.h +++ include/llvm/ADT/STLExtras.h @@ -26,6 +26,7 @@ #include #include // for std::pair +#include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Compiler.h" @@ -235,6 +236,89 @@ llvm::make_reverse_iterator(std::begin(C))); } +/// An iterator adaptor that filters the elements of given inner iterators. +/// +/// The predicate parameter should be a callable object that accepts the wrapped +/// iterator's reference type and returns a bool. When incrementing or +/// decrementing the iterator, it will call the predicate on each element and +/// skip any where it returns false. +/// +/// \code +/// int A[] = { 1, 2, 3, 4 }; +/// const auto R = make_filter_range(A, A + 4, +/// [](int A) { return A % 2 == 1; }); +/// // R contains { 1, 3 }. +/// \endcode +template +class filter_iterator + : public iterator_adaptor_base< + filter_iterator, WrappedIteratorT, + typename std::common_type< + std::bidirectional_iterator_tag, + typename std::iterator_traits< + WrappedIteratorT>::iterator_category>::type, + typename std::iterator_traits::value_type, + typename std::iterator_traits::difference_type, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference> { + using IteratorCategoryT = typename std::common_type< + std::bidirectional_iterator_tag, + typename std::iterator_traits::iterator_category>::type; + using BaseT = iterator_adaptor_base< + filter_iterator, WrappedIteratorT, + IteratorCategoryT, + typename std::iterator_traits::value_type, + typename std::iterator_traits::difference_type, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference>; + + WrappedIteratorT End; + Predicate Pred; + + void findNextValid() { + while (this->I != End && !Pred(*this->I)) + BaseT::operator++(); + } + +public: + filter_iterator(WrappedIteratorT Begin, WrappedIteratorT End, Predicate Pred) + : BaseT(Begin), End(End), Pred(Pred) { + findNextValid(); + } + + using BaseT::operator++; + using BaseT::operator--; + + filter_iterator &operator++() { + BaseT::operator++(); + findNextValid(); + return *this; + } + + filter_iterator &operator--() { + static_assert( + std::is_same::value, + "operator--() reuiqres the underlying type to be at least " + "bidirectional"); + BaseT::operator--(); + // Hopefully users don't call operator--() on Begin. They will be doomed. + while (!Pred(*this->I)) + BaseT::operator--(); + return *this; + } +}; + +/// \brief Convenience function that takes a range of elements and a predicate, +/// and return a new filter_iterator range. +template +iterator_range> +make_filter_range(iterator_range Range, + const Predicate &Pred) { + using IterT = filter_iterator; + return make_range(IterT(Range.begin(), Range.end(), Pred), + IterT(Range.end(), Range.end(), Pred)); +} + //===----------------------------------------------------------------------===// // Extra additions to //===----------------------------------------------------------------------===// Index: include/llvm/ADT/iterator.h =================================================================== --- include/llvm/ADT/iterator.h +++ include/llvm/ADT/iterator.h @@ -168,15 +168,7 @@ iterator_adaptor_base() = default; - template - explicit iterator_adaptor_base( - U &&u, - typename std::enable_if< - !std::is_base_of::type>::type, - DerivedT>::value, - int>::type = 0) - : I(std::forward(u)) {} + explicit iterator_adaptor_base(WrappedIteratorT u) : I(std::move(u)) {} const WrappedIteratorT &wrapped() const { return I; } Index: unittests/Support/IteratorTest.cpp =================================================================== --- unittests/Support/IteratorTest.cpp +++ unittests/Support/IteratorTest.cpp @@ -12,6 +12,8 @@ #include "llvm/ADT/SmallVector.h" #include "gtest/gtest.h" +#include + using namespace llvm; namespace { @@ -98,4 +100,77 @@ EXPECT_EQ(End, I); } +TEST(FilterIteratorTest, Lambda) { + auto IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, StdFunction) { + std::function IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, FunctionPointer) { + bool (*IsOdd)(int) = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, Composition) { + const auto IsOdd = [](int a) { return a % 2 == 1; }; + std::unique_ptr a[] = {make_unique(0), make_unique(1), + make_unique(2), make_unique(3), + make_unique(4), make_unique(5), + make_unique(6)}; + using PointeeIterator = pointee_iterator *>; + const auto Range = make_filter_range( + make_range(PointeeIterator(a), PointeeIterator(a + 7)), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, InputIterator) { + struct InputIterator + : iterator_adaptor_base { + using BaseT = + iterator_adaptor_base; + + InputIterator(int *It) : BaseT(It) {} + }; + + auto IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range( + make_range(InputIterator(a), InputIterator(a + 7)), IsOdd); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, Decrement) { + auto IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + auto Iter = make_filter_range(make_range(a, a + 7), IsOdd).begin(); + EXPECT_EQ(1, *Iter); + ++Iter; + EXPECT_EQ(3, *Iter); + --Iter; + EXPECT_EQ(1, *Iter); + std::advance(Iter, 2); + EXPECT_EQ(5, *Iter); + --Iter; + EXPECT_EQ(3, *Iter); +} + } // anonymous namespace