Index: clang/lib/AST/Interp/ByteCodeExprGen.h =================================================================== --- clang/lib/AST/Interp/ByteCodeExprGen.h +++ clang/lib/AST/Interp/ByteCodeExprGen.h @@ -269,8 +269,8 @@ } bool emitRecordDestruction(const Descriptor *Desc); - bool emitDerivedToBaseCasts(const RecordType *DerivedType, - const RecordType *BaseType, const Expr *E); + unsigned collectBaseOffset(const RecordType *BaseType, + const RecordType *DerivedType); bool emitBuiltinBitCast(const CastExpr *E); protected: Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp =================================================================== --- clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -138,8 +138,20 @@ if (!this->visit(SubExpr)) return false; - return this->emitDerivedToBaseCasts(getRecordTy(SubExpr->getType()), - getRecordTy(CE->getType()), CE); + unsigned DerivedOffset = collectBaseOffset(getRecordTy(CE->getType()), + getRecordTy(SubExpr->getType())); + + return this->emitGetPtrBasePop(DerivedOffset, CE); + } + + case CK_BaseToDerived: { + if (!this->visit(SubExpr)) + return false; + + unsigned DerivedOffset = collectBaseOffset(getRecordTy(SubExpr->getType()), + getRecordTy(CE->getType())); + + return this->emitGetPtrDerivedPop(DerivedOffset, CE); } case CK_FloatingCast: { @@ -2124,13 +2136,15 @@ } template -bool ByteCodeExprGen::emitDerivedToBaseCasts( - const RecordType *DerivedType, const RecordType *BaseType, const Expr *E) { - // Pointer of derived type is already on the stack. +unsigned +ByteCodeExprGen::collectBaseOffset(const RecordType *BaseType, + const RecordType *DerivedType) { const auto *FinalDecl = cast(BaseType->getDecl()); const RecordDecl *CurDecl = DerivedType->getDecl(); const Record *CurRecord = getRecord(CurDecl); assert(CurDecl && FinalDecl); + + unsigned OffsetSum = 0; for (;;) { assert(CurRecord->getNumBases() > 0); // One level up @@ -2138,21 +2152,18 @@ const auto *BaseDecl = cast(B.Decl); if (BaseDecl == FinalDecl || BaseDecl->isDerivedFrom(FinalDecl)) { - // This decl will lead us to the final decl, so emit a base cast. - if (!this->emitGetPtrBasePop(B.Offset, E)) - return false; - + OffsetSum += B.Offset; CurRecord = B.R; CurDecl = BaseDecl; break; } } if (CurDecl == FinalDecl) - return true; + break; } - llvm_unreachable("Couldn't find the base class?"); - return false; + assert(OffsetSum > 0); + return OffsetSum; } /// When calling this, we have a pointer of the local-to-destroy Index: clang/lib/AST/Interp/Interp.h =================================================================== --- clang/lib/AST/Interp/Interp.h +++ clang/lib/AST/Interp/Interp.h @@ -67,8 +67,9 @@ bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr, CheckSubobjectKind CSK); -/// Checks if accessing a base of the given pointer is valid. -bool CheckBase(InterpState &S, CodePtr OpPC, const Pointer &Ptr); +/// Checks if accessing a base or derived record of the given pointer is valid. +bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr, + CheckSubobjectKind CSK); /// Checks if a pointer points to const storage. bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr); @@ -1078,11 +1079,21 @@ return true; } +inline bool GetPtrDerivedPop(InterpState &S, CodePtr OpPC, uint32_t Off) { + const Pointer &Ptr = S.Stk.pop(); + if (!CheckNull(S, OpPC, Ptr, CSK_Derived)) + return false; + if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Derived)) + return false; + S.Stk.push(Ptr.atFieldSub(Off)); + return true; +} + inline bool GetPtrBase(InterpState &S, CodePtr OpPC, uint32_t Off) { const Pointer &Ptr = S.Stk.peek(); if (!CheckNull(S, OpPC, Ptr, CSK_Base)) return false; - if (!CheckBase(S, OpPC, Ptr)) + if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base)) return false; S.Stk.push(Ptr.atField(Off)); return true; @@ -1092,7 +1103,7 @@ const Pointer &Ptr = S.Stk.pop(); if (!CheckNull(S, OpPC, Ptr, CSK_Base)) return false; - if (!CheckBase(S, OpPC, Ptr)) + if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base)) return false; S.Stk.push(Ptr.atField(Off)); return true; Index: clang/lib/AST/Interp/Interp.cpp =================================================================== --- clang/lib/AST/Interp/Interp.cpp +++ clang/lib/AST/Interp/Interp.cpp @@ -211,12 +211,13 @@ return false; } -bool CheckBase(InterpState &S, CodePtr OpPC, const Pointer &Ptr) { +bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr, + CheckSubobjectKind CSK) { if (!Ptr.isOnePastEnd()) return true; const SourceInfo &Loc = S.Current->getSource(OpPC); - S.FFDiag(Loc, diag::note_constexpr_past_end_subobject) << CSK_Base; + S.FFDiag(Loc, diag::note_constexpr_past_end_subobject) << CSK; return false; } Index: clang/lib/AST/Interp/Opcodes.td =================================================================== --- clang/lib/AST/Interp/Opcodes.td +++ clang/lib/AST/Interp/Opcodes.td @@ -291,6 +291,10 @@ let Args = [ArgUint32]; } +def GetPtrDerivedPop : Opcode { + let Args = [ArgUint32]; +} + // [Pointer] -> [Pointer] def GetPtrVirtBase : Opcode { // RecordDecl of base class. Index: clang/lib/AST/Interp/Pointer.h =================================================================== --- clang/lib/AST/Interp/Pointer.h +++ clang/lib/AST/Interp/Pointer.h @@ -100,6 +100,14 @@ return Pointer(Pointee, Field, Field); } + /// Subtract the given offset from the current Base and Offset + /// of the pointer. + Pointer atFieldSub(unsigned Off) const { + assert(Offset >= Off); + unsigned O = Offset - Off; + return Pointer(Pointee, O, O); + } + /// Restricts the scope of an array element pointer. Pointer narrow() const { // Null pointers cannot be narrowed. Index: clang/test/AST/Interp/records.cpp =================================================================== --- clang/test/AST/Interp/records.cpp +++ clang/test/AST/Interp/records.cpp @@ -586,6 +586,58 @@ static_assert(test() == 1); } +namespace BaseToDerived { +namespace A { + struct A {}; + struct B : A { int n; }; + struct C : B {}; + C c = {}; + constexpr C *pb = (C*)((A*)&c + 1); // expected-error {{must be initialized by a constant expression}} \ + // expected-note {{cannot access derived class of pointer past the end of object}} \ + // ref-error {{must be initialized by a constant expression}} \ + // ref-note {{cannot access derived class of pointer past the end of object}} +} +namespace B { + struct A {}; + struct Z {}; + struct B : Z, A { + int n; + constexpr B() : n(10) {} + }; + struct C : B { + constexpr C() : B() {} + }; + + constexpr C c = {}; + constexpr const A *pa = &c; + constexpr const C *cp = (C*)pa; + constexpr const B *cb = (B*)cp; + + static_assert(cb->n == 10); + static_assert(cp->n == 10); +} + +namespace C { + struct Base { int *a; }; + struct Base2 : Base { int f[12]; }; + + struct Middle1 { int b[3]; }; + struct Middle2 : Base2 { char c; }; + struct Middle3 : Middle2 { char g[3]; }; + struct Middle4 { int f[3]; }; + struct Middle5 : Middle4, Middle3 { char g2[3]; }; + + struct NotQuiteDerived : Middle1, Middle5 { bool d; }; + struct Derived : NotQuiteDerived { int e; }; + + constexpr NotQuiteDerived NQD1 = {}; + + constexpr Middle5 *M4 = (Middle5*)((Base2*)&NQD1); + static_assert(M4->a == nullptr); + static_assert(M4->g2[0] == 0); +} +} + namespace VirtualDtors { class A {