Index: cfe/trunk/include/clang/AST/RecursiveASTVisitor.h =================================================================== --- cfe/trunk/include/clang/AST/RecursiveASTVisitor.h +++ cfe/trunk/include/clang/AST/RecursiveASTVisitor.h @@ -722,12 +722,6 @@ break; #include "clang/AST/DeclNodes.inc" } - - // Visit any attributes attached to this declaration. - for (auto *I : D->attrs()) { - if (!getDerived().TraverseAttr(I)) - return false; - } return true; } @@ -1407,6 +1401,11 @@ { CODE; } \ if (ReturnValue && ShouldVisitChildren) \ TRY_TO(TraverseDeclContextHelper(dyn_cast(D))); \ + if (ReturnValue) { \ + /* Visit any attributes attached to this declaration. */ \ + for (auto *I : D->attrs()) \ + TRY_TO(getDerived().TraverseAttr(I)); \ + } \ if (ReturnValue && getDerived().shouldTraversePostOrder()) \ TRY_TO(WalkUpFrom##DECL(D)); \ return ReturnValue; \ Index: cfe/trunk/unittests/AST/CMakeLists.txt =================================================================== --- cfe/trunk/unittests/AST/CMakeLists.txt +++ cfe/trunk/unittests/AST/CMakeLists.txt @@ -26,6 +26,7 @@ Language.cpp NamedDeclPrinterTest.cpp OMPStructuredBlockTest.cpp + RecursiveASTVisitorTest.cpp SourceLocationTest.cpp StmtPrinterTest.cpp StructuralEquivalenceTest.cpp Index: cfe/trunk/unittests/AST/RecursiveASTVisitorTest.cpp =================================================================== --- cfe/trunk/unittests/AST/RecursiveASTVisitorTest.cpp +++ cfe/trunk/unittests/AST/RecursiveASTVisitorTest.cpp @@ -0,0 +1,106 @@ +//===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "clang/AST/RecursiveASTVisitor.h" +#include "clang/AST/ASTConsumer.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/Attr.h" +#include "clang/Frontend/FrontendAction.h" +#include "clang/Tooling/Tooling.h" +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/STLExtras.h" +#include "gmock/gmock-generated-matchers.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include + +using namespace clang; +using ::testing::ElementsAre; + +namespace { +class ProcessASTAction : public clang::ASTFrontendAction { +public: + ProcessASTAction(llvm::unique_function Process) + : Process(std::move(Process)) { + assert(this->Process); + } + + std::unique_ptr CreateASTConsumer(CompilerInstance &CI, + StringRef InFile) { + class Consumer : public ASTConsumer { + public: + Consumer(llvm::function_ref Process) + : Process(Process) {} + + void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); } + + private: + llvm::function_ref Process; + }; + + return llvm::make_unique(Process); + } + +private: + llvm::unique_function Process; +}; + +enum class VisitEvent { + StartTraverseFunction, + EndTraverseFunction, + StartTraverseAttr, + EndTraverseAttr +}; + +class CollectInterestingEvents + : public RecursiveASTVisitor { +public: + bool TraverseFunctionDecl(FunctionDecl *D) { + Events.push_back(VisitEvent::StartTraverseFunction); + bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D); + Events.push_back(VisitEvent::EndTraverseFunction); + + return Ret; + } + + bool TraverseAttr(Attr *A) { + Events.push_back(VisitEvent::StartTraverseAttr); + bool Ret = RecursiveASTVisitor::TraverseAttr(A); + Events.push_back(VisitEvent::EndTraverseAttr); + + return Ret; + } + + std::vector takeEvents() && { return std::move(Events); } + +private: + std::vector Events; +}; + +std::vector collectEvents(llvm::StringRef Code) { + CollectInterestingEvents Visitor; + clang::tooling::runToolOnCode( + new ProcessASTAction( + [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }), + Code); + return std::move(Visitor).takeEvents(); +} +} // namespace + +TEST(RecursiveASTVisitorTest, AttributesInsideDecls) { + /// Check attributes are traversed inside TraverseFunctionDecl. + llvm::StringRef Code = R"cpp( +__attribute__((annotate("something"))) int foo() { return 10; } + )cpp"; + + EXPECT_THAT(collectEvents(Code), + ElementsAre(VisitEvent::StartTraverseFunction, + VisitEvent::StartTraverseAttr, + VisitEvent::EndTraverseAttr, + VisitEvent::EndTraverseFunction)); +}