diff --git a/clang/lib/CodeGen/CGExprCXX.cpp b/clang/lib/CodeGen/CGExprCXX.cpp --- a/clang/lib/CodeGen/CGExprCXX.cpp +++ b/clang/lib/CodeGen/CGExprCXX.cpp @@ -382,7 +382,7 @@ const CXXRecordDecl *RD; std::tie(VTable, RD) = CGM.getCXXABI().LoadVTablePtr(*this, This.getAddress(), - MD->getParent()); + CalleeDecl->getParent()); EmitVTablePtrCheckForCall(RD, VTable, CFITCK_NVCall, CE->getBeginLoc()); } diff --git a/clang/test/CodeGenCXX/cfi-multiple-inheritance.cpp b/clang/test/CodeGenCXX/cfi-multiple-inheritance.cpp new file mode 100644 --- /dev/null +++ b/clang/test/CodeGenCXX/cfi-multiple-inheritance.cpp @@ -0,0 +1,31 @@ +// Test that correct vtable ptr and type metadata are passed to llvm.type.test +// Related to Bugzilla 43390. + +// RUN: %clang_cc1 -triple x86_64-unknown-linux -fvisibility hidden -std=c++11 -fsanitize=cfi-nvcall -emit-llvm -o - %s | FileCheck %s + +class A1 { +public: + virtual int f1() = 0; +}; + +class A2 { +public: + virtual int f2() = 0; +}; + + +class B : public A1, public A2 { +public: + int f2() final { return 1; } + int f1() final { return 2; } +}; + +// CHECK-LABEL: define hidden i32 @_Z3foov +int foo() { + B b; + return static_cast(&b)->f2(); + // CHECK: [[P:%[^ ]*]] = bitcast %class.B* %b to i8** + // CHECK: [[V:%[^ ]*]] = load i8*, i8** [[P]], align 8 + // CHECK: call i1 @llvm.type.test(i8* [[V]], metadata !"_ZTS1B") + // CHECK: call i1 @llvm.type.test(i8* [[V]], metadata !"all-vtables") +} diff --git a/compiler-rt/test/cfi/multiple-inheritance2.cpp b/compiler-rt/test/cfi/multiple-inheritance2.cpp new file mode 100644 --- /dev/null +++ b/compiler-rt/test/cfi/multiple-inheritance2.cpp @@ -0,0 +1,38 @@ +// Test that virtual functions of the derived class can be called through +// pointers of both base classes without CFI errors. +// Related to Bugzilla 43390. + +// RUN: %clangxx_cfi -o %t1 %s +// RUN: %run %t1 2>&1 | FileCheck --check-prefix=CFI %s + +// CFI: In f1 +// CFI: In f2 +// CFI-NOT: control flow integrity check + +// REQUIRES: cxxabi + +#include + +class A1 { +public: + virtual void f1() = 0; +}; + +class A2 { +public: + virtual void f2() = 0; +}; + + +class B : public A1, public A2 { +public: + void f2() final { fprintf(stderr, "In f2\n"); } + void f1() final { fprintf(stderr, "In f1\n"); } +}; + +int main() { + B b; + + static_cast(&b)->f1(); + static_cast(&b)->f2(); +}