Index: include/llvm/Support/CastingIterators.h =================================================================== --- /dev/null +++ include/llvm/Support/CastingIterators.h @@ -0,0 +1,70 @@ +//===- llvm/Support/CastingIterators.h ---------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SUPPORT_CASTING_ITERATORS_H +#define LLVM_SUPPORT_CASTING_ITERATORS_H + +#include "llvm/ADT/iterator_range.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include + +namespace llvm { + +namespace detail { + +/// Functor used as the predicate in a dyn_cast_iterator. +template +struct dyn_predicate { + template + bool operator()(const BaseT *Base) const { return isa(Base); } +}; + +} // end namespace detail + +/// Iterator adapter that 'dyn_cast's the elements of the given inner iterators +/// and filters out the failed casts. +template +struct dyn_cast_iterator + : filter_iterator> { + using FilterIt = + filter_iterator>; + + dyn_cast_iterator(WrappedIteratorT Begin, WrappedIteratorT End) + : FilterIt(std::move(Begin), std::move(End), /* predicate */{}) {} + + auto operator*() const -> decltype(cast(FilterIt::operator*())) { + // Perform the cast on dereference. + // The filtering ensures that the dynamic type is correct. + return cast(FilterIt::operator*()); + } +}; + +/// Make a dyn_cast_iterator and infer the wrapped iterator type. +template +dyn_cast_iterator +make_dyn_cast_iterator(WrappedIteratorT Begin, WrappedIteratorT End) { + return dyn_cast_iterator( + std::move(Begin), std::move(End)); +} + +/// Make a dyn_cast_iterator range from the input range. +template +auto make_dyn_cast_range(const RangeT &Rng) + -> iterator_range> { + using IterT = + dyn_cast_iterator; + + return make_range(IterT(std::begin(Rng), std::end(Rng)), + IterT(std::end(Rng), std::end(Rng))); +} + +} // end namespace llvm + +#endif // LLVM_SUPPORT_CASTING_ITERATORS_H Index: unittests/Support/CMakeLists.txt =================================================================== --- unittests/Support/CMakeLists.txt +++ unittests/Support/CMakeLists.txt @@ -13,6 +13,7 @@ CachePruningTest.cpp CrashRecoveryTest.cpp Casting.cpp + CastingIterators.cpp CheckedArithmeticTest.cpp Chrono.cpp CommandLineTest.cpp Index: unittests/Support/CastingIterators.cpp =================================================================== --- /dev/null +++ unittests/Support/CastingIterators.cpp @@ -0,0 +1,96 @@ +//===- llvm/unittest/Support/CastingIterators.h -----------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/Casting.h" +#include "llvm/Support/CastingIterators.h" +#include "gtest/gtest.h" + +#include +#include + +using namespace llvm; + +namespace { + +struct Number { + unsigned getVal() const { return Val; } + virtual ~Number() = default; +protected: + Number(unsigned Val) : Val(Val) {} +private: + unsigned Val; +}; + +struct EvenNumber : Number { + EvenNumber(unsigned Val) : Number(Val) { + assert((Val % 2) == 0 && "Not an even number!"); + } + static bool classof(const Number *Num) { return (Num->getVal() % 2) == 0; } +}; + +struct OddNumber : Number { + OddNumber(unsigned Val) : Number(Val) { + assert((Val % 2) != 0 && "Not an odd number!"); + } + static bool classof(const Number *Num) { return (Num->getVal() % 2) != 0; } +}; + + + +class DynCastIterTest : public testing::Test { +protected: + std::array Nums; + + void SetUp() override { + for (std::size_t i = 1; i <= Nums.size(); ++i) { + Nums[i - 1] = (i % 2 == 0) + ? static_cast(new EvenNumber(i)) + : static_cast(new OddNumber(i)); + } + } + + void TearDown() override { + for (auto *N : Nums) + delete N; + } +}; + +TEST_F(DynCastIterTest, Simple) { + auto It = make_dyn_cast_iterator(Nums.begin(), Nums.end()); + auto End = make_dyn_cast_iterator(Nums.end(), Nums.end()); + + EXPECT_NE(It, End); + ASSERT_EQ((*It)->getVal(), 1u); + + ++It; + EXPECT_NE(It, End); + ASSERT_EQ((*It)->getVal(), 3u); + + ++It; + EXPECT_NE(It, End); + ASSERT_EQ((*It)->getVal(), 5u); + + ++It; + EXPECT_NE(It, End); + ASSERT_EQ((*It)->getVal(), 7u); + + ++It; + EXPECT_NE(It, End); + ASSERT_EQ((*It)->getVal(), 9u); + + ++It; + ASSERT_EQ(It, End); +} + +TEST_F(DynCastIterTest, Range) { + for (EvenNumber *N : make_dyn_cast_range(Nums)) + (void)N; +} + +} // end anonymous namespace