diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -36,6 +36,7 @@ #include #include #include +#include #include #include #include @@ -208,6 +209,131 @@ // Extra additions to //===----------------------------------------------------------------------===// +namespace callable_detail { + +/// Templated storage wrapper for a callable. +/// +/// This class is consistently default constructible, copy / move +/// constructible / assignable. +/// +/// Supported callable types: +/// - Function pointer +/// - Function reference +/// - Lambda +/// - Function object +template >>> +class Callable { + using value_type = std::remove_reference_t; + using reference = value_type &; + using const_reference = value_type const &; + + std::optional Obj; + + static_assert(!std::is_pointer_v, + "Pointers to non-functions are not callable."); + +public: + Callable() = default; + Callable(T const &O) : Obj(std::in_place, O) {} + + Callable(Callable const &Other) = default; + Callable(Callable &&Other) = default; + + Callable &operator=(Callable const &Other) { + Obj = std::nullopt; + if (Other.Obj) + Obj.emplace(*Other.Obj); + return *this; + } + + Callable &operator=(Callable &&Other) { + Obj = std::nullopt; + if (Other.Obj) + Obj.emplace(std::move(*Other.Obj)); + return *this; + } + + template , int> = 0> + decltype(auto) operator()(Pn &&...Params) { + return (*Obj)(std::forward(Params)...); + } + + template , int> = 0> + decltype(auto) operator()(Pn &&...Params) const { + return (*Obj)(std::forward(Params)...); + } + + bool valid() const { return Obj != std::nullopt; } + bool reset() { return Obj = std::nullopt; } + + operator reference() { return *Obj; } + operator const_reference() const { return *Obj; } +}; + +// Function specialization. No need to waste extra space wrapping with a +// std::optional. +template class Callable { + static constexpr bool IsPtr = std::is_pointer_v>; + + using StorageT = std::conditional_t *>; + using CastT = std::conditional_t; + +private: + StorageT Func = nullptr; + +private: + template static constexpr auto convertIn(In &&I) { + if constexpr (IsPtr) { + // Pointer... just echo it back. + return I; + } else { + // Must be a function reference. Return its address. + return &I; + } + } + +public: + Callable() = default; + + // Construct from a function pointer or reference. + // + // Disable this constructor for references to 'Callable' so we don't violate + // the rule of 0. + template < // clang-format off + typename FnPtrOrRef, + std::enable_if_t< + !std::is_same_v, Callable>, int + > = 0 + > // clang-format on + Callable(FnPtrOrRef &&F) : Func(convertIn(F)) {} + + template , int> = 0> + decltype(auto) operator()(Pn &&...Params) const { + return Func(std::forward(Params)...); + } + + bool valid() const { return Func != nullptr; } + void reset() { Func = nullptr; } + + operator T const &() const { + if constexpr (IsPtr) { + // T is a pointer... just echo it back. + return Func; + } else { + static_assert(std::is_reference_v, + "Expected a reference to a function."); + // T is a function reference... dereference the stored pointer. + return *Func; + } + } +}; + +} // namespace callable_detail + namespace adl_detail { using std::begin; @@ -291,6 +417,7 @@ typename std::iterator_traits::difference_type, std::remove_reference_t *, ReferenceTy> { public: + mapped_iterator() = default; mapped_iterator(ItTy U, FuncTy F) : mapped_iterator::iterator_adaptor_base(std::move(U)), F(std::move(F)) {} @@ -301,7 +428,7 @@ ReferenceTy operator*() const { return F(*this->I); } private: - FuncTy F; + callable_detail::Callable F{}; }; // map_iterator - Provide a convenient way to create mapped_iterators, just like diff --git a/llvm/unittests/ADT/MappedIteratorTest.cpp b/llvm/unittests/ADT/MappedIteratorTest.cpp --- a/llvm/unittests/ADT/MappedIteratorTest.cpp +++ b/llvm/unittests/ADT/MappedIteratorTest.cpp @@ -13,10 +13,201 @@ namespace { -TEST(MappedIteratorTest, ApplyFunctionOnDereference) { +template class MappedIteratorTestBasic : public testing::Test {}; + +struct Plus1Lambda { + auto operator()() const { + return [](int X) { return X + 1; }; + } +}; + +struct Plus1LambdaWithCapture { + const int One = 1; + + auto operator()() const { + return [=](int X) { return X + One; }; + } +}; + +struct Plus1FunctionRef { + static int plus1(int X) { return X + 1; } + + using FuncT = int (&)(int); + + FuncT operator()() const { return *plus1; } +}; + +struct Plus1FunctionPtr { + static int plus1(int X) { return X + 1; } + + using FuncT = int (*)(int); + + FuncT operator()() const { return plus1; } +}; + +struct Plus1Functor { + struct Plus1 { + int operator()(int X) const { return X + 1; } + }; + + auto operator()() const { return Plus1(); } +}; + +struct Plus1FunctorNotDefaultConstructible { + class PlusN { + const int N; + + public: + PlusN(int NArg) : N(NArg) {} + + int operator()(int X) const { return X + N; } + }; + + auto operator()() const { return PlusN(1); } +}; + +// clang-format off +using FunctionTypes = + ::testing::Types< + Plus1Lambda, + Plus1LambdaWithCapture, + Plus1FunctionRef, + Plus1FunctionPtr, + Plus1Functor, + Plus1FunctorNotDefaultConstructible + >; +// clang-format on + +TYPED_TEST_SUITE(MappedIteratorTestBasic, FunctionTypes, ); + +template using GetFuncT = decltype(std::declval().operator()()); + +TYPED_TEST(MappedIteratorTestBasic, DefaultConstruct) { + using FuncT = GetFuncT; + using IterT = mapped_iterator::iterator, FuncT>; + TypeParam GetCallable; + + auto Func = GetCallable(); + (void)Func; + constexpr bool DefaultConstruct = + std::is_default_constructible_v>; + EXPECT_TRUE(DefaultConstruct); + EXPECT_TRUE(std::is_default_constructible_v); + + if constexpr (std::is_default_constructible_v) { + IterT I; + (void)I; + } +} + +TYPED_TEST(MappedIteratorTestBasic, CopyConstruct) { + std::vector V({0}); + + using FuncT = GetFuncT; + using IterT = mapped_iterator; + + EXPECT_TRUE(std::is_copy_constructible_v); + + if constexpr (std::is_copy_constructible_v) { + TypeParam GetCallable; + + IterT I1(V.begin(), GetCallable()); + IterT I2(I1); + + EXPECT_EQ(I2, I1) << "copy constructed iterator is a different position"; + } +} + +TYPED_TEST(MappedIteratorTestBasic, MoveConstruct) { + std::vector V({0}); + + using FuncT = GetFuncT; + using IterT = mapped_iterator; + + EXPECT_TRUE(std::is_move_constructible_v); + + if constexpr (std::is_move_constructible_v) { + TypeParam GetCallable; + + IterT I1(V.begin(), GetCallable()); + IterT I2(V.begin(), GetCallable()); + IterT I3(std::move(I2)); + + EXPECT_EQ(I3, I1) << "move constructed iterator is a different position"; + } +} + +TYPED_TEST(MappedIteratorTestBasic, CopyAssign) { std::vector V({0}); - auto I = map_iterator(V.begin(), [](int X) { return X + 1; }); + using FuncT = GetFuncT; + using IterT = mapped_iterator; + + EXPECT_TRUE(std::is_copy_assignable_v); + + if constexpr (std::is_copy_assignable_v) { + TypeParam GetCallable; + + IterT I1(V.begin(), GetCallable()); + IterT I2(V.end(), GetCallable()); + + I2 = I1; + + EXPECT_EQ(I2, I1) << "copy assigned iterator is a different position"; + } +} + +TYPED_TEST(MappedIteratorTestBasic, MoveAssign) { + std::vector V({0}); + + using FuncT = GetFuncT; + using IterT = mapped_iterator; + + EXPECT_TRUE(std::is_move_assignable_v); + + if constexpr (std::is_move_assignable_v) { + TypeParam GetCallable; + + IterT I1(V.begin(), GetCallable()); + IterT I2(V.begin(), GetCallable()); + IterT I3(V.end(), GetCallable()); + + I3 = std::move(I2); + + EXPECT_EQ(I2, I1) << "move assigned iterator is a different position"; + } +} + +TYPED_TEST(MappedIteratorTestBasic, GetFunction) { + std::vector V({0}); + + using FuncT = GetFuncT; + using IterT = mapped_iterator; + + TypeParam GetCallable; + IterT I(V.begin(), GetCallable()); + + EXPECT_EQ(I.getFunction()(200), 201); +} + +TYPED_TEST(MappedIteratorTestBasic, GetCurrent) { + std::vector V({0}); + + using FuncT = GetFuncT; + using IterT = mapped_iterator; + + TypeParam GetCallable; + IterT I(V.begin(), GetCallable()); + + EXPECT_EQ(I.getCurrent(), V.begin()); + EXPECT_EQ(std::next(I).getCurrent(), V.end()); +} + +TYPED_TEST(MappedIteratorTestBasic, ApplyFunctionOnDereference) { + std::vector V({0}); + TypeParam GetCallable; + + auto I = map_iterator(V.begin(), GetCallable()); EXPECT_EQ(*I, 1) << "should have applied function in dereference"; } @@ -28,9 +219,9 @@ std::vector V({0}); S Y; - S* P = &Y; + S *P = &Y; - auto I = map_iterator(V.begin(), [&](int X) -> S& { return *(P + X); }); + auto I = map_iterator(V.begin(), [&](int X) -> S & { return *(P + X); }); I->Z = 42; @@ -39,9 +230,9 @@ TEST(MappedIteratorTest, FunctionPreservesReferences) { std::vector V({1}); - std::map M({ {1, 1} }); + std::map M({{1, 1}}); - auto I = map_iterator(V.begin(), [&](int X) -> int& { return M[X]; }); + auto I = map_iterator(V.begin(), [&](int X) -> int & { return M[X]; }); *I = 42; EXPECT_EQ(M[1], 42) << "assignment should have modified M";