diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h --- a/llvm/include/llvm/ADT/PointerUnion.h +++ b/llvm/include/llvm/ADT/PointerUnion.h @@ -16,6 +16,8 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/Visitor.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include #include @@ -64,18 +66,28 @@ return std::min({PointerLikeTypeTraits::NumLowBitsAvailable...}); } - /// Find the index of a type in a list of types. TypeIndex::Index - /// is the index of T in Us, or sizeof...(Us) if T does not appear in the - /// list. - template struct TypeIndex; - template struct TypeIndex { + template struct SelectTypeAtIndex {}; + + template + struct SelectTypeAtIndex : SelectTypeAtIndex {}; + + template + struct SelectTypeAtIndex<0, T0, Ts...> { + using type = T0; + }; + + /// Find the index of a type in a list of types. FindIndexOfType::Index is the index of T in Us, or sizeof...(Us) if T does not + /// appear in the list. + template struct FindIndexOfType; + template struct FindIndexOfType { static constexpr int Index = 0; }; template - struct TypeIndex { - static constexpr int Index = 1 + TypeIndex::Index; + struct FindIndexOfType { + static constexpr int Index = 1 + FindIndexOfType::Index; }; - template struct TypeIndex { + template struct FindIndexOfType { static constexpr int Index = 0; }; @@ -145,6 +157,9 @@ /// P = (float*)0; /// Y = P.get(); // ok. /// X = P.get(); // runtime assertion failure. +/// W = P.visit( +/// [](int* i) { std::cout << i; }, +/// [](float* f) { std::cout << f; }); // ok. template class PointerUnion : public pointer_union_detail::PointerUnionMembers< @@ -175,7 +190,8 @@ /// Test if the Union currently holds the type matching T. template bool is() const { - constexpr int Index = pointer_union_detail::TypeIndex::Index; + constexpr int Index = + pointer_union_detail::FindIndexOfType::Index; static_assert(Index < sizeof...(PTs), "PointerUnion::is given type not in the union"); return this->Val.getInt() == Index; @@ -230,6 +246,68 @@ V.Val = decltype(V.Val)::getFromOpaqueValue(VP); return V; } + + template + using TypeAtIndex = + typename pointer_union_detail::SelectTypeAtIndex::type; + using ArbitraryType = TypeAtIndex<0>; + + struct UnreachableSwitchCase { + [[noreturn]] static ArbitraryType get(PointerUnion ptr) { + llvm_unreachable("Unreachable case in PointerUnion visitor."); + } + }; + + template struct ReachableSwitchCase { + static TypeAtIndex get(PointerUnion ptr) { + return ptr.get>(); + } + }; + + // Default Getter implements a case statement for unreachable branches. + template struct GetterImpl { + template using Body = UnreachableSwitchCase; + }; + + template <> struct GetterImpl { + template using Body = ReachableSwitchCase; + }; + + template + using Getter = typename GetterImpl<(I < sizeof...(PTs))>::template Body; + + template + auto static visit(Visitor &&visitor, PointerUnion pu) { + static_assert( + sizeof...(PTs) <= 8, + "PointerUnion::visit implemented for at most 8 alternatives."); + switch (pu.Val.getInt()) { + case 0: + return visitor(Getter<0>::get(pu)); + case 1: + return visitor(Getter<1>::get(pu)); + case 2: + return visitor(Getter<2>::get(pu)); + case 3: + return visitor(Getter<3>::get(pu)); + case 4: + return visitor(Getter<4>::get(pu)); + case 5: + return visitor(Getter<5>::get(pu)); + case 6: + return visitor(Getter<6>::get(pu)); + case 7: + return visitor(Getter<7>::get(pu)); + } + } + + /// Slightly more ergonomic version of `visit`. + /// + /// Named differently because the calling convention is different. Name is + /// inspired by Rust's `match` keyword. + template auto match(LAMBDAS... lambdas) { + return visit(llvm::makeVisitor(lambdas...), *this); + } }; template diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp --- a/llvm/unittests/ADT/PointerUnionTest.cpp +++ b/llvm/unittests/ADT/PointerUnionTest.cpp @@ -156,4 +156,23 @@ EXPECT_TRUE((void *)n.getAddrOfPtr1() == (void *)&n); } +TEST_F(PointerUnionTest, Visitor) { + void *ptr; + + a.match([&](int *intPtr) { ptr = intPtr; }, + [&](float *floatPtr) { ptr = floatPtr; }); + EXPECT_EQ(&f, ptr); + + i3.match([&](int *intPtr) { ptr = intPtr; }, + [&](float *floatPtr) { ptr = floatPtr; }, + [&](long long *longPtr) { ptr = longPtr; }); + EXPECT_EQ(&i, ptr); + + d4.match([&](int *intPtr) { ptr = intPtr; }, + [&](float *floatPtr) { ptr = floatPtr; }, + [&](long long *longPtr) { ptr = longPtr; }, + [&](double *doublePtr) { ptr = doublePtr; }); + EXPECT_EQ(&d, ptr); +} + } // end anonymous namespace