diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.h b/clang/lib/AST/Interp/ByteCodeExprGen.h --- a/clang/lib/AST/Interp/ByteCodeExprGen.h +++ b/clang/lib/AST/Interp/ByteCodeExprGen.h @@ -267,6 +267,8 @@ bool emitRecordDestruction(const Descriptor *Desc); bool emitDerivedToBaseCasts(const RecordType *DerivedType, const RecordType *BaseType, const Expr *E); + unsigned collectBaseOffset(const RecordType *BaseType, + const RecordType *DerivedType); protected: /// Variable to storage mapping. diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -92,8 +92,20 @@ if (!this->visit(SubExpr)) return false; - return this->emitDerivedToBaseCasts(getRecordTy(SubExpr->getType()), - getRecordTy(CE->getType()), CE); + unsigned DerivedOffset = collectBaseOffset(getRecordTy(CE->getType()), + getRecordTy(SubExpr->getType())); + + return this->emitGetPtrBasePop(DerivedOffset, CE); + } + + case CK_BaseToDerived: { + if (!this->visit(SubExpr)) + return false; + + unsigned DerivedOffset = collectBaseOffset(getRecordTy(SubExpr->getType()), + getRecordTy(CE->getType())); + + return this->emitGetPtrDerivedPop(DerivedOffset, CE); } case CK_FloatingCast: { @@ -2262,13 +2274,15 @@ } template -bool ByteCodeExprGen::emitDerivedToBaseCasts( - const RecordType *DerivedType, const RecordType *BaseType, const Expr *E) { - // Pointer of derived type is already on the stack. +unsigned +ByteCodeExprGen::collectBaseOffset(const RecordType *BaseType, + const RecordType *DerivedType) { const auto *FinalDecl = cast(BaseType->getDecl()); const RecordDecl *CurDecl = DerivedType->getDecl(); const Record *CurRecord = getRecord(CurDecl); assert(CurDecl && FinalDecl); + + unsigned OffsetSum = 0; for (;;) { assert(CurRecord->getNumBases() > 0); // One level up @@ -2276,21 +2290,18 @@ const auto *BaseDecl = cast(B.Decl); if (BaseDecl == FinalDecl || BaseDecl->isDerivedFrom(FinalDecl)) { - // This decl will lead us to the final decl, so emit a base cast. - if (!this->emitGetPtrBasePop(B.Offset, E)) - return false; - + OffsetSum += B.Offset; CurRecord = B.R; CurDecl = BaseDecl; break; } } if (CurDecl == FinalDecl) - return true; + break; } - llvm_unreachable("Couldn't find the base class?"); - return false; + assert(OffsetSum > 0); + return OffsetSum; } /// When calling this, we have a pointer of the local-to-destroy diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h --- a/clang/lib/AST/Interp/Interp.h +++ b/clang/lib/AST/Interp/Interp.h @@ -67,6 +67,10 @@ bool CheckRange(InterpState &S, CodePtr OpPC, const Pointer &Ptr, CheckSubobjectKind CSK); +/// Checks if accessing a base or derived record of the given pointer is valid. +bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr, + CheckSubobjectKind CSK); + /// Checks if a pointer points to const storage. bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr); @@ -1157,10 +1161,22 @@ return true; } +inline bool GetPtrDerivedPop(InterpState &S, CodePtr OpPC, uint32_t Off) { + const Pointer &Ptr = S.Stk.pop(); + if (!CheckNull(S, OpPC, Ptr, CSK_Derived)) + return false; + if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Derived)) + return false; + S.Stk.push(Ptr.atFieldSub(Off)); + return true; +} + inline bool GetPtrBase(InterpState &S, CodePtr OpPC, uint32_t Off) { const Pointer &Ptr = S.Stk.peek(); if (!CheckNull(S, OpPC, Ptr, CSK_Base)) return false; + if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base)) + return false; S.Stk.push(Ptr.atField(Off)); return true; } @@ -1169,6 +1185,8 @@ const Pointer &Ptr = S.Stk.pop(); if (!CheckNull(S, OpPC, Ptr, CSK_Base)) return false; + if (!CheckBaseDerived(S, OpPC, Ptr, CSK_Base)) + return false; S.Stk.push(Ptr.atField(Off)); return true; } diff --git a/clang/lib/AST/Interp/Interp.cpp b/clang/lib/AST/Interp/Interp.cpp --- a/clang/lib/AST/Interp/Interp.cpp +++ b/clang/lib/AST/Interp/Interp.cpp @@ -213,6 +213,16 @@ return false; } +bool CheckBaseDerived(InterpState &S, CodePtr OpPC, const Pointer &Ptr, + CheckSubobjectKind CSK) { + if (!Ptr.isOnePastEnd()) + return true; + + const SourceInfo &Loc = S.Current->getSource(OpPC); + S.FFDiag(Loc, diag::note_constexpr_past_end_subobject) << CSK; + return false; +} + bool CheckConst(InterpState &S, CodePtr OpPC, const Pointer &Ptr) { assert(Ptr.isLive() && "Pointer is not live"); if (!Ptr.isConst()) diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td --- a/clang/lib/AST/Interp/Opcodes.td +++ b/clang/lib/AST/Interp/Opcodes.td @@ -293,6 +293,10 @@ let Args = [ArgUint32]; } +def GetPtrDerivedPop : Opcode { + let Args = [ArgUint32]; +} + // [Pointer] -> [Pointer] def GetPtrVirtBase : Opcode { // RecordDecl of base class. diff --git a/clang/lib/AST/Interp/Pointer.h b/clang/lib/AST/Interp/Pointer.h --- a/clang/lib/AST/Interp/Pointer.h +++ b/clang/lib/AST/Interp/Pointer.h @@ -109,6 +109,14 @@ return Pointer(Pointee, Field, Field); } + /// Subtract the given offset from the current Base and Offset + /// of the pointer. + Pointer atFieldSub(unsigned Off) const { + assert(Offset >= Off); + unsigned O = Offset - Off; + return Pointer(Pointee, O, O); + } + /// Restricts the scope of an array element pointer. Pointer narrow() const { // Null pointers cannot be narrowed. diff --git a/clang/test/AST/Interp/records.cpp b/clang/test/AST/Interp/records.cpp --- a/clang/test/AST/Interp/records.cpp +++ b/clang/test/AST/Interp/records.cpp @@ -625,6 +625,58 @@ // ref-note {{in call to 'testS()'}} } +namespace BaseToDerived { +namespace A { + struct A {}; + struct B : A { int n; }; + struct C : B {}; + C c = {}; + constexpr C *pb = (C*)((A*)&c + 1); // expected-error {{must be initialized by a constant expression}} \ + // expected-note {{cannot access derived class of pointer past the end of object}} \ + // ref-error {{must be initialized by a constant expression}} \ + // ref-note {{cannot access derived class of pointer past the end of object}} +} +namespace B { + struct A {}; + struct Z {}; + struct B : Z, A { + int n; + constexpr B() : n(10) {} + }; + struct C : B { + constexpr C() : B() {} + }; + + constexpr C c = {}; + constexpr const A *pa = &c; + constexpr const C *cp = (C*)pa; + constexpr const B *cb = (B*)cp; + + static_assert(cb->n == 10); + static_assert(cp->n == 10); +} + +namespace C { + struct Base { int *a; }; + struct Base2 : Base { int f[12]; }; + + struct Middle1 { int b[3]; }; + struct Middle2 : Base2 { char c; }; + struct Middle3 : Middle2 { char g[3]; }; + struct Middle4 { int f[3]; }; + struct Middle5 : Middle4, Middle3 { char g2[3]; }; + + struct NotQuiteDerived : Middle1, Middle5 { bool d; }; + struct Derived : NotQuiteDerived { int e; }; + + constexpr NotQuiteDerived NQD1 = {}; + + constexpr Middle5 *M4 = (Middle5*)((Base2*)&NQD1); + static_assert(M4->a == nullptr); + static_assert(M4->g2[0] == 0); +} +} + namespace VirtualDtors { class A {