Index: include/llvm/ADT/STLExtras.h =================================================================== --- include/llvm/ADT/STLExtras.h +++ include/llvm/ADT/STLExtras.h @@ -26,6 +26,8 @@ #include #include // for std::pair +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Compiler.h" @@ -235,6 +237,91 @@ llvm::make_reverse_iterator(std::begin(C))); } +/// An iterator adaptor that filters the elements of given inner iterators. +/// +/// The predicate parameter should be a callable object that accepts the wrapped +/// iterator's reference type and returns a bool. When incrementing or +/// decrementing the iterator, it will call the predicate on each element and +/// skip any where it returns false. +/// +/// \code +/// int A[] = { 1, 2, 3, 4 }; +/// const auto R = make_filter_range(A, A + 4, +/// [](int A) { return A % 2 == 1; }); +/// // R contains { 1, 3 }. +/// \endcode +template +class filter_iterator + : public iterator_adaptor_base< + filter_iterator, WrappedIteratorT, + typename std::common_type< + std::forward_iterator_tag, + typename std::iterator_traits< + WrappedIteratorT>::iterator_category>::type, + typename std::iterator_traits::value_type, + typename std::iterator_traits::difference_type, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference> { + using BaseT = iterator_adaptor_base< + filter_iterator, WrappedIteratorT, + typename std::common_type::iterator_category>::type, + typename std::iterator_traits::value_type, + typename std::iterator_traits::difference_type, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference>; + + struct PayloadType { + WrappedIteratorT End; + PredicateT Pred; + }; + + Optional Payload; + + void findNextValid() { + assert(Payload && "Payload should be engaged when findNextValid is called"); + while (this->I != Payload->End && !Payload->Pred(*this->I)) + BaseT::operator++(); + } + + // Construct the begin iterator. The begin iterator requires to know where end + // is, so that it can properly stop when it hits end. + filter_iterator(WrappedIteratorT Begin, WrappedIteratorT End, PredicateT Pred) + : BaseT(std::move(Begin)), + Payload(PayloadType{std::move(End), std::move(Pred)}) { + findNextValid(); + } + + // Construct the end iterator. It's not incrementable, so Payload doesn't + // have to be engaged. + filter_iterator(WrappedIteratorT End) : BaseT(End) {} + +public: + using BaseT::operator++; + + filter_iterator &operator++() { + BaseT::operator++(); + findNextValid(); + return *this; + } + + template + friend iterator_range> + make_filter_range(iterator_range, PT); +}; + +/// Convenience function that takes a range of elements and a predicate, +/// and return a new filter_iterator range. +template +iterator_range> +make_filter_range(iterator_range Range, PredicateT Pred) { + using FilterIteratorT = filter_iterator; + return make_range( + FilterIteratorT(Range.begin(), Range.end(), std::move(Pred)), + FilterIteratorT(Range.end())); +} + //===----------------------------------------------------------------------===// // Extra additions to //===----------------------------------------------------------------------===// Index: include/llvm/ADT/iterator.h =================================================================== --- include/llvm/ADT/iterator.h +++ include/llvm/ADT/iterator.h @@ -168,15 +168,7 @@ iterator_adaptor_base() = default; - template - explicit iterator_adaptor_base( - U &&u, - typename std::enable_if< - !std::is_base_of::type>::type, - DerivedT>::value, - int>::type = 0) - : I(std::forward(u)) {} + explicit iterator_adaptor_base(WrappedIteratorT u) : I(std::move(u)) {} const WrappedIteratorT &wrapped() const { return I; } Index: unittests/Support/IteratorTest.cpp =================================================================== --- unittests/Support/IteratorTest.cpp +++ unittests/Support/IteratorTest.cpp @@ -12,6 +12,8 @@ #include "llvm/ADT/SmallVector.h" #include "gtest/gtest.h" +#include + using namespace llvm; namespace { @@ -98,4 +100,62 @@ EXPECT_EQ(End, I); } +TEST(FilterIteratorTest, Lambda) { + auto IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, StdFunction) { + std::function IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, FunctionPointer) { + bool (*IsOdd)(int) = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, Composition) { + const auto IsOdd = [](int a) { return a % 2 == 1; }; + std::unique_ptr a[] = {make_unique(0), make_unique(1), + make_unique(2), make_unique(3), + make_unique(4), make_unique(5), + make_unique(6)}; + using PointeeIterator = pointee_iterator *>; + const auto Range = make_filter_range( + make_range(PointeeIterator(a), PointeeIterator(a + 7)), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, InputIterator) { + struct InputIterator + : iterator_adaptor_base { + using BaseT = + iterator_adaptor_base; + + InputIterator(int *It) : BaseT(It) {} + }; + + auto IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range( + make_range(InputIterator(a), InputIterator(a + 7)), IsOdd); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + } // anonymous namespace