Index: include/llvm/ADT/STLExtras.h =================================================================== --- include/llvm/ADT/STLExtras.h +++ include/llvm/ADT/STLExtras.h @@ -26,6 +26,8 @@ #include #include // for std::pair +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Compiler.h" @@ -235,6 +237,92 @@ llvm::make_reverse_iterator(std::begin(C))); } +/// An iterator adaptor that filters the elements of given inner iterators. +/// +/// The predicate parameter should be a callable object that accepts the wrapped +/// iterator's reference type and returns a bool. When incrementing or +/// decrementing the iterator, it will call the predicate on each element and +/// skip any where it returns false. +/// +/// \code +/// int A[] = { 1, 2, 3, 4 }; +/// auto R = make_filter_range(std::begin(A), std::end(A), +/// [](int A) { return A % 2 == 1; }); +/// // R contains { 1, 3 }. +/// \endcode +template +class filtered_iterator + : public iterator_adaptor_base< + filtered_iterator, WrappedIteratorT, + typename std::common_type< + std::forward_iterator_tag, + typename std::iterator_traits< + WrappedIteratorT>::iterator_category>::type, + typename std::iterator_traits::value_type, + typename std::iterator_traits::difference_type, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference> { + using BaseT = iterator_adaptor_base< + filtered_iterator, WrappedIteratorT, + typename std::common_type::iterator_category>::type, + typename std::iterator_traits::value_type, + typename std::iterator_traits::difference_type, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference>; + + struct PayloadType { + WrappedIteratorT End; + PredicateT Pred; + }; + + Optional Payload; + + void findNextValid() { + assert(Payload && "Payload should be engaged when findNextValid is called"); + while (this->I != Payload->End && !Payload->Pred(*this->I)) + BaseT::operator++(); + } + + // Construct the begin iterator. The begin iterator requires to know where end + // is, so that it can properly stop when it hits end. + filtered_iterator(WrappedIteratorT Begin, WrappedIteratorT End, + PredicateT Pred) + : BaseT(std::move(Begin)), + Payload(PayloadType{std::move(End), std::move(Pred)}) { + findNextValid(); + } + + // Construct the end iterator. It's not incrementable, so Payload doesn't + // have to be engaged. + filtered_iterator(WrappedIteratorT End) : BaseT(End) {} + +public: + using BaseT::operator++; + + filtered_iterator &operator++() { + BaseT::operator++(); + findNextValid(); + return *this; + } + + template + friend iterator_range> + make_filter_range(iterator_range, PT); +}; + +/// Convenience function that takes a range of elements and a predicate, +/// and return a new filtered_iterator range. +template +iterator_range> +make_filter_range(iterator_range Range, PredicateT Pred) { + using FilterIteratorT = filtered_iterator; + return make_range( + FilterIteratorT(Range.begin(), Range.end(), std::move(Pred)), + FilterIteratorT(Range.end())); +} + //===----------------------------------------------------------------------===// // Extra additions to //===----------------------------------------------------------------------===// Index: include/llvm/ADT/iterator.h =================================================================== --- include/llvm/ADT/iterator.h +++ include/llvm/ADT/iterator.h @@ -168,15 +168,7 @@ iterator_adaptor_base() = default; - template - explicit iterator_adaptor_base( - U &&u, - typename std::enable_if< - !std::is_base_of::type>::type, - DerivedT>::value, - int>::type = 0) - : I(std::forward(u)) {} + explicit iterator_adaptor_base(WrappedIteratorT u) : I(std::move(u)) {} const WrappedIteratorT &wrapped() const { return I; } Index: unittests/Support/IteratorTest.cpp =================================================================== --- unittests/Support/IteratorTest.cpp +++ unittests/Support/IteratorTest.cpp @@ -98,4 +98,76 @@ EXPECT_EQ(End, I); } +TEST(FilterIteratorTest, Lambda) { + auto IsOdd = [](int N) { return N % 2 == 1; }; + int A[] = {0, 1, 2, 3, 4, 5, 6}; + auto Range = make_filter_range(make_range(std::begin(A), std::end(A)), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, CallableObject) { + int Counter = 0; + struct Callable { + int &Counter; + + Callable(int &Counter) : Counter(Counter) {} + + bool operator()(int N) { + Counter++; + return N % 2 == 1; + } + }; + Callable IsOdd(Counter); + int A[] = {0, 1, 2, 3, 4, 5, 6}; + auto Range = make_filter_range(make_range(std::begin(A), std::end(A)), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + EXPECT_EQ(7, Counter); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, FunctionPointer) { + bool (*IsOdd)(int) = [](int N) { return N % 2 == 1; }; + int A[] = {0, 1, 2, 3, 4, 5, 6}; + auto Range = make_filter_range(make_range(std::begin(A), std::end(A)), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, Composition) { + auto IsOdd = [](int N) { return N % 2 == 1; }; + std::unique_ptr A[] = {make_unique(0), make_unique(1), + make_unique(2), make_unique(3), + make_unique(4), make_unique(5), + make_unique(6)}; + using PointeeIterator = pointee_iterator *>; + auto Range = make_filter_range( + make_range(PointeeIterator(std::begin(A)), PointeeIterator(std::end(A))), + IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, InputIterator) { + struct InputIterator + : iterator_adaptor_base { + using BaseT = + iterator_adaptor_base; + + InputIterator(int *It) : BaseT(It) {} + }; + + auto IsOdd = [](int N) { return N % 2 == 1; }; + int A[] = {0, 1, 2, 3, 4, 5, 6}; + auto Range = make_filter_range( + make_range(InputIterator(std::begin(A)), InputIterator(std::end(A))), + IsOdd); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + } // anonymous namespace