diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h --- a/llvm/include/llvm/ADT/PointerUnion.h +++ b/llvm/include/llvm/ADT/PointerUnion.h @@ -18,6 +18,7 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include #include @@ -87,6 +88,9 @@ }; } +// forward declaration of CastInfoPointerUnionImpl +template struct CastInfoPointerUnionImpl; + /// A discriminated union of two or more pointer types, with the discriminator /// in the low bit of the pointer. /// @@ -122,6 +126,8 @@ using First = TypeAtIndex<0, PTs...>; using Base = typename PointerUnion::PointerUnionMembers; + friend struct CastInfoPointerUnionImpl>; + public: PointerUnion() = default; @@ -134,17 +140,19 @@ explicit operator bool() const { return !isNull(); } + // FIXME: Given that we're able to use isa, cast and dyn_cast, + // remove the following is(), get() and dyn_cast() after all the uses + // have been replaced. + /// Test if the Union currently holds the type matching T. - template bool is() const { - return this->Val.getInt() == FirstIndexOfType::value; - } + template bool is() const { return isa(*this); } /// Returns the value of the specified pointer type. /// /// If the specified pointer type is incorrect, assert. template T get() const { - assert(is() && "Invalid accessor called"); - return PointerLikeTypeTraits::getFromVoidPointer(this->Val.getPointer()); + assert(isa(*this) && "Invalid accessor called"); + return cast(*this); } /// Returns the current pointer if it is of the specified pointer type, @@ -205,6 +213,31 @@ return lhs.getOpaqueValue() < rhs.getOpaqueValue(); } +template struct CastInfoPointerUnionImpl { + template + static inline bool isPossible(PointerUnion &PU) { + return PU.Val.getInt() == FirstIndexOfType::value; + } + + template static To doCast(PointerUnion &PU) { + assert(isPossible(PU) && "cast to an incompatible type of pointer union !"); + return PointerLikeTypeTraits::getFromVoidPointer(PU.Val.getPointer()); + } +}; + +// Specialization of CastInfo for PointerUnion +template +struct CastInfo> + : CastInfoPointerUnionImpl, + NullableValueCastFailed, + DefaultDoCastIfPossible, + CastInfo>> {}; + +template +struct CastInfo> + : ConstStrippingForwardingCast, + CastInfo>> {}; + // Teach SmallPtrSet that PointerUnion is "basically a pointer", that has // # low bits available = min(PT1bits,PT2bits)-1. template diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp --- a/llvm/unittests/ADT/PointerUnionTest.cpp +++ b/llvm/unittests/ADT/PointerUnionTest.cpp @@ -156,4 +156,120 @@ EXPECT_TRUE((void *)n.getAddrOfPtr1() == (void *)&n); } +TEST_F(PointerUnionTest, NewCastInfra) { + // test isa<> + EXPECT_TRUE(isa(a)); + EXPECT_TRUE(isa(b)); + EXPECT_TRUE(isa(c)); + EXPECT_TRUE(isa(n)); + EXPECT_TRUE(isa(i3)); + EXPECT_TRUE(isa(f3)); + EXPECT_TRUE(isa(l3)); + EXPECT_TRUE(isa(i4)); + EXPECT_TRUE(isa(f4)); + EXPECT_TRUE(isa(l4)); + EXPECT_TRUE(isa(d4)); + EXPECT_TRUE(isa(i4null)); + EXPECT_TRUE(isa(f4null)); + EXPECT_TRUE(isa(l4null)); + EXPECT_TRUE(isa(d4null)); + EXPECT_FALSE(isa(a)); + EXPECT_FALSE(isa(b)); + EXPECT_FALSE(isa(c)); + EXPECT_FALSE(isa(n)); + EXPECT_FALSE(isa(i3)); + EXPECT_FALSE(isa(i3)); + EXPECT_FALSE(isa(f3)); + EXPECT_FALSE(isa(f3)); + EXPECT_FALSE(isa(l3)); + EXPECT_FALSE(isa(l3)); + EXPECT_FALSE(isa(i4)); + EXPECT_FALSE(isa(i4)); + EXPECT_FALSE(isa(i4)); + EXPECT_FALSE(isa(f4)); + EXPECT_FALSE(isa(f4)); + EXPECT_FALSE(isa(f4)); + EXPECT_FALSE(isa(l4)); + EXPECT_FALSE(isa(l4)); + EXPECT_FALSE(isa(l4)); + EXPECT_FALSE(isa(d4)); + EXPECT_FALSE(isa(d4)); + EXPECT_FALSE(isa(d4)); + EXPECT_FALSE(isa(i4null)); + EXPECT_FALSE(isa(i4null)); + EXPECT_FALSE(isa(i4null)); + EXPECT_FALSE(isa(f4null)); + EXPECT_FALSE(isa(f4null)); + EXPECT_FALSE(isa(f4null)); + EXPECT_FALSE(isa(l4null)); + EXPECT_FALSE(isa(l4null)); + EXPECT_FALSE(isa(l4null)); + EXPECT_FALSE(isa(d4null)); + EXPECT_FALSE(isa(d4null)); + EXPECT_FALSE(isa(d4null)); + + // test cast<> + EXPECT_EQ(cast(a), &f); + EXPECT_EQ(cast(b), &i); + EXPECT_EQ(cast(c), &i); + EXPECT_EQ(cast(i3), &i); + EXPECT_EQ(cast(f3), &f); + EXPECT_EQ(cast(l3), &l); + EXPECT_EQ(cast(i4), &i); + EXPECT_EQ(cast(f4), &f); + EXPECT_EQ(cast(l4), &l); + EXPECT_EQ(cast(d4), &d); + + // test dyn_cast + EXPECT_EQ(dyn_cast(a), nullptr); + EXPECT_EQ(dyn_cast(a), &f); + EXPECT_EQ(dyn_cast(b), &i); + EXPECT_EQ(dyn_cast(b), nullptr); + EXPECT_EQ(dyn_cast(c), &i); + EXPECT_EQ(dyn_cast(c), nullptr); + EXPECT_EQ(dyn_cast(n), nullptr); + EXPECT_EQ(dyn_cast(n), nullptr); + EXPECT_EQ(dyn_cast(i3), &i); + EXPECT_EQ(dyn_cast(i3), nullptr); + EXPECT_EQ(dyn_cast(i3), nullptr); + EXPECT_EQ(dyn_cast(f3), nullptr); + EXPECT_EQ(dyn_cast(f3), &f); + EXPECT_EQ(dyn_cast(f3), nullptr); + EXPECT_EQ(dyn_cast(l3), nullptr); + EXPECT_EQ(dyn_cast(l3), nullptr); + EXPECT_EQ(dyn_cast(l3), &l); + EXPECT_EQ(dyn_cast(i4), &i); + EXPECT_EQ(dyn_cast(i4), nullptr); + EXPECT_EQ(dyn_cast(i4), nullptr); + EXPECT_EQ(dyn_cast(i4), nullptr); + EXPECT_EQ(dyn_cast(f4), nullptr); + EXPECT_EQ(dyn_cast(f4), &f); + EXPECT_EQ(dyn_cast(f4), nullptr); + EXPECT_EQ(dyn_cast(f4), nullptr); + EXPECT_EQ(dyn_cast(l4), nullptr); + EXPECT_EQ(dyn_cast(l4), nullptr); + EXPECT_EQ(dyn_cast(l4), &l); + EXPECT_EQ(dyn_cast(l4), nullptr); + EXPECT_EQ(dyn_cast(d4), nullptr); + EXPECT_EQ(dyn_cast(d4), nullptr); + EXPECT_EQ(dyn_cast(d4), nullptr); + EXPECT_EQ(dyn_cast(d4), &d); + EXPECT_EQ(dyn_cast(i4null), nullptr); + EXPECT_EQ(dyn_cast(i4null), nullptr); + EXPECT_EQ(dyn_cast(i4null), nullptr); + EXPECT_EQ(dyn_cast(i4null), nullptr); + EXPECT_EQ(dyn_cast(f4null), nullptr); + EXPECT_EQ(dyn_cast(f4null), nullptr); + EXPECT_EQ(dyn_cast(f4null), nullptr); + EXPECT_EQ(dyn_cast(f4null), nullptr); + EXPECT_EQ(dyn_cast(l4null), nullptr); + EXPECT_EQ(dyn_cast(l4null), nullptr); + EXPECT_EQ(dyn_cast(l4null), nullptr); + EXPECT_EQ(dyn_cast(l4null), nullptr); + EXPECT_EQ(dyn_cast(d4null), nullptr); + EXPECT_EQ(dyn_cast(d4null), nullptr); + EXPECT_EQ(dyn_cast(d4null), nullptr); + EXPECT_EQ(dyn_cast(d4null), nullptr); +} + } // end anonymous namespace