Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp =================================================================== --- clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -2047,6 +2047,8 @@ if (!this->emitCall(DtorFunc, SourceInfo{})) return false; } + if (Dtor->isVirtual()) + return this->emitPopPtr(SourceInfo{}); } for (const Record::Base &Base : llvm::reverse(R->bases())) { Index: clang/lib/AST/Interp/Context.h =================================================================== --- clang/lib/AST/Interp/Context.h +++ clang/lib/AST/Interp/Context.h @@ -61,6 +61,11 @@ /// Classifies an expression. std::optional classify(QualType T) const; + const CXXMethodDecl * + getOverridingFunction(const CXXRecordDecl *DynamicDecl, + const CXXRecordDecl *StaticDecl, + const CXXMethodDecl *InitialFunction); + private: /// Runs a function. bool Run(State &Parent, Function *Func, APValue &Result); Index: clang/lib/AST/Interp/Context.cpp =================================================================== --- clang/lib/AST/Interp/Context.cpp +++ clang/lib/AST/Interp/Context.cpp @@ -152,3 +152,38 @@ }); return false; } + +// TODO: Virtual bases? +const CXXMethodDecl * +Context::getOverridingFunction(const CXXRecordDecl *DynamicDecl, + const CXXRecordDecl *StaticDecl, + const CXXMethodDecl *InitialFunction) { + + const CXXRecordDecl *CurRecord = DynamicDecl; + const CXXMethodDecl *FoundFunction = InitialFunction; + for (;;) { + const CXXMethodDecl *Overrider = + FoundFunction->getCorrespondingMethodDeclaredInClass(CurRecord, false); + if (Overrider) + return Overrider; + + // Common case of only one base class. + if (CurRecord->getNumBases() == 1) { + CurRecord = CurRecord->bases_begin()->getType()->getAsCXXRecordDecl(); + continue; + } + + // Otherwise, go to the base class that will lead to the StaticDecl. + for (const CXXBaseSpecifier &Spec : CurRecord->bases()) { + const CXXRecordDecl *Base = Spec.getType()->getAsCXXRecordDecl(); + if (Base == StaticDecl || Base->isDerivedFrom(StaticDecl)) { + CurRecord = Base; + break; + } + } + } + + llvm_unreachable( + "Couldn't find an overriding function in the class hierarchy?"); + return nullptr; +} Index: clang/lib/AST/Interp/Function.h =================================================================== --- clang/lib/AST/Interp/Function.h +++ clang/lib/AST/Interp/Function.h @@ -130,6 +130,13 @@ /// Checks if the function is a constructor. bool isConstructor() const { return isa(F); } + /// Returns the parent record decl, if any. + const CXXRecordDecl *getParentDecl() const { + if (const auto *MD = dyn_cast(F)) + return MD->getParent(); + return nullptr; + } + /// Checks if the function is fully done compiling. bool isFullyCompiled() const { return IsFullyCompiled; } Index: clang/lib/AST/Interp/Interp.h =================================================================== --- clang/lib/AST/Interp/Interp.h +++ clang/lib/AST/Interp/Interp.h @@ -1509,6 +1509,22 @@ if (S.checkingPotentialConstantExpression()) return false; + + // For a virtual call, we need to get the right function here. + if (Func->isVirtual()) { + // Our ThisPtr has the decl of the right type at this point, + // so we just need to find the function to call. + const CXXRecordDecl *DynamicDecl = + ThisPtr.getDeclDesc()->getType()->getAsCXXRecordDecl(); + const CXXRecordDecl *StaticDecl = + cast(Func->getParentDecl()); + const CXXMethodDecl *InitialFunction = + cast(Func->getDecl()); + const CXXMethodDecl *Overrider = S.getContext().getOverridingFunction( + DynamicDecl, StaticDecl, InitialFunction); + if (Overrider != InitialFunction) + Func = S.P.getFunction(Overrider); + } } if (!CheckCallable(S, PC, Func)) Index: clang/lib/AST/Interp/InterpState.h =================================================================== --- clang/lib/AST/Interp/InterpState.h +++ clang/lib/AST/Interp/InterpState.h @@ -86,6 +86,8 @@ return M ? M->getSource(F, PC) : F->getSource(PC); } + Context &getContext() const { return Ctx; } + private: /// AST Walker state. State &Parent; Index: clang/test/AST/Interp/records.cpp =================================================================== --- clang/test/AST/Interp/records.cpp +++ clang/test/AST/Interp/records.cpp @@ -1,8 +1,10 @@ // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -verify %s // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++14 -verify %s +// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++20 -verify %s // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -triple i686 -verify %s // RUN: %clang_cc1 -verify=ref %s // RUN: %clang_cc1 -verify=ref -std=c++14 %s +// RUN: %clang_cc1 -verify=ref -std=c++20 %s // RUN: %clang_cc1 -verify=ref -triple i686 %s struct BoolPair { @@ -286,6 +288,7 @@ }; namespace DeriveFailures { +#if __cplusplus < 202002L struct Base { // ref-note 2{{declared here}} int Val; }; @@ -301,10 +304,12 @@ // ref-note {{in call to 'Derived(12)'}} \ // ref-note {{declared here}} \ // expected-error {{must be initialized by a constant expression}} + static_assert(D.Val == 0, ""); // ref-error {{not an integral constant expression}} \ // ref-note {{initializer of 'D' is not a constant expression}} \ // expected-error {{not an integral constant expression}} \ // expected-note {{read of object outside its lifetime}} +#endif struct AnotherBase { int Val; @@ -354,3 +359,82 @@ static_assert(getS(true).a == 12, ""); static_assert(getS(false).a == 13, ""); }; + +#if __cplusplus >= 202002L +namespace VirtualCalls { +namespace Obvious { + + class A { + public: + constexpr A(){} + constexpr virtual int foo() { + return 3; + } + }; + class B : public A { + public: + constexpr int foo() override { + return 6; + } + }; + + constexpr int getFooB(bool b) { + A *a; + A myA; + B myB; + + if (b) + a = &myA; + else + a = &myB; + + return a->foo(); + } + static_assert(getFooB(true) == 3, ""); + static_assert(getFooB(false) == 6, ""); +} + +namespace MultipleBases { + class A { + public: + constexpr virtual int getInt() const { return 10; } + }; + class B { + public: + }; + class C : public A, public B { + public: + constexpr int getInt() const override { return 20; } + }; + + constexpr int callGetInt(const A& a) { return a.getInt(); } + static_assert(callGetInt(C()) == 20, ""); + static_assert(callGetInt(A()) == 10, ""); +} + +namespace Destructors { + class Base { + public: + int i; + constexpr Base(int &i) : i(i) {i++;} + constexpr virtual ~Base() {i--;} + }; + + class Derived : public Base { + public: + constexpr Derived(int &i) : Base(i) {} + constexpr virtual ~Derived() {i--;} + }; + + constexpr int test() { + int i = 0; + Derived d(i); + return i; + } + static_assert(test() == 1); +} + + + +}; +#endif