diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h --- a/llvm/include/llvm/ADT/PointerUnion.h +++ b/llvm/include/llvm/ADT/PointerUnion.h @@ -16,6 +16,7 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/VariantTraits.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include #include @@ -149,6 +150,8 @@ using First = TypeAtIndex<0, PTs...>; using Base = typename PointerUnion::PointerUnionMembers; + friend struct variant_traits_detail::VariantTraits; + public: PointerUnion() = default; @@ -272,6 +275,19 @@ } }; +template +struct variant_traits_detail::VariantTraits> { + static constexpr size_t numAlts() { return sizeof...(PTs); } + static constexpr size_t getIndex(const PointerUnion &Variant) { + return Variant.Val.getInt(); + } + template + static constexpr decltype(auto) getAlt(VariantT &&Variant) { + return std::forward(Variant) + .template get>(); + } +}; + } // end namespace llvm #endif // LLVM_ADT_POINTERUNION_H diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp --- a/llvm/unittests/ADT/PointerUnionTest.cpp +++ b/llvm/unittests/ADT/PointerUnionTest.cpp @@ -156,4 +156,26 @@ EXPECT_TRUE((void *)n.getAddrOfPtr1() == (void *)&n); } +TEST_F(PointerUnionTest, Visitor) { + void *ptr; + + visit(makeVisitor([&](int *intPtr) { ptr = intPtr; }, + [&](float *floatPtr) { ptr = floatPtr; }), + a); + EXPECT_EQ(&f, ptr); + + visit(makeVisitor([&](int *intPtr) { ptr = intPtr; }, + [&](float *floatPtr) { ptr = floatPtr; }, + [&](long long *longPtr) { ptr = longPtr; }), + i3); + EXPECT_EQ(&i, ptr); + + visit(makeVisitor([&](int *intPtr) { ptr = intPtr; }, + [&](float *floatPtr) { ptr = floatPtr; }, + [&](long long *longPtr) { ptr = longPtr; }, + [&](double *doublePtr) { ptr = doublePtr; }), + d4); + EXPECT_EQ(&d, ptr); +} + } // end anonymous namespace