diff --git a/clang/lib/AST/Interp/ByteCodeExprGen.cpp b/clang/lib/AST/Interp/ByteCodeExprGen.cpp --- a/clang/lib/AST/Interp/ByteCodeExprGen.cpp +++ b/clang/lib/AST/Interp/ByteCodeExprGen.cpp @@ -1711,12 +1711,24 @@ assert(HasRVO == Func->hasRVO()); + bool HasQualifier = false; + if (const auto *ME = dyn_cast(E->getCallee())) + HasQualifier = ME->hasQualifier(); + + bool IsVirtual = false; + if (const auto *MD = dyn_cast(FuncDecl)) + IsVirtual = MD->isVirtual(); + // In any case call the function. The return value will end up on the stack // and if the function has RVO, we already have the pointer on the stack to // write the result into. - if (!this->emitCall(Func, E)) - return false; - + if (IsVirtual && !HasQualifier) { + if (!this->emitCallVirt(Func, E)) + return false; + } else { + if (!this->emitCall(Func, E)) + return false; + } } else { // Indirect call. Visit the callee, which will leave a FunctionPointer on // the stack. Cleanup of the returned value if necessary will be done after diff --git a/clang/lib/AST/Interp/Context.h b/clang/lib/AST/Interp/Context.h --- a/clang/lib/AST/Interp/Context.h +++ b/clang/lib/AST/Interp/Context.h @@ -63,6 +63,11 @@ /// Classifies an expression. std::optional classify(QualType T) const; + const CXXMethodDecl * + getOverridingFunction(const CXXRecordDecl *DynamicDecl, + const CXXRecordDecl *StaticDecl, + const CXXMethodDecl *InitialFunction) const; + private: /// Runs a function. bool Run(State &Parent, Function *Func, APValue &Result); diff --git a/clang/lib/AST/Interp/Context.cpp b/clang/lib/AST/Interp/Context.cpp --- a/clang/lib/AST/Interp/Context.cpp +++ b/clang/lib/AST/Interp/Context.cpp @@ -158,3 +158,38 @@ }); return false; } + +// TODO: Virtual bases? +const CXXMethodDecl * +Context::getOverridingFunction(const CXXRecordDecl *DynamicDecl, + const CXXRecordDecl *StaticDecl, + const CXXMethodDecl *InitialFunction) const { + + const CXXRecordDecl *CurRecord = DynamicDecl; + const CXXMethodDecl *FoundFunction = InitialFunction; + for (;;) { + const CXXMethodDecl *Overrider = + FoundFunction->getCorrespondingMethodDeclaredInClass(CurRecord, false); + if (Overrider) + return Overrider; + + // Common case of only one base class. + if (CurRecord->getNumBases() == 1) { + CurRecord = CurRecord->bases_begin()->getType()->getAsCXXRecordDecl(); + continue; + } + + // Otherwise, go to the base class that will lead to the StaticDecl. + for (const CXXBaseSpecifier &Spec : CurRecord->bases()) { + const CXXRecordDecl *Base = Spec.getType()->getAsCXXRecordDecl(); + if (Base == StaticDecl || Base->isDerivedFrom(StaticDecl)) { + CurRecord = Base; + break; + } + } + } + + llvm_unreachable( + "Couldn't find an overriding function in the class hierarchy?"); + return nullptr; +} diff --git a/clang/lib/AST/Interp/Descriptor.cpp b/clang/lib/AST/Interp/Descriptor.cpp --- a/clang/lib/AST/Interp/Descriptor.cpp +++ b/clang/lib/AST/Interp/Descriptor.cpp @@ -274,6 +274,8 @@ return E->getType(); if (auto *D = asValueDecl()) return D->getType(); + if (auto *T = dyn_cast(asDecl())) + return QualType(T->getTypeForDecl(), 0); llvm_unreachable("Invalid descriptor type"); } diff --git a/clang/lib/AST/Interp/Function.h b/clang/lib/AST/Interp/Function.h --- a/clang/lib/AST/Interp/Function.h +++ b/clang/lib/AST/Interp/Function.h @@ -137,6 +137,13 @@ /// Checks if the function is a destructor. bool isDestructor() const { return isa(F); } + /// Returns the parent record decl, if any. + const CXXRecordDecl *getParentDecl() const { + if (const auto *MD = dyn_cast(F)) + return MD->getParent(); + return nullptr; + } + /// Checks if the function is fully done compiling. bool isFullyCompiled() const { return IsFullyCompiled; } diff --git a/clang/lib/AST/Interp/Interp.h b/clang/lib/AST/Interp/Interp.h --- a/clang/lib/AST/Interp/Interp.h +++ b/clang/lib/AST/Interp/Interp.h @@ -1622,6 +1622,36 @@ return false; } +inline bool CallVirt(InterpState &S, CodePtr OpPC, const Function *Func) { + assert(Func->hasThisPointer()); + assert(Func->isVirtual()); + size_t ThisOffset = + Func->getArgSize() + (Func->hasRVO() ? primSize(PT_Ptr) : 0); + Pointer &ThisPtr = S.Stk.peek(ThisOffset); + + const CXXRecordDecl *DynamicDecl = + ThisPtr.getDeclDesc()->getType()->getAsCXXRecordDecl(); + const auto *StaticDecl = cast(Func->getParentDecl()); + const auto *InitialFunction = cast(Func->getDecl()); + const CXXMethodDecl *Overrider = S.getContext().getOverridingFunction( + DynamicDecl, StaticDecl, InitialFunction); + + if (Overrider != InitialFunction) { + Func = S.P.getFunction(Overrider); + + const CXXRecordDecl *ThisFieldDecl = + ThisPtr.getFieldDesc()->getType()->getAsCXXRecordDecl(); + if (Func->getParentDecl()->isDerivedFrom(ThisFieldDecl)) { + // If the function we call is further DOWN the hierarchy than the + // FieldDesc of our pointer, just get the DeclDesc instead, which + // is the furthest we might go up in the hierarchy. + ThisPtr = ThisPtr.getDeclPtr(); + } + } + + return Call(S, OpPC, Func); +} + inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func) { auto NewFrame = std::make_unique(S, Func, PC); diff --git a/clang/lib/AST/Interp/InterpState.h b/clang/lib/AST/Interp/InterpState.h --- a/clang/lib/AST/Interp/InterpState.h +++ b/clang/lib/AST/Interp/InterpState.h @@ -89,6 +89,8 @@ return M ? M->getSource(F, PC) : F->getSource(PC); } + Context &getContext() const { return Ctx; } + private: /// AST Walker state. State &Parent; diff --git a/clang/lib/AST/Interp/Opcodes.td b/clang/lib/AST/Interp/Opcodes.td --- a/clang/lib/AST/Interp/Opcodes.td +++ b/clang/lib/AST/Interp/Opcodes.td @@ -181,6 +181,11 @@ let Types = []; } +def CallVirt : Opcode { + let Args = [ArgFunction]; + let Types = []; +} + def CallBI : Opcode { let Args = [ArgFunction]; let Types = []; diff --git a/clang/lib/AST/Interp/Pointer.h b/clang/lib/AST/Interp/Pointer.h --- a/clang/lib/AST/Interp/Pointer.h +++ b/clang/lib/AST/Interp/Pointer.h @@ -200,6 +200,8 @@ /// Returns the type of the innermost field. QualType getType() const { return getFieldDesc()->getType(); } + Pointer getDeclPtr() const { return Pointer(Pointee); } + /// Returns the element size of the innermost field. size_t elemSize() const { if (Base == RootPtrMark) diff --git a/clang/test/AST/Interp/records.cpp b/clang/test/AST/Interp/records.cpp --- a/clang/test/AST/Interp/records.cpp +++ b/clang/test/AST/Interp/records.cpp @@ -1,8 +1,10 @@ // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -verify %s // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++14 -verify %s +// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++20 -verify %s // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -triple i686 -verify %s // RUN: %clang_cc1 -verify=ref %s // RUN: %clang_cc1 -verify=ref -std=c++14 %s +// RUN: %clang_cc1 -verify=ref -std=c++20 %s // RUN: %clang_cc1 -verify=ref -triple i686 %s struct BoolPair { @@ -380,6 +382,7 @@ }; namespace DeriveFailures { +#if __cplusplus < 202002L struct Base { // ref-note 2{{declared here}} expected-note {{declared here}} int Val; }; @@ -397,10 +400,12 @@ // ref-note {{declared here}} \ // expected-error {{must be initialized by a constant expression}} \ // expected-note {{in call to 'Derived(12)'}} + static_assert(D.Val == 0, ""); // ref-error {{not an integral constant expression}} \ // ref-note {{initializer of 'D' is not a constant expression}} \ // expected-error {{not an integral constant expression}} \ // expected-note {{read of object outside its lifetime}} +#endif struct AnotherBase { int Val; @@ -488,3 +493,201 @@ //static_assert(b.a.m == 100, ""); //static_assert(b.a.f == 100, ""); } + +#if __cplusplus >= 202002L +namespace VirtualCalls { +namespace Obvious { + + class A { + public: + constexpr A(){} + constexpr virtual int foo() { + return 3; + } + }; + class B : public A { + public: + constexpr int foo() override { + return 6; + } + }; + + constexpr int getFooB(bool b) { + A *a; + A myA; + B myB; + + if (b) + a = &myA; + else + a = &myB; + + return a->foo(); + } + static_assert(getFooB(true) == 3, ""); + static_assert(getFooB(false) == 6, ""); +} + +namespace MultipleBases { + class A { + public: + constexpr virtual int getInt() const { return 10; } + }; + class B { + public: + }; + class C : public A, public B { + public: + constexpr int getInt() const override { return 20; } + }; + + constexpr int callGetInt(const A& a) { return a.getInt(); } + static_assert(callGetInt(C()) == 20, ""); + static_assert(callGetInt(A()) == 10, ""); +} + +namespace Destructors { + class Base { + public: + int i; + constexpr Base(int &i) : i(i) {i++;} + constexpr virtual ~Base() {i--;} + }; + + class Derived : public Base { + public: + constexpr Derived(int &i) : Base(i) {} + constexpr virtual ~Derived() {i--;} + }; + + constexpr int test() { + int i = 0; + Derived d(i); + return i; + } + static_assert(test() == 1); +} + + +namespace VirtualDtors { + class A { + public: + unsigned &v; + constexpr A(unsigned &v) : v(v) {} + constexpr virtual ~A() { + v |= (1 << 0); + } + }; + class B : public A { + public: + constexpr B(unsigned &v) : A(v) {} + constexpr virtual ~B() { + v |= (1 << 1); + } + }; + class C : public B { + public: + constexpr C(unsigned &v) : B(v) {} + constexpr virtual ~C() { + v |= (1 << 2); + } + }; + + constexpr bool foo() { + unsigned a = 0; + { + C c(a); + } + return ((a & (1 << 0)) && (a & (1 << 1)) && (a & (1 << 2))); + } + + static_assert(foo()); + + +}; + +namespace QualifiedCalls { + class A { + public: + constexpr virtual int foo() const { + return 5; + } + }; + class B : public A {}; + class C : public B { + public: + constexpr int foo() const override { + return B::foo(); // B doesn't have a foo(), so this should call A::foo(). + } + constexpr int foo2() const { + return this->A::foo(); + } + }; + constexpr C c; + static_assert(c.foo() == 5); + static_assert(c.foo2() == 5); + + + struct S { + int _c = 0; + virtual constexpr int foo() const { return 1; } + }; + + struct SS : S { + int a; + constexpr SS() { + a = S::foo(); + } + constexpr int foo() const override { + return S::foo(); + } + }; + + constexpr SS ss; + static_assert(ss.a == 1); +} + +namespace CtorDtor { + struct Base { + int i = 0; + int j = 0; + + constexpr Base() : i(func()) { + j = func(); + } + constexpr Base(int i) : i(i), j(i) {} + + constexpr virtual int func() const { return 1; } + }; + + struct Derived : Base { + constexpr Derived() {} + constexpr Derived(int i) : Base(i) {} + constexpr int func() const override { return 2; } + }; + + struct Derived2 : Derived { + constexpr Derived2() : Derived(func()) {} // ref-note {{subexpression not valid in a constant expression}} + constexpr int func() const override { return 3; } + }; + + constexpr Base B; + static_assert(B.i == 1 && B.j == 1, ""); + + constexpr Derived D; + static_assert(D.i == 1, ""); // expected-error {{static assertion failed}} \ + // expected-note {{2 == 1}} + static_assert(D.j == 1, ""); // expected-error {{static assertion failed}} \ + // expected-note {{2 == 1}} + + constexpr Derived2 D2; // ref-error {{must be initialized by a constant expression}} \ + // ref-note {{in call to 'Derived2()'}} \ + // ref-note 2{{declared here}} + static_assert(D2.i == 3, ""); // ref-error {{not an integral constant expression}} \ + // ref-note {{initializer of 'D2' is not a constant expression}} + static_assert(D2.j == 3, ""); // ref-error {{not an integral constant expression}} \ + // ref-note {{initializer of 'D2' is not a constant expression}} + +} +}; +#endif