diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -1166,7 +1166,10 @@ // We traverse this in the type case as well, but how is it not reached through // the pointee type? DEF_TRAVERSE_TYPELOC(MemberPointerType, { - TRY_TO(TraverseType(QualType(TL.getTypePtr()->getClass(), 0))); + if (auto *TSI = TL.getClassTInfo()) + TRY_TO(TraverseTypeLoc(TSI->getTypeLoc())); + else + TRY_TO(TraverseType(QualType(TL.getTypePtr()->getClass(), 0))); TRY_TO(TraverseTypeLoc(TL.getPointeeLoc())); }) diff --git a/clang/unittests/Tooling/CMakeLists.txt b/clang/unittests/Tooling/CMakeLists.txt --- a/clang/unittests/Tooling/CMakeLists.txt +++ b/clang/unittests/Tooling/CMakeLists.txt @@ -42,6 +42,7 @@ RecursiveASTVisitorTests/LambdaDefaultCapture.cpp RecursiveASTVisitorTests/LambdaExpr.cpp RecursiveASTVisitorTests/LambdaTemplateParams.cpp + RecursiveASTVisitorTests/MemberPointerTypeLoc.cpp RecursiveASTVisitorTests/NestedNameSpecifiers.cpp RecursiveASTVisitorTests/ParenExpr.cpp RecursiveASTVisitorTests/TemplateArgumentLocTraverser.cpp diff --git a/clang/unittests/Tooling/RecursiveASTVisitorTests/MemberPointerTypeLoc.cpp b/clang/unittests/Tooling/RecursiveASTVisitorTests/MemberPointerTypeLoc.cpp new file mode 100644 --- /dev/null +++ b/clang/unittests/Tooling/RecursiveASTVisitorTests/MemberPointerTypeLoc.cpp @@ -0,0 +1,37 @@ +//===- unittest/Tooling/RecursiveASTVisitorTests/MemberPointerTypeLoc.cpp -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestVisitor.h" + +using namespace clang; + +namespace { + +class MemberPointerTypeLocVisitor + : public ExpectedLocationVisitor { +public: + bool VisitRecordTypeLoc(RecordTypeLoc RTL) { + if (!RTL) + return true; + Match(RTL.getDecl()->getName(), RTL.getNameLoc()); + return true; + } +}; + +TEST(RecursiveASTVisitor, VisitTypeLocInMemberPointerTypeLoc) { + MemberPointerTypeLocVisitor Visitor; + Visitor.ExpectMatch("Bar", 4, 36); + EXPECT_TRUE(Visitor.runOver(R"cpp( + class Bar { void func(int); }; + class Foo { + void bind(const char*, void(Bar::*Foo)(int)) {} + }; + )cpp")); +} + +} // end anonymous namespace