diff --git a/llvm/include/llvm/ADT/IntrusiveVariant.h b/llvm/include/llvm/ADT/IntrusiveVariant.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ADT/IntrusiveVariant.h @@ -0,0 +1,638 @@ +//===- IntrusiveVariant.h - Compact type safe union -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides IntrusiveVariant, a class template modeled in the spirit +// of std::variant, but leveraging the "common initial sequence" rule for union +// members to store the runtime tag at the beginning of the IntrusiveVariant's +// alternative types, allowing for it to be packed more efficiently into bits +// that would otherwise be used for padding. +// +// However, this requires several restrictions be placed on valid alternative +// types. All alternative types of an IntrusiveVariant must: +// +// * Be standard-layout. This implies (among other things): +// * All non-static data members must have the same access control. +// * All non-static data members must be declared in only one class in the +// inheritence hierarchy. +// * No virtual methods. +// * Begin their class definition by invoking the +// DECLARE_INTRUSIVE_ALTERNATIVE macro. This declares a member named +// `IntrusiveVariantTagMember` which must not be referenced outside of the +// implementation of IntrusiveVariant, and declares some `friend` types to +// make the tag accessible to the implementation. +// +// Additionally, some features were omitted that are present in the C++17 +// std::variant to keep the code simpler: +// +// * All alternative types must be trivially-destructible. +// * All copy/move constructors and assignment operators for the variant are +// disabled if any type is not trivially-constructible and/or +// trivially-copyable, respectively. +// * Only one variant may be passed to the llvm::visit function. +// * All alternative types must be unique, and cannot be referred to by index. +// * No equivalent to std::monostate. An instantiation must have at least +// IntrusiveVariant::MinNumberOfAlternatives alternatives. +// * There is a static limit on the number of alternative types supported. +// An instantiation must have no more than +// IntrusiveVariant::MaxNumberOfAlternatives alternatives. The idea is that +// the cap is probably high enough to begin with, and if it isn't we likely +// want to raise the static limit and avoid any additional runtime cost +// required to implement it generally. +// +// If a use case for the above materializes these can always be added +// retroactively. +// +// Example: +// +// class AltInt { +// DECLARE_INTRUSIVE_ALTERNATIVE +// int Int; +// +// public: +// AltInt() : Int(0) {} +// AltInt(int Int) : Int(Int) {} +// int getInt() const { return Int; } +// void setInt(int Int) { this->Int = Int; } +// }; +// +// class AltDouble { +// DECLARE_INTRUSIVE_ALTERNATIVE +// double Double; +// +// public: +// AltDouble(double Double) : Double(Double) {} +// double getDouble() const { return Double; } +// void setDouble(double Double) { this->Double = Double; } +// }; +// +// class AltComplexInt { +// DECLARE_INTRUSIVE_ALTERNATIVE +// int Real; +// int Imag; +// +// public: +// AltComplexInt(int Real, int Imag) : Real(Real), Imag(Imag) {} +// int getReal() const { return Real; } +// void setReal(int Real) { this->Real = Real; } +// int getImag() const { return Imag; } +// void setImag(int Imag) { this->Imag = Imag; } +// }; +// +// TEST(VariantTest, HeaderExample) { +// using MyVariant = IntrusiveVariant; +// +// MyVariant DefaultConstructedVariant; +// ASSERT_TRUE(DefaultConstructedVariant.holdsAlternative()); +// ASSERT_EQ(DefaultConstructedVariant.get().getInt(), 0); +// MyVariant Variant{InPlaceType{}, 4, 2}; +// ASSERT_TRUE(Variant.holdsAlternative()); +// int NonSense = visit( +// makeVisitor( +// [](AltInt &AI) { return AI.getInt(); }, +// [](AltDouble &AD) { return static_cast(AD.getDouble()); }, +// [](AltComplexInt &ACI) { return ACI.getReal() + ACI.getImag(); }), +// Variant); +// ASSERT_EQ(NonSense, 6); +// Variant.emplace(2.0); +// ASSERT_TRUE(Variant.holdsAlternative()); +// Variant.get().setDouble(3.0); +// AltDouble AD = Variant.get(); +// double D = AD.getDouble(); +// ASSERT_EQ(D, 3.0); +// Variant.emplace(4, 5); +// ASSERT_EQ(Variant.get().getReal(), 4); +// ASSERT_EQ(Variant.get().getImag(), 5); +// } +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_INTRUSIVEVARIANT_H +#define LLVM_ADT_INTRUSIVEVARIANT_H + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" +#include +#include +#include + +namespace llvm { + +template class IntrusiveVariant; + +/// A disambiguation tag used to convey the alternative type to be constructed +/// in-place within an IntrusiveVariant. +/// +/// This is a workaround required by the lack of support for specifying a +/// template parameter to a templated constructor. See std::in_place_type_t. +template struct InPlaceType { explicit InPlaceType() = default; }; + +/// Helper to get the number of alternative types of a (possibly cv-qualified) +/// IntrusiveVariant type as a constexpr. See std::variant_size. +template struct IntrusiveVariantSize; +template +struct IntrusiveVariantSize : IntrusiveVariantSize {}; +template +struct IntrusiveVariantSize : IntrusiveVariantSize {}; +template +struct IntrusiveVariantSize : IntrusiveVariantSize {}; +template +struct IntrusiveVariantSize> + : std::integral_constant {}; + +/// Simple value type which must be the first member of all alternative types +/// of an IntrusiveVariant. See DECLARE_INTRUSIVE_ALTERNATIVE. +/// +/// The internal implementation assumes this is layout-compatible with the +/// "common initial sequence" of all alternative types contained in the private +/// union of the IntrusiveVariant. +struct IntrusiveVariantTag { + uint8_t Index = std::numeric_limits::max(); + IntrusiveVariantTag() {} + IntrusiveVariantTag(uint8_t Index) : Index(Index) {} +}; + +/// A helper macro to add the declarations needed to use a type as an +/// alternative for IntrusiveVariant. Must be the first declaration of the +/// class. +#define DECLARE_INTRUSIVE_ALTERNATIVE \ + ::llvm::IntrusiveVariantTag IntrusiveVariantTagMember; \ + template friend class ::llvm::IntrusiveVariant; \ + template \ + friend union ::llvm::detail::UnionImpl; + +namespace detail { +// This struct is used to access the intrusive tag of the alternative types. +// +// All such types must be have an initial sequence which is layout-compatible +// with this struct or the access causes undefined behavior. +struct CommonInitialSequenceT { + IntrusiveVariantTag Tag; +}; + +// Returns the index of T in the type pack Ts... if it is present, else -1. +template constexpr std::size_t indexOf() { + constexpr bool Matches[] = {std::is_same{}...}; + for (std::size_t I = 0; I < sizeof...(Ts); ++I) + if (Matches[I]) + return I; + return std::numeric_limits::max(); +} +// Wrapper "tag" type to work around inability to supply template parameters to +// constructor templates explicitly. This can just be a detail in our +// implementation because we don't allow repeated alternative types. +template struct InPlaceIndex { + explicit InPlaceIndex() = default; +}; + +// Helper to extract Nth type from a pack Ts... +template +using NthType = std::tuple_element_t>; + +// The inner implementation of the "type safe union". Members are only +// accessible directly via an Index, so IntrusiveVariant must use indexOf to +// convert a pair of T and Ts... into an index. +// +// Effectively implemented as a "linked list" of recursively defined union +// templates. This is the recursive portion of the definition. +// +// We use InPlaceIndex here both to disambiguate the constructor and to make +// defining the overload set for getMember more natural. +template +union UnionImpl { + using TailT = UnionImpl; + HeadT Head; + TailT Tail; + HeadT &getMember(InPlaceIndex) { return Head; } + const HeadT &getMember(InPlaceIndex) const { return Head; } + template decltype(auto) getMember(InPlaceIndex) { + return Tail.getMember(InPlaceIndex{}); + } + template decltype(auto) getMember(InPlaceIndex) const { + return Tail.getMember(InPlaceIndex{}); + } + template + UnionImpl(InPlaceIndex, ArgTs &&... Args) { + new (&Head) HeadT(std::forward(Args)...); + Head.IntrusiveVariantTagMember.Index = Index; + } + template + UnionImpl(InPlaceIndex, ArgTs &&... Args) { + new (&Tail) TailT(InPlaceIndex{}, std::forward(Args)...); + } + UnionImpl(const UnionImpl &) = default; + UnionImpl(UnionImpl &&) = default; + UnionImpl &operator=(const UnionImpl &) = default; + UnionImpl &operator=(UnionImpl &&) = default; + // This is safe, assuming the member types are all trivially destructible. + ~UnionImpl() = default; +}; +// The base case for the above, i.e. when the tail pack is empty. This is the +// "(cons head nil)" of the linked list. +template union UnionImpl { + HeadT Head; + HeadT &getMember(InPlaceIndex) { return Head; } + const HeadT &getMember(InPlaceIndex) const { return Head; } + template + UnionImpl(InPlaceIndex, ArgTs &&... Args) { + new (&Head) HeadT(std::forward(Args)...); + Head.IntrusiveVariantTagMember.Index = Index; + } + UnionImpl(const UnionImpl &) = default; + UnionImpl(UnionImpl &&) = default; + UnionImpl &operator=(const UnionImpl &) = default; + UnionImpl &operator=(UnionImpl &&) = default; + // This is safe, assuming the member types are all trivially destructible. + ~UnionImpl() = default; +}; + +// A recursive template definition for a callable visitor. +template struct Visitor; +template +struct Visitor : HeadT { + typedef Visitor Type; + typedef ResultT ResultType; + using HeadT::operator(); + Visitor(HeadT Head) : HeadT(Head) {} +}; +template +struct Visitor : HeadT, Visitor { + typedef Visitor Type; + typedef ResultT ResultType; + using HeadT::operator(); + using Visitor::Type::operator(); + Visitor(HeadT Head, TailTs... Tail) + : HeadT(Head), Visitor(Tail...) {} +}; + +// Template for callable which errors at runtime for the case when the given +// index doesn't refer to a valid alternative type for the IntrusiveVariant. +// +// Used to implement the switch-based dispatch in visit, because the "dead" +// case limbs still have to resolve to some overload, but we won't have an +// overload available for the cases where the index is too large (i.e. we want +// there to be no such overload, so we get compile-time errors for those cases +// everywhere else). As a workaround, we make the body of the out-of-bounds +// overloads be llvm_unreachable. +template +struct FallibleIntrusiveVariantVisitor { + // We assume we are used in a context where either: + // * There is only one variant + // * The index has already been proven to be the same for all variants + using FirstVariantT = std::tuple_element_t<0, std::tuple>; + using VariantSize = + IntrusiveVariantSize>; + // If the index is in-bounds for the variants, forward them to the visitor. + template = 0> + constexpr decltype(auto) operator()(VisitorT &&Visitor, + VariantTs &&... Variants) { + return std::forward(Visitor)( + std::forward(Variants).Union.getMember( + detail::InPlaceIndex{})...); + } + // Otherwise, the index is out-of-bounds, and we fail at runtime. We declare + // the return type of this case to be the same as for Index==0 so we do not + // interfere with the return type deduction for the in-bounds case. + template = NumTs), int> = 0> + constexpr auto operator()(VisitorT &&Visitor, VariantTs &&... Variants) + -> decltype(std::forward(Visitor)( + std::forward(Variants).Union.getMember( + detail::InPlaceIndex<0>{})...)) { + llvm_unreachable("invalid index of IntrusiveVariant visited"); + } +}; +// Convenience function to infer some of the template parameters of +// FallibleIntrusiveVariantVisitor. +template +constexpr decltype(auto) visitOrError(VisitorT &&Visitor, + VariantTs &&... Variants) { + return FallibleIntrusiveVariantVisitor{}( + std::forward(Visitor), std::forward(Variants)...); +} +} // end namespace detail + +/// Convenience function to create a visitor from a variadic list of callables. +template +decltype(auto) makeVisitor(CallableTs &&... Callables) { + return detail::Visitor{ + std::forward(Callables)...}; +} + +// Macro for writing the 32-case switch statement from a definition of CASE(N) +#define CASES8(B) \ + CASE((B + 0)) \ + CASE((B + 1)) \ + CASE((B + 2)) \ + CASE((B + 3)) \ + CASE((B + 4)) \ + CASE((B + 5)) \ + CASE((B + 6)) \ + CASE((B + 7)) +#define CASES32() CASES8(8 * 0) CASES8(8 * 1) CASES8(8 * 2) CASES8(8 * 3) +#define SWITCH(INDEX) \ + switch ((INDEX)) { \ + default: \ + llvm_unreachable("invalid index of IntrusiveVariant visited"); \ + CASES32() \ + } + +/// Invokes the provided Visitor using overload resolution based on the +/// dynamic alternative type held in Variant. See std::variant. +/// +/// The return type is decltype(Visitor(Variant.get())) for all T in the +/// alternative types Ts... of Variant. This must be a valid expression of the +/// same type and value category for all Ts... +template +constexpr decltype(auto) visit(VisitorT &&Visitor, VariantT &&Variant) { +#define CASE(N) \ + case N: \ + return detail::visitOrError(std::forward(Visitor), \ + std::forward(Variant)); + SWITCH(Variant.index()) +#undef CASE +} + +/// A class template modeled in the spirit of std::variant, but leveraging the +/// "common initial sequence" rule for union members to store the runtime tag +/// at the beginning of each variant alternative itself, allowing for it to be +/// packed more efficiently into bits that would otherwise be used for padding. +template class IntrusiveVariant { +public: + /// The static minimum number of alternative types supported for an + /// instantiation of IntrusiveVariant. + static constexpr std::size_t MinNumberOfAlternatives = 1; + /// The static maximum number of alternative types supported for an + /// instantiation of IntrusiveVariant. + static constexpr std::size_t MaxNumberOfAlternatives = 32; + +private: + static_assert(MinNumberOfAlternatives <= sizeof...(Ts), + "IntrusiveVariant must consist of no less than " + "IntrusiveVariant::MinNumberOfAlternatives alternatives."); + static_assert(sizeof...(Ts) <= MaxNumberOfAlternatives, + "IntrusiveVariant must consist of no more than " + "IntrusiveVariant::MaxNumberOfAlternatives alternatives."); + static_assert(llvm::conjunction...>::value, + "IntrusiveVariant alternatives must be standard-layout."); + static_assert( + llvm::conjunction...>::value, + "IntrusiveVariant alternatives must be trivially-destructible."); + template static constexpr bool tagIsFirstMember() { + constexpr bool IsFirstMember[] = { + !offsetof(Us, IntrusiveVariantTagMember)...}; + for (std::size_t I = 0; I < sizeof...(Us); ++I) + if (!IsFirstMember[I]) + return false; + return true; + } + static_assert( + tagIsFirstMember() && + llvm::conjunction< + std::is_same...>::value, + "IntrusiveVariant alternatives' class definition must begin with " + "DECLARE_INTRUSIVE_ALTERNATIVE"); + template static constexpr bool allTypesUnique() { + constexpr std::size_t IndexOf[] = {detail::indexOf()...}; + for (std::size_t I = 0; I < sizeof...(Us); ++I) + if (I != IndexOf[I]) + return false; + return true; + } + static_assert( + allTypesUnique(), + "Repeated alternative types in IntrusiveVariant are not allowed."); + + template + friend struct detail::FallibleIntrusiveVariantVisitor; + template + friend constexpr decltype(auto) visit(VisitorT &&, VariantT &&); + + // Alias for the UnionImpl of this IntrusiveVariant. + using UnionT = detail::UnionImpl<0, Ts...>; + // Helper to get the InPlaceIndex for T in Ts... + template + using IndexOfT = detail::InPlaceIndex()>; + // Helper to check if a type is in the set Ts... + template + using IsAlternativeType = llvm::disjunction...>; + + // The only data member of IntrusiveVariant, meaning the variant is the same + // size and has the same alignment requirements as the union of all of its + // alternative types. + union { + detail::CommonInitialSequenceT CommonInitialSequence; + UnionT Union; + }; + + // Returns the current dynamic index of this variant. + std::size_t index() const { return CommonInitialSequence.Tag.Index; } + + // Convenience methods to get the union member for an alternative type T. + template T &getAlt() { return Union.getMember(IndexOfT{}); } + template const T &getAlt() const { + return Union.getMember(IndexOfT{}); + } + +public: + /// A default constructed IntrusiveVariant holds a default constructed value + /// of its first alternative. Only enabled if the first alternative has a + /// default constructor. + template >{}, + typename std::enable_if_t = 0> + constexpr IntrusiveVariant() : Union(detail::InPlaceIndex<0>{}) {} + /// The forwarding constructor requires a disambiguation tag InPlaceTag, + /// and creates an IntrusiveVariant holding the alternative T constructed + /// with the constructor arguments Args... + template {}, int> = 0, + typename... ArgTs> + explicit constexpr IntrusiveVariant(InPlaceType, ArgTs &&... Args) + : Union(IndexOfT{}, std::forward(Args)...) {} + /// Converting constructor from alternative types. + template {}, int> = 0> + constexpr IntrusiveVariant(T &&Alt) + : Union(IndexOfT{}, std::forward(Alt)) {} + IntrusiveVariant(const IntrusiveVariant &) = default; + IntrusiveVariant(IntrusiveVariant &&) = default; + ~IntrusiveVariant() = default; + IntrusiveVariant &operator=(const IntrusiveVariant &) = default; + IntrusiveVariant &operator=(IntrusiveVariant &&) = default; + /// Replaces the held value with a new value of alternative type T in-place, + /// constructing the new value with constructor arguments Args... + /// + /// Returns the newly constructed alternative type value. + template T &emplace(ArgTs &&... Args) { + new (&Union) UnionT(IndexOfT{}, std::forward(Args)...); + return Union.getMember(IndexOfT{}); + } + /// Check if this variant holds a value of the given alternative type T. + template constexpr bool holdsAlternative() const { + return index() == detail::indexOf(); + } + /// Reads the value of alternative type T. + /// + /// Behavior undefined if this does not hold a value of alternative type T. + template constexpr T &get() { + assert(holdsAlternative()); + return getAlt(); + } + /// Reads the value of alternative type T. + /// + /// Behavior undefined if this does not hold a value of alternative type T. + template constexpr const T &get() const { + assert(holdsAlternative()); + return getAlt(); + } + /// Obtains a pointer to the value of alternative type T if this holds a + /// value of alternative type T. Otherwise, returns nullptr. + template constexpr T *getIf() { + if (holdsAlternative()) + return &getAlt(); + return nullptr; + } + /// Obtains a pointer to the value of alternative type T if this holds a + /// value of alternative type T. Otherwise, returns nullptr. + template constexpr const T *getIf() const { + if (holdsAlternative()) + return &getAlt(); + return nullptr; + } + + /// Equality operator. + /// + /// The alternative types held by LHS and RHS are T and U, respectively; then: + /// + /// If T != U, returns false. + /// Otherwise, returns LHS.get() == RHS.get(). + friend constexpr bool operator==(const IntrusiveVariant &LHS, + const IntrusiveVariant &RHS) { + if (LHS.index() != RHS.index()) + return false; +#define CASE(N) \ + case N: \ + return detail::visitOrError(std::equal_to<>{}, LHS, RHS); + SWITCH(LHS.index()) +#undef CASE + } + + /// Inequality operator. + /// + /// The alternative types held by LHS and RHS are T and U, respectively; then: + /// + /// If T != U, returns true. + /// Otherwise, returns LHS.get() != RHS.get(). + friend constexpr bool operator!=(const IntrusiveVariant &LHS, + const IntrusiveVariant &RHS) { + if (LHS.index() != RHS.index()) + return true; +#define CASE(N) \ + case N: \ + return detail::visitOrError(std::not_equal_to<>{}, LHS, RHS); + SWITCH(LHS.index()) +#undef CASE + } + + /// Less-than operator. + /// + /// The alternative types held by LHS and RHS are T and U, respectively; then: + /// + /// If T precedes U in Ts..., returns true. + /// If U precedes T in Ts..., returns false. + /// Otherwise, returns LHS.get() < RHS.get(). + friend constexpr bool operator<(const IntrusiveVariant &LHS, + const IntrusiveVariant &RHS) { + if (LHS.index() < RHS.index()) + return true; + if (RHS.index() > LHS.index()) + return false; +#define CASE(N) \ + case N: \ + return detail::visitOrError(std::less<>{}, LHS, RHS); + SWITCH(LHS.index()) +#undef CASE + } + + /// Greater-than operator. + /// + /// The alternative types held by LHS and RHS are T and U, respectively; then: + /// + /// If T precedes U in Ts..., returns false. + /// If U precedes T in Ts..., returns true. + /// Otherwise, returns LHS.get() > RHS.get(). + friend constexpr bool operator>(const IntrusiveVariant &LHS, + const IntrusiveVariant &RHS) { + if (LHS.index() < RHS.index()) + return false; + if (RHS.index() > LHS.index()) + return true; +#define CASE(N) \ + case N: \ + return detail::visitOrError(std::greater<>{}, LHS, RHS); + SWITCH(LHS.index()) +#undef CASE + } + + /// Less-equal operator. + /// + /// The alternative types held by LHS and RHS are T and U, respectively; then: + /// + /// If T precedes U in Ts..., returns true. + /// If U precedes T in Ts..., returns false. + /// Otherwise, returns LHS.get() <= RHS.get(). + friend constexpr bool operator<=(const IntrusiveVariant &LHS, + const IntrusiveVariant &RHS) { + if (LHS.index() < RHS.index()) + return true; + if (RHS.index() > LHS.index()) + return false; +#define CASE(N) \ + case N: \ + return detail::visitOrError(std::less_equal<>{}, LHS, RHS); + SWITCH(LHS.index()) +#undef CASE + } + + /// Greater-equal operator. + /// + /// The alternative types held by LHS and RHS are T and U, respectively; then: + /// + /// If T precedes U in Ts..., returns false. + /// If U precedes T in Ts..., returns true. + /// Otherwise, returns LHS.get() >= RHS.get(). + friend constexpr bool operator>=(const IntrusiveVariant &LHS, + const IntrusiveVariant &RHS) { + if (LHS.index() < RHS.index()) + return false; + if (RHS.index() > LHS.index()) + return true; +#define CASE(N) \ + case N: \ + return detail::visitOrError(std::greater_equal<>{}, LHS, RHS); + SWITCH(LHS.index()) +#undef CASE + } + + /// Enabled if all alternative types overload hash_value. + friend hash_code hash_value(const IntrusiveVariant &IV) { +#define CASE(N) \ + case N: \ + return detail::visitOrError( \ + [](auto &&Alt) { return hash_combine(N, hash_value(Alt)); }, IV); + SWITCH(IV.index()) +#undef CASE + } +}; + +#undef SWITCH +#undef CASES32 +#undef CASES8 + +} // end namespace llvm + +#endif // LLVM_ADT_INTRUSIVEVARIANT_H diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -69,6 +69,12 @@ struct conjunction : std::conditional, B1>::type {}; +template struct disjunction : std::false_type {}; +template struct disjunction : B1 {}; +template +struct disjunction + : std::conditional>::type {}; + template struct make_const_ptr { using type = typename std::add_pointer::type>::type; diff --git a/llvm/unittests/ADT/CMakeLists.txt b/llvm/unittests/ADT/CMakeLists.txt --- a/llvm/unittests/ADT/CMakeLists.txt +++ b/llvm/unittests/ADT/CMakeLists.txt @@ -41,6 +41,7 @@ IntEqClassesTest.cpp IntervalMapTest.cpp IntrusiveRefCntPtrTest.cpp + IntrusiveVariantTest.cpp IteratorTest.cpp MappedIteratorTest.cpp MapVectorTest.cpp diff --git a/llvm/unittests/ADT/IntrusiveVariantTest.cpp b/llvm/unittests/ADT/IntrusiveVariantTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/ADT/IntrusiveVariantTest.cpp @@ -0,0 +1,260 @@ +//===- llvm/unittest/Support/AnyTest.cpp - Any tests ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/IntrusiveVariant.h" +#include "gtest/gtest.h" +#include + +using namespace llvm; + +namespace { + +class A { + DECLARE_INTRUSIVE_ALTERNATIVE +}; + +class B { + DECLARE_INTRUSIVE_ALTERNATIVE +}; + +TEST(IntrusiveVariantTest, SingleAlternative) { IntrusiveVariant V; } + +TEST(IntrusiveVariantTest, ZeroArgConstructionAndAssignment) { + IntrusiveVariant V; + ASSERT_TRUE(V.holdsAlternative()); + visit(makeVisitor([](A) {}, [](B) { FAIL(); }), V); + visit(makeVisitor([](B) { FAIL(); }, [](A) {}), V); + visit(makeVisitor([](A &) {}, [](B &) { FAIL(); }), V); + visit(makeVisitor([](A) {}, [](B &) { FAIL(); }), V); + visit(makeVisitor([](A &) {}, [](B) { FAIL(); }), V); + visit(makeVisitor([](auto &&) {}), V); + + V.emplace(); + ASSERT_TRUE(V.holdsAlternative()); + + IntrusiveVariant W{V}; + ASSERT_TRUE(W.holdsAlternative()); + + const IntrusiveVariant X{V}; + ASSERT_TRUE(X.holdsAlternative()); +} + +#define DECLARE_ALT(NAME, TYPE) \ + class NAME { \ + DECLARE_INTRUSIVE_ALTERNATIVE \ + TYPE Val; \ + \ + public: \ + NAME(TYPE Val) : Val(Val) {} \ + TYPE getVal() const { return Val; } \ + friend bool operator==(const NAME &LHS, const NAME &RHS) { \ + return LHS.getVal() == RHS.getVal(); \ + } \ + friend bool operator!=(const NAME &LHS, const NAME &RHS) { \ + return LHS.getVal() != RHS.getVal(); \ + } \ + friend bool operator<(const NAME &LHS, const NAME &RHS) { \ + return LHS.getVal() < RHS.getVal(); \ + } \ + friend bool operator>(const NAME &LHS, const NAME &RHS) { \ + return LHS.getVal() > RHS.getVal(); \ + } \ + friend bool operator<=(const NAME &LHS, const NAME &RHS) { \ + return LHS.getVal() <= RHS.getVal(); \ + } \ + friend bool operator>=(const NAME &LHS, const NAME &RHS) { \ + return LHS.getVal() >= RHS.getVal(); \ + } \ + } +DECLARE_ALT(I, int); +DECLARE_ALT(F, float); +DECLARE_ALT(D, double); + +TEST(IntrusiveVariantTest, ConstructionAndAssignment) { + IntrusiveVariant V{InPlaceType{}, 2.0f}; + visit(makeVisitor([](I) { FAIL(); }, [](F X) { EXPECT_EQ(X.getVal(), 2.0f); }, + [](D) { FAIL(); }), + V); + IntrusiveVariant W{V}; + visit(makeVisitor([](I) { FAIL(); }, [](F X) { EXPECT_EQ(X.getVal(), 2.0f); }, + [](D) { FAIL(); }), + W); + W.emplace(42); + visit(makeVisitor([](I X) { EXPECT_EQ(X.getVal(), 42); }, [](F) { FAIL(); }, + [](D) { FAIL(); }), + W); + W = V; + visit(makeVisitor([](I) { FAIL(); }, [](F X) { EXPECT_EQ(X.getVal(), 2.0f); }, + [](D) { FAIL(); }), + W); +} + +TEST(IntrusiveVariantTest, Comparison) { + IntrusiveVariant V{InPlaceType{}, 1}; + IntrusiveVariant W{InPlaceType{}, 2.0f}; + IntrusiveVariant X{InPlaceType{}, 2.0f}; + IntrusiveVariant Y{InPlaceType{}, 3.0f}; + IntrusiveVariant Z{InPlaceType{}, 3.0}; + EXPECT_NE(V, W); + EXPECT_LT(V, W); + EXPECT_LE(V, W); + EXPECT_GT(W, V); + EXPECT_GE(W, V); + EXPECT_EQ(W, X); + EXPECT_LE(W, X); + EXPECT_GE(W, X); + EXPECT_NE(W, Y); + EXPECT_NE(X, Y); + EXPECT_LT(X, Y); + EXPECT_LE(X, Y); + EXPECT_GT(Y, X); + EXPECT_GE(Y, X); + EXPECT_NE(Y, Z); + std::swap(X, Y); + EXPECT_EQ(W, Y); + EXPECT_NE(W, X); +} + +TEST(IntrusiveVariantTest, IntrusiveVariantSize) { + constexpr auto One = IntrusiveVariantSize>{}; + EXPECT_EQ(One, 1u); + constexpr auto Two = IntrusiveVariantSize>{}; + EXPECT_EQ(Two, 2u); + constexpr auto Three = IntrusiveVariantSize>{}; + EXPECT_EQ(Three, 3u); +} + +TEST(IntrusiveVariantTest, HoldsAlternative) { + IntrusiveVariant V{InPlaceType{}, 2.0}; + EXPECT_FALSE(V.holdsAlternative()); + EXPECT_FALSE(V.holdsAlternative()); + EXPECT_TRUE(V.holdsAlternative()); + V.emplace(1); + EXPECT_TRUE(V.holdsAlternative()); + EXPECT_FALSE(V.holdsAlternative()); + EXPECT_FALSE(V.holdsAlternative()); + const IntrusiveVariant C{InPlaceType{}, 2.0f}; + EXPECT_FALSE(C.holdsAlternative()); + EXPECT_TRUE(C.holdsAlternative()); + EXPECT_FALSE(C.holdsAlternative()); +} + +TEST(IntrusiveVariantTest, Get) { + IntrusiveVariant V{InPlaceType{}, 2.0}; + EXPECT_EQ(V.get(), D{2.0}); + EXPECT_EQ(V.get(), *V.getIf()); + EXPECT_EQ(&V.get(), V.getIf()); + V.emplace(1); + EXPECT_EQ(V.get(), I{1}); + EXPECT_EQ(V.get(), *V.getIf()); + EXPECT_EQ(&V.get(), V.getIf()); + const IntrusiveVariant C{InPlaceType{}, 2.0}; + EXPECT_EQ(C.get(), D{2.0}); + EXPECT_EQ(C.get(), *C.getIf()); + EXPECT_EQ(&C.get(), C.getIf()); +} + +TEST(IntrusiveVariantTest, GetIf) { + IntrusiveVariant V{InPlaceType{}, 2.0}; + EXPECT_EQ(V.getIf(), nullptr); + EXPECT_EQ(V.getIf(), nullptr); + EXPECT_NE(V.getIf(), nullptr); + V.emplace(1); + EXPECT_NE(V.getIf(), nullptr); + EXPECT_EQ(V.getIf(), nullptr); + EXPECT_EQ(V.getIf(), nullptr); + const IntrusiveVariant C{InPlaceType{}, 2.0f}; + EXPECT_EQ(C.getIf(), nullptr); + EXPECT_NE(C.getIf(), nullptr); + EXPECT_EQ(C.getIf(), nullptr); +} + +struct IntA { + DECLARE_INTRUSIVE_ALTERNATIVE + int Val; + IntA(int Val) : Val(Val) {} + friend hash_code hash_value(const IntA &IA) { return hash_value(IA.Val); } +}; + +struct IntB { + DECLARE_INTRUSIVE_ALTERNATIVE + int Val; + IntB(int Val) : Val(Val) {} + friend hash_code hash_value(const IntB &IB) { return hash_value(IB.Val); } +}; + +TEST(IntrusiveVariantTest, HashValue) { + IntrusiveVariant ATwo{InPlaceType{}, 2}; + IntrusiveVariant AThree{InPlaceType{}, 3}; + IntrusiveVariant BTwo{InPlaceType{}, 2}; + EXPECT_EQ(hash_value(ATwo), hash_value(ATwo)); + EXPECT_NE(hash_value(ATwo), hash_value(AThree)); + EXPECT_NE(hash_value(ATwo), hash_value(BTwo)); +} + +class AltInt { + DECLARE_INTRUSIVE_ALTERNATIVE + int Int; + +public: + AltInt() : Int(0) {} + AltInt(int Int) : Int(Int) {} + int getInt() const { return Int; } + void setInt(int Int) { this->Int = Int; } +}; + +class AltDouble { + DECLARE_INTRUSIVE_ALTERNATIVE + double Double; + +public: + AltDouble(double Double) : Double(Double) {} + double getDouble() const { return Double; } + void setDouble(double Double) { this->Double = Double; } +}; + +class AltComplexInt { + DECLARE_INTRUSIVE_ALTERNATIVE + int Real; + int Imag; + +public: + AltComplexInt(int Real, int Imag) : Real(Real), Imag(Imag) {} + int getReal() const { return Real; } + void setReal(int Real) { this->Real = Real; } + int getImag() const { return Imag; } + void setImag(int Imag) { this->Imag = Imag; } +}; + +TEST(IntrusiveVariantTest, HeaderExample) { + using MyVariant = IntrusiveVariant; + + MyVariant DefaultConstructedVariant; + ASSERT_TRUE(DefaultConstructedVariant.holdsAlternative()); + ASSERT_EQ(DefaultConstructedVariant.get().getInt(), 0); + MyVariant Variant{InPlaceType{}, 4, 2}; + ASSERT_TRUE(Variant.holdsAlternative()); + int NonSense = visit( + makeVisitor( + [](AltInt &AI) { return AI.getInt(); }, + [](AltDouble &AD) { return static_cast(AD.getDouble()); }, + [](AltComplexInt &ACI) { return ACI.getReal() + ACI.getImag(); }), + Variant); + ASSERT_EQ(NonSense, 6); + Variant.emplace(2.0); + ASSERT_TRUE(Variant.holdsAlternative()); + Variant.get().setDouble(3.0); + AltDouble AD = Variant.get(); + double D = AD.getDouble(); + ASSERT_EQ(D, 3.0); + Variant.emplace(4, 5); + ASSERT_EQ(Variant.get().getReal(), 4); + ASSERT_EQ(Variant.get().getImag(), 5); +} + +} // anonymous namespace