Index: include/llvm/ADT/STLExtras.h =================================================================== --- include/llvm/ADT/STLExtras.h +++ include/llvm/ADT/STLExtras.h @@ -26,6 +26,7 @@ #include #include // for std::pair +#include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Compiler.h" @@ -235,6 +236,78 @@ llvm::make_reverse_iterator(std::begin(C))); } +/// \brief An iterator adaptor that filters the elements of given inner +/// iterators. +/// +/// Predicate should take a parameter such that predicate(*iter) returns a bool, +/// where iter is a WrappedIteratorT object. If the predicate returns true, the +/// element is kept. +/// +/// Currently it only supports operator++ but not operator--, so it's an input +/// iterator if the underlying type is input iterator; otherwise it's a forward +/// iterator. +/// +/// \code +/// const auto IsOdd = [](int a) { return a % 2 == 1; }; +/// filter_iterator Begin(..., IsOdd); +/// \endcode +template +class filter_iterator + : public iterator_adaptor_base< + filter_iterator, WrappedIteratorT, + typename std::conditional< + std::is_same::iterator_category>::value, + std::input_iterator_tag, std::forward_iterator_tag>::type, + typename std::iterator_traits::value_type, + typename std::iterator_traits::difference_type, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference> { + using BaseT = iterator_adaptor_base< + filter_iterator, WrappedIteratorT, + typename std::conditional< + std::is_same::iterator_category>::value, + std::input_iterator_tag, std::forward_iterator_tag>::type, + typename std::iterator_traits::value_type, + typename std::iterator_traits::difference_type, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference>; + + WrappedIteratorT End; + Predicate Pred; + + void findNextValid() { + while (this->I != End && !Pred(*this->I)) + BaseT::operator++(); + } + +public: + filter_iterator(WrappedIteratorT Begin, WrappedIteratorT End, Predicate Pred) + : BaseT(Begin), End(End), Pred(Pred) { + findNextValid(); + } + + using BaseT::operator++; + + filter_iterator &operator++() { + BaseT::operator++(); + findNextValid(); + return *this; + } +}; + +template +iterator_range> +make_filter_range(iterator_range Range, + const Predicate &Pred) { + using IterT = filter_iterator; + return make_range(IterT(Range.begin(), Range.end(), Pred), + IterT(Range.end(), Range.end(), Pred)); +} + //===----------------------------------------------------------------------===// // Extra additions to //===----------------------------------------------------------------------===// Index: unittests/Support/IteratorTest.cpp =================================================================== --- unittests/Support/IteratorTest.cpp +++ unittests/Support/IteratorTest.cpp @@ -98,4 +98,13 @@ EXPECT_EQ(End, I); } +TEST(FilterIteratorTest, Basic) { + auto IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual); +} + } // anonymous namespace