Index: include/llvm/ADT/STLExtras.h =================================================================== --- include/llvm/ADT/STLExtras.h +++ include/llvm/ADT/STLExtras.h @@ -26,6 +26,7 @@ #include <memory> #include <utility> // for std::pair +#include "llvm/ADT/iterator.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Compiler.h" @@ -235,6 +236,89 @@ llvm::make_reverse_iterator(std::begin(C))); } +/// An iterator adaptor that filters the elements of given inner iterators. +/// +/// The predicate parameter should be a callable object that accepts the wrapped +/// iterator's reference type and returns a bool. When incrementing or +/// decrementing the iterator, it will call the predicate on each element and +/// skip any where it returns false. +/// +/// \code +/// int A[] = { 1, 2, 3, 4 }; +/// const auto R = make_filter_range(A, A + 4, +/// [](int A) { return A % 2 == 1; }); +/// // R contains { 1, 3 }. +/// \endcode +template <typename WrappedIteratorT, typename Predicate> +class filter_iterator + : public iterator_adaptor_base< + filter_iterator<WrappedIteratorT, Predicate>, WrappedIteratorT, + typename std::common_type< + std::bidirectional_iterator_tag, + typename std::iterator_traits< + WrappedIteratorT>::iterator_category>::type, + typename std::iterator_traits<WrappedIteratorT>::value_type, + typename std::iterator_traits<WrappedIteratorT>::difference_type, + typename std::iterator_traits<WrappedIteratorT>::pointer, + typename std::iterator_traits<WrappedIteratorT>::reference> { + using IteratorCategoryT = typename std::common_type< + std::bidirectional_iterator_tag, + typename std::iterator_traits<WrappedIteratorT>::iterator_category>::type; + using BaseT = iterator_adaptor_base< + filter_iterator<WrappedIteratorT, Predicate>, WrappedIteratorT, + IteratorCategoryT, + typename std::iterator_traits<WrappedIteratorT>::value_type, + typename std::iterator_traits<WrappedIteratorT>::difference_type, + typename std::iterator_traits<WrappedIteratorT>::pointer, + typename std::iterator_traits<WrappedIteratorT>::reference>; + + WrappedIteratorT End; + Predicate Pred; + + void findNextValid() { + while (this->I != End && !Pred(*this->I)) + BaseT::operator++(); + } + +public: + filter_iterator(WrappedIteratorT Begin, WrappedIteratorT End, Predicate Pred) + : BaseT(Begin), End(End), Pred(Pred) { + findNextValid(); + } + + using BaseT::operator++; + using BaseT::operator--; + + filter_iterator &operator++() { + BaseT::operator++(); + findNextValid(); + return *this; + } + + filter_iterator &operator--() { + static_assert( + std::is_same<std::bidirectional_iterator_tag, IteratorCategoryT>::value, + "operator--() reuiqres the underlying type to be at least " + "bidirectional"); + BaseT::operator--(); + // Hopefully users don't call operator--() on Begin. They will be doomed. + while (!Pred(*this->I)) + BaseT::operator--(); + return *this; + } +}; + +/// \brief Convenience function that takes a range of elements and a predicate, +/// and return a new filter_iterator range. +template <typename WrappedIteratorT, typename Predicate> +iterator_range<filter_iterator<WrappedIteratorT, Predicate>> +make_filter_range(iterator_range<WrappedIteratorT> Range, + const Predicate &Pred) { + using IterT = filter_iterator<WrappedIteratorT, Predicate>; + return make_range(IterT(Range.begin(), Range.end(), Pred), + IterT(Range.end(), Range.end(), Pred)); +} + //===----------------------------------------------------------------------===// // Extra additions to <utility> //===----------------------------------------------------------------------===// Index: include/llvm/ADT/iterator.h =================================================================== --- include/llvm/ADT/iterator.h +++ include/llvm/ADT/iterator.h @@ -168,15 +168,7 @@ iterator_adaptor_base() = default; - template <typename U> - explicit iterator_adaptor_base( - U &&u, - typename std::enable_if< - !std::is_base_of<typename std::remove_cv< - typename std::remove_reference<U>::type>::type, - DerivedT>::value, - int>::type = 0) - : I(std::forward<U &&>(u)) {} + explicit iterator_adaptor_base(WrappedIteratorT u) : I(std::move(u)) {} const WrappedIteratorT &wrapped() const { return I; } Index: unittests/Support/IteratorTest.cpp =================================================================== --- unittests/Support/IteratorTest.cpp +++ unittests/Support/IteratorTest.cpp @@ -12,6 +12,8 @@ #include "llvm/ADT/SmallVector.h" #include "gtest/gtest.h" +#include <functional> + using namespace llvm; namespace { @@ -98,4 +100,77 @@ EXPECT_EQ(End, I); } +TEST(FilterIteratorTest, Lambda) { + auto IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector<int, 3> Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector<int, 3>{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, StdFunction) { + std::function<bool(int)> IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector<int, 3> Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector<int, 3>{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, FunctionPointer) { + bool (*IsOdd)(int) = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range(make_range(a, a + 7), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector<int, 3> Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector<int, 3>{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, Composition) { + const auto IsOdd = [](int a) { return a % 2 == 1; }; + std::unique_ptr<int> a[] = {make_unique<int>(0), make_unique<int>(1), + make_unique<int>(2), make_unique<int>(3), + make_unique<int>(4), make_unique<int>(5), + make_unique<int>(6)}; + using PointeeIterator = pointee_iterator<std::unique_ptr<int> *>; + const auto Range = make_filter_range( + make_range(PointeeIterator(a), PointeeIterator(a + 7)), IsOdd); + ASSERT_EQ(3, std::distance(Range.begin(), Range.end())); + SmallVector<int, 3> Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector<int, 3>{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, InputIterator) { + struct InputIterator + : iterator_adaptor_base<InputIterator, int *, std::input_iterator_tag> { + using BaseT = + iterator_adaptor_base<InputIterator, int *, std::input_iterator_tag>; + + InputIterator(int *It) : BaseT(It) {} + }; + + auto IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + const auto Range = make_filter_range( + make_range(InputIterator(a), InputIterator(a + 7)), IsOdd); + SmallVector<int, 3> Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector<int, 3>{1, 3, 5}), Actual); +} + +TEST(FilterIteratorTest, Decrement) { + auto IsOdd = [](int a) { return a % 2 == 1; }; + int a[] = {0, 1, 2, 3, 4, 5, 6}; + auto Iter = make_filter_range(make_range(a, a + 7), IsOdd).begin(); + EXPECT_EQ(1, *Iter); + ++Iter; + EXPECT_EQ(3, *Iter); + --Iter; + EXPECT_EQ(1, *Iter); + std::advance(Iter, 2); + EXPECT_EQ(5, *Iter); + --Iter; + EXPECT_EQ(3, *Iter); +} + } // anonymous namespace