diff --git a/llvm/include/llvm/ADT/PointerUnion.h b/llvm/include/llvm/ADT/PointerUnion.h --- a/llvm/include/llvm/ADT/PointerUnion.h +++ b/llvm/include/llvm/ADT/PointerUnion.h @@ -18,6 +18,7 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/PointerLikeTypeTraits.h" #include #include @@ -87,6 +88,9 @@ }; } +// This is a forward declaration of CastInfoPointerUnionImpl +// Refer to its definition below for further details +template struct CastInfoPointerUnionImpl; /// A discriminated union of two or more pointer types, with the discriminator /// in the low bit of the pointer. /// @@ -122,6 +126,11 @@ using First = TypeAtIndex<0, PTs...>; using Base = typename PointerUnion::PointerUnionMembers; + /// This is needed to give the CastInfo implementation below access + /// to protected members. + /// Refer to its definition for further details. + friend struct CastInfoPointerUnionImpl; + public: PointerUnion() = default; @@ -134,25 +143,24 @@ explicit operator bool() const { return !isNull(); } + // FIXME: Replace the uses of is(), get() and dyn_cast() with + // isa, cast and the llvm::dyn_cast + /// Test if the Union currently holds the type matching T. - template bool is() const { - return this->Val.getInt() == FirstIndexOfType::value; - } + template inline bool is() const { return isa(*this); } /// Returns the value of the specified pointer type. /// /// If the specified pointer type is incorrect, assert. - template T get() const { - assert(is() && "Invalid accessor called"); - return PointerLikeTypeTraits::getFromVoidPointer(this->Val.getPointer()); + template inline T get() const { + assert(isa(*this) && "Invalid accessor called"); + return cast(*this); } /// Returns the current pointer if it is of the specified pointer type, /// otherwise returns null. - template T dyn_cast() const { - if (is()) - return get(); - return T(); + template inline T dyn_cast() const { + return llvm::dyn_cast(*this); } /// If the union is set to the first pointer type get an address pointing to @@ -205,6 +213,52 @@ return lhs.getOpaqueValue() < rhs.getOpaqueValue(); } +/// We can't (at least, at this moment with C++14) declare CastInfo +/// as a friend of PointerUnion like this: +/// ``` +/// template +/// friend struct CastInfo>; +/// ``` +/// The compiler complains 'Partial specialization cannot be declared as a +/// friend'. +/// So we define this struct to be a bridge between CastInfo and +/// PointerUnion. +template struct CastInfoPointerUnionImpl { + using From = PointerUnion; + + template static inline bool isPossible(From &F) { + return F.Val.getInt() == FirstIndexOfType::value; + } + + template static To doCast(From &F) { + assert(isPossible(F) && "cast to an incompatible type !"); + return PointerLikeTypeTraits::getFromVoidPointer(F.Val.getPointer()); + } +}; + +// Specialization of CastInfo for PointerUnion +template +struct CastInfo> + : public DefaultDoCastIfPossible, + CastInfo>> { + using From = PointerUnion; + using Impl = CastInfoPointerUnionImpl; + + static inline bool isPossible(From &f) { + return Impl::template isPossible(f); + } + + static To doCast(From &f) { return Impl::template doCast(f); } + + static inline To castFailed() { return To(); } +}; + +template +struct CastInfo> + : public ConstStrippingForwardingCast, + CastInfo>> { +}; + // Teach SmallPtrSet that PointerUnion is "basically a pointer", that has // # low bits available = min(PT1bits,PT2bits)-1. template diff --git a/llvm/unittests/ADT/PointerUnionTest.cpp b/llvm/unittests/ADT/PointerUnionTest.cpp --- a/llvm/unittests/ADT/PointerUnionTest.cpp +++ b/llvm/unittests/ADT/PointerUnionTest.cpp @@ -156,4 +156,136 @@ EXPECT_TRUE((void *)n.getAddrOfPtr1() == (void *)&n); } +TEST_F(PointerUnionTest, NewCastInfra) { + // test isa<> + EXPECT_TRUE(isa(a)); + EXPECT_TRUE(isa(b)); + EXPECT_TRUE(isa(c)); + EXPECT_TRUE(isa(n)); + EXPECT_TRUE(isa(i3)); + EXPECT_TRUE(isa(f3)); + EXPECT_TRUE(isa(l3)); + EXPECT_TRUE(isa(i4)); + EXPECT_TRUE(isa(f4)); + EXPECT_TRUE(isa(l4)); + EXPECT_TRUE(isa(d4)); + EXPECT_TRUE(isa(i4null)); + EXPECT_TRUE(isa(f4null)); + EXPECT_TRUE(isa(l4null)); + EXPECT_TRUE(isa(d4null)); + EXPECT_FALSE(isa(a)); + EXPECT_FALSE(isa(b)); + EXPECT_FALSE(isa(c)); + EXPECT_FALSE(isa(n)); + EXPECT_FALSE(isa(i3)); + EXPECT_FALSE(isa(i3)); + EXPECT_FALSE(isa(f3)); + EXPECT_FALSE(isa(f3)); + EXPECT_FALSE(isa(l3)); + EXPECT_FALSE(isa(l3)); + EXPECT_FALSE(isa(i4)); + EXPECT_FALSE(isa(i4)); + EXPECT_FALSE(isa(i4)); + EXPECT_FALSE(isa(f4)); + EXPECT_FALSE(isa(f4)); + EXPECT_FALSE(isa(f4)); + EXPECT_FALSE(isa(l4)); + EXPECT_FALSE(isa(l4)); + EXPECT_FALSE(isa(l4)); + EXPECT_FALSE(isa(d4)); + EXPECT_FALSE(isa(d4)); + EXPECT_FALSE(isa(d4)); + EXPECT_FALSE(isa(i4null)); + EXPECT_FALSE(isa(i4null)); + EXPECT_FALSE(isa(i4null)); + EXPECT_FALSE(isa(f4null)); + EXPECT_FALSE(isa(f4null)); + EXPECT_FALSE(isa(f4null)); + EXPECT_FALSE(isa(l4null)); + EXPECT_FALSE(isa(l4null)); + EXPECT_FALSE(isa(l4null)); + EXPECT_FALSE(isa(d4null)); + EXPECT_FALSE(isa(d4null)); + EXPECT_FALSE(isa(d4null)); + + // test cast<> + EXPECT_EQ(cast(a), &f); + EXPECT_EQ(cast(b), &i); + EXPECT_EQ(cast(c), &i); + EXPECT_EQ(cast(i3), &i); + EXPECT_EQ(cast(f3), &f); + EXPECT_EQ(cast(l3), &l); + EXPECT_EQ(cast(i4), &i); + EXPECT_EQ(cast(f4), &f); + EXPECT_EQ(cast(l4), &l); + EXPECT_EQ(cast(d4), &d); + + // test dyn_cast + EXPECT_EQ(dyn_cast(a), nullptr); + EXPECT_EQ(dyn_cast(a), &f); + EXPECT_EQ(dyn_cast(b), &i); + EXPECT_EQ(dyn_cast(b), nullptr); + EXPECT_EQ(dyn_cast(c), &i); + EXPECT_EQ(dyn_cast(c), nullptr); + EXPECT_EQ(dyn_cast(n), nullptr); + EXPECT_EQ(dyn_cast(n), nullptr); + EXPECT_EQ(dyn_cast(i3), &i); + EXPECT_EQ(dyn_cast(i3), nullptr); + EXPECT_EQ(dyn_cast(i3), nullptr); + EXPECT_EQ(dyn_cast(f3), nullptr); + EXPECT_EQ(dyn_cast(f3), &f); + EXPECT_EQ(dyn_cast(f3), nullptr); + EXPECT_EQ(dyn_cast(l3), nullptr); + EXPECT_EQ(dyn_cast(l3), nullptr); + EXPECT_EQ(dyn_cast(l3), &l); + EXPECT_EQ(dyn_cast(i4), &i); + EXPECT_EQ(dyn_cast(i4), nullptr); + EXPECT_EQ(dyn_cast(i4), nullptr); + EXPECT_EQ(dyn_cast(i4), nullptr); + EXPECT_EQ(dyn_cast(f4), nullptr); + EXPECT_EQ(dyn_cast(f4), &f); + EXPECT_EQ(dyn_cast(f4), nullptr); + EXPECT_EQ(dyn_cast(f4), nullptr); + EXPECT_EQ(dyn_cast(l4), nullptr); + EXPECT_EQ(dyn_cast(l4), nullptr); + EXPECT_EQ(dyn_cast(l4), &l); + EXPECT_EQ(dyn_cast(l4), nullptr); + EXPECT_EQ(dyn_cast(d4), nullptr); + EXPECT_EQ(dyn_cast(d4), nullptr); + EXPECT_EQ(dyn_cast(d4), nullptr); + EXPECT_EQ(dyn_cast(d4), &d); + EXPECT_EQ(dyn_cast(i4null), nullptr); + EXPECT_EQ(dyn_cast(i4null), nullptr); + EXPECT_EQ(dyn_cast(i4null), nullptr); + EXPECT_EQ(dyn_cast(i4null), nullptr); + EXPECT_EQ(dyn_cast(f4null), nullptr); + EXPECT_EQ(dyn_cast(f4null), nullptr); + EXPECT_EQ(dyn_cast(f4null), nullptr); + EXPECT_EQ(dyn_cast(f4null), nullptr); + EXPECT_EQ(dyn_cast(l4null), nullptr); + EXPECT_EQ(dyn_cast(l4null), nullptr); + EXPECT_EQ(dyn_cast(l4null), nullptr); + EXPECT_EQ(dyn_cast(l4null), nullptr); + EXPECT_EQ(dyn_cast(d4null), nullptr); + EXPECT_EQ(dyn_cast(d4null), nullptr); + EXPECT_EQ(dyn_cast(d4null), nullptr); + EXPECT_EQ(dyn_cast(d4null), nullptr); + + // test for const + const PU4 constd4(&d); + EXPECT_TRUE(isa(constd4)); + EXPECT_FALSE(isa(constd4)); + EXPECT_EQ(cast(constd4), &d); + EXPECT_EQ(dyn_cast(constd4), nullptr); + + auto *result1 = cast(constd4); + static_assert(std::is_same::value, + "type mismatch for cast with PointerUnion"); + + PointerUnion constd2(&d); + auto *result2 = cast(d); + static_assert(std::is_same::value, + "type mismatch for cast with PointerUnion"); +} + } // end anonymous namespace