diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h --- a/llvm/include/llvm/ADT/PointerUnion.h +++ b/llvm/include/llvm/ADT/PointerUnion.h @@ -16,6 +16,7 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/Visitor.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include #include @@ -64,18 +65,27 @@ return std::min({PointerLikeTypeTraits::NumLowBitsAvailable...}); } - /// Find the index of a type in a list of types. TypeIndex::Index - /// is the index of T in Us, or sizeof...(Us) if T does not appear in the - /// list. - template struct TypeIndex; - template struct TypeIndex { + template struct SelectTypeAtIndex {}; + + template + struct SelectTypeAtIndex : SelectTypeAtIndex {}; + + template struct SelectTypeAtIndex<0, T0, Ts...> { + using type = T0; + }; + + /// Find the index of a type in a list of types. FindIndexOfType::Index is the index of T in Us, or sizeof...(Us) if T does not + /// appear in the list. + template struct FindIndexOfType; + template struct FindIndexOfType { static constexpr int Index = 0; }; template - struct TypeIndex { - static constexpr int Index = 1 + TypeIndex::Index; + struct FindIndexOfType { + static constexpr int Index = 1 + FindIndexOfType::Index; }; - template struct TypeIndex { + template struct FindIndexOfType { static constexpr int Index = 0; }; @@ -145,6 +155,9 @@ /// P = (float*)0; /// Y = P.get(); // ok. /// X = P.get(); // runtime assertion failure. +/// W = P.visit( +/// [](int* i) { std::cout << i; }, +/// [](float* f) { std::cout << f; }); // ok. template class PointerUnion : public pointer_union_detail::PointerUnionMembers< @@ -175,7 +188,8 @@ /// Test if the Union currently holds the type matching T. template bool is() const { - constexpr int Index = pointer_union_detail::TypeIndex::Index; + constexpr int Index = + pointer_union_detail::FindIndexOfType::Index; static_assert(Index < sizeof...(PTs), "PointerUnion::is given type not in the union"); return this->Val.getInt() == Index; @@ -230,6 +244,70 @@ V.Val = decltype(V.Val)::getFromOpaqueValue(VP); return V; } + + template + using TypeAtIndex = + typename pointer_union_detail::SelectTypeAtIndex::type; + using ArbitraryType = TypeAtIndex<0>; + + struct UnreachableSwitchCase { + [[noreturn]] static ArbitraryType get(PointerUnion p) { +#if ((defined(__has_builtin) && __has_builtin(__builtin_unreachable)) || \ + (defined(__GNUC__) && !defined(__clang__))) + __builtin_unreachable(); +#elif defined(_MSC_VER) + __assume(false); +#else + assert(false); // NOLINT +#endif + } + }; + + template struct ReachableSwitchCase { + static TypeAtIndex get(PointerUnion p) { + return p.get>(); + } + }; + + // Default Getter implements a case statement for unreachable branches. + template struct GetterImpl { + template using Body = UnreachableSwitchCase; + }; + + template <> struct GetterImpl { + template using Body = ReachableSwitchCase; + }; + + template + using Getter = typename GetterImpl<(I < sizeof...(PTs))>::template Body; + + template auto visit(Visitor &&visitor) { + static_assert( + sizeof...(PTs) <= 8, + "PointerUnion::visit implemented for at most 8 alternatives."); + switch (this->Val.getInt()) { + case 0: + return visitor(Getter<0>::get(*this)); + case 1: + return visitor(Getter<1>::get(*this)); + case 2: + return visitor(Getter<2>::get(*this)); + case 3: + return visitor(Getter<3>::get(*this)); + case 4: + return visitor(Getter<4>::get(*this)); + case 5: + return visitor(Getter<5>::get(*this)); + case 6: + return visitor(Getter<6>::get(*this)); + case 7: + return visitor(Getter<7>::get(*this)); + } + } + + template auto visit(LAMBDAS... lambdas) { + return visit(llvm::makeVisitor(lambdas...)); + } }; template