Index: include-fixer/find-all-symbols/CMakeLists.txt =================================================================== --- include-fixer/find-all-symbols/CMakeLists.txt +++ include-fixer/find-all-symbols/CMakeLists.txt @@ -4,6 +4,7 @@ add_clang_library(findAllSymbols FindAllSymbols.cpp + PragmaCommentHandler.cpp SymbolInfo.cpp LINK_LIBS Index: include-fixer/find-all-symbols/FindAllSymbols.h =================================================================== --- include-fixer/find-all-symbols/FindAllSymbols.h +++ include-fixer/find-all-symbols/FindAllSymbols.h @@ -10,8 +10,12 @@ #ifndef LLVM_CLANG_TOOLS_EXTRA_FIND_ALL_SYMBOLS_SYMBOL_MATCHER_H #define LLVM_CLANG_TOOLS_EXTRA_FIND_ALL_SYMBOLS_SYMBOL_MATCHER_H +#include "PragmaCommentHandler.h" #include "SymbolInfo.h" #include "clang/ASTMatchers/ASTMatchFinder.h" +#include "llvm/ADT/StringRef.h" + +#include #include namespace clang { @@ -39,7 +43,11 @@ const SymbolInfo &Symbol) = 0; }; - explicit FindAllSymbols(ResultReporter *Reporter) : Reporter(Reporter) {} + explicit FindAllSymbols(ResultReporter *Reporter, + PragmaCommentHandler::HeaderMapCollector *Collector) + : Reporter(Reporter), + Collector(Collector) { + } void registerMatchers(clang::ast_matchers::MatchFinder *MatchFinder); @@ -48,6 +56,7 @@ private: ResultReporter *const Reporter; + PragmaCommentHandler::HeaderMapCollector *const Collector; }; } // namespace find_all_symbols Index: include-fixer/find-all-symbols/FindAllSymbols.cpp =================================================================== --- include-fixer/find-all-symbols/FindAllSymbols.cpp +++ include-fixer/find-all-symbols/FindAllSymbols.cpp @@ -43,8 +43,8 @@ } } -bool SetCommonInfo(const MatchFinder::MatchResult &Result, - const NamedDecl *ND, SymbolInfo *Symbol) { +bool SetCommonInfo(const MatchFinder::MatchResult &Result, const NamedDecl *ND, + SymbolInfo *Symbol, const PragmaCommentHandler::HeaderMapCollector *const Collector) { SetContext(ND, Symbol); Symbol->Name = ND->getNameAsString(); @@ -60,6 +60,13 @@ Symbol->LineNumber = SM->getExpansionLineNumber(Loc); + // Check pragma header. + auto PragmaHeader = Collector->getHeaderMapping(SM->getFileID(Loc)); + if (PragmaHeader) { + Symbol->FilePath = PragmaHeader.getValue().str(); + return true; + } + llvm::StringRef FilePath = SM->getFilename(Loc); if (FilePath.empty()) return false; @@ -174,7 +181,7 @@ const SourceManager *SM = Result.SourceManager; SymbolInfo Symbol; - if (!SetCommonInfo(Result, ND, &Symbol)) + if (!SetCommonInfo(Result, ND, &Symbol, this->Collector)) return; if (const auto *VD = llvm::dyn_cast(ND)) { Index: include-fixer/find-all-symbols/PragmaCommentHandler.h =================================================================== --- /dev/null +++ include-fixer/find-all-symbols/PragmaCommentHandler.h @@ -0,0 +1,45 @@ +//===-- PragmaCommentHandler.h - find all symbols----------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CLANG_TOOLS_EXTRA_FIND_ALL_SYMBOLS_PRAGMACOMMENTHANDLER_H +#define LLVM_CLANG_TOOLS_EXTRA_FIND_ALL_SYMBOLS_PRAGMACOMMENTHANDLER_H + +#include "clang/Basic/SourceLocation.h" +#include "clang/Lex/Preprocessor.h" +#include + +namespace clang { +namespace find_all_symbols { + +class FindAllSymbols; + +class PragmaCommentHandler : public clang::CommentHandler { +public: + class HeaderMapCollector { + public: + virtual ~HeaderMapCollector() = default; + void addHeaderMapping(FileID ID, llvm::StringRef FilePath); + llvm::Optional getHeaderMapping(FileID ID) const; + + private: + std::map HeaderMap; + }; + + PragmaCommentHandler(HeaderMapCollector *Collector) : Collector(Collector) {} + + bool HandleComment(Preprocessor &PP, SourceRange Range) override; + +private: + HeaderMapCollector *const Collector; +}; + +} // namespace find_all_symbols +} // namespace clang + +#endif // LLVM_CLANG_TOOLS_EXTRA_FIND_ALL_SYMBOLS_PRAGMACOMMENTHANDLER_H Index: include-fixer/find-all-symbols/PragmaCommentHandler.cpp =================================================================== --- /dev/null +++ include-fixer/find-all-symbols/PragmaCommentHandler.cpp @@ -0,0 +1,47 @@ +//===-- PragmaCommentHandler.cpp - find all symbols -----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "PragmaCommentHandler.h" +#include "FindAllSymbols.h" +#include "clang/Lex/Preprocessor.h" +#include "llvm/Support/Regex.h" + +namespace clang { +namespace find_all_symbols { +namespace { +const char IWYUPragma[] = "// IWYU pragma: private, include (.*)"; +} // namespace + +void PragmaCommentHandler::HeaderMapCollector::addHeaderMapping( + FileID ID, llvm::StringRef FilePath) { + HeaderMap[ID] = FilePath; +} + +llvm::Optional +PragmaCommentHandler::HeaderMapCollector::getHeaderMapping(FileID ID) const { + auto It = HeaderMap.find(ID); + if (It == HeaderMap.end()) + return llvm::None; + return llvm::StringRef(It->second); +} + +bool PragmaCommentHandler::HandleComment(Preprocessor &PP, SourceRange Range) { + StringRef Text = + Lexer::getSourceText(CharSourceRange::getCharRange(Range), + PP.getSourceManager(), PP.getLangOpts()); + SmallVector Matches; + if (!llvm::Regex(IWYUPragma).match(Text, &Matches)) + return false; + Collector->addHeaderMapping(PP.getSourceManager().getFileID(Range.getBegin()), + Matches[1].trim("\"<>")); + return false; +} + +} // namespace find_all_symbols +} // namespace clang Index: include-fixer/find-all-symbols/tool/FindAllSymbolsMain.cpp =================================================================== --- include-fixer/find-all-symbols/tool/FindAllSymbolsMain.cpp +++ include-fixer/find-all-symbols/tool/FindAllSymbolsMain.cpp @@ -9,7 +9,12 @@ #include "FindAllSymbols.h" #include "SymbolInfo.h" +#include "PragmaCommentHandler.h" #include "clang/ASTMatchers/ASTMatchFinder.h" +#include "clang/ASTMatchers/ASTMatchers.h" +#include "clang/Frontend/CompilerInstance.h" +#include "clang/Frontend/FrontendActions.h" +#include "clang/Lex/Preprocessor.h" #include "clang/Tooling/CommonOptionsParser.h" #include "clang/Tooling/Tooling.h" #include "llvm/ADT/ArrayRef.h" @@ -82,6 +87,32 @@ std::map> Symbols; }; +class FindAllSymbolsAction : public clang::ASTFrontendAction { +public: + FindAllSymbolsAction() + : Reporter(), MatchFinder(), Collector(), + Handler(&Collector), + Matcher(&Reporter, &Collector) { + Matcher.registerMatchers(&MatchFinder); + } + + std::unique_ptr + CreateASTConsumer(clang::CompilerInstance &Compiler, + StringRef InFile) override { + Compiler.getPreprocessor().addCommentHandler(&Handler); + return MatchFinder.newASTConsumer(); + } + + void EndSourceFileAction() override { Reporter.Write(OutputDir); } + +private: + YamlReporter Reporter; + clang::ast_matchers::MatchFinder MatchFinder; + PragmaCommentHandler::HeaderMapCollector Collector; + PragmaCommentHandler Handler; + FindAllSymbols Matcher; +}; + bool Merge(llvm::StringRef MergeDir, llvm::StringRef OutputFile) { std::error_code EC; std::set UniqueSymbols; @@ -142,11 +173,8 @@ return 0; } - clang::find_all_symbols::YamlReporter Reporter; - clang::find_all_symbols::FindAllSymbols Matcher(&Reporter); - clang::ast_matchers::MatchFinder MatchFinder; - Matcher.registerMatchers(&MatchFinder); - Tool.run(newFrontendActionFactory(&MatchFinder).get()); - Reporter.Write(OutputDir); + Tool.run( + newFrontendActionFactory() + .get()); return 0; } Index: unittests/include-fixer/find-all-symbols/FindAllSymbolsTests.cpp =================================================================== --- unittests/include-fixer/find-all-symbols/FindAllSymbolsTests.cpp +++ unittests/include-fixer/find-all-symbols/FindAllSymbolsTests.cpp @@ -8,11 +8,13 @@ //===----------------------------------------------------------------------===// #include "FindAllSymbols.h" +#include "PragmaCommentHandler.h" #include "SymbolInfo.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/Basic/FileManager.h" #include "clang/Basic/FileSystemOptions.h" #include "clang/Basic/VirtualFileSystem.h" +#include "clang/Frontend/CompilerInstance.h" #include "clang/Frontend/PCHContainerOperations.h" #include "clang/Tooling/Tooling.h" #include "llvm/ADT/IntrusiveRefCntPtr.h" @@ -62,6 +64,41 @@ std::vector Symbols; }; +class MockFindAllSymbolsAction : public clang::ASTFrontendAction { +public: + MockFindAllSymbolsAction(FindAllSymbols::ResultReporter *Reporter) + : MatchFinder(), Collector(), Handler(&Collector), + Matcher(Reporter, &Collector) { + Matcher.registerMatchers(&MatchFinder); + } + + std::unique_ptr + CreateASTConsumer(clang::CompilerInstance &Compiler, + StringRef InFile) override { + Compiler.getPreprocessor().addCommentHandler(&Handler); + return MatchFinder.newASTConsumer(); + } + +private: + ast_matchers::MatchFinder MatchFinder; + PragmaCommentHandler::HeaderMapCollector Collector; + PragmaCommentHandler Handler; + FindAllSymbols Matcher; +}; + +class MockFindAllSymbolsActionFactory + : public clang::tooling::FrontendActionFactory { +public: + MockFindAllSymbolsActionFactory(MockReporter *Reporter) + : Reporter(Reporter) {} + clang::FrontendAction *create() override { + return new MockFindAllSymbolsAction(Reporter); + } + +private: + MockReporter *const Reporter; +}; + class FindAllSymbolsTest : public ::testing::Test { public: bool hasSymbol(const SymbolInfo &Symbol) { @@ -73,18 +110,16 @@ } bool runFindAllSymbols(StringRef Code) { - FindAllSymbols matcher(&Reporter); - clang::ast_matchers::MatchFinder MatchFinder; - matcher.registerMatchers(&MatchFinder); - llvm::IntrusiveRefCntPtr InMemoryFileSystem( new vfs::InMemoryFileSystem); llvm::IntrusiveRefCntPtr Files( new FileManager(FileSystemOptions(), InMemoryFileSystem)); std::string FileName = "symbol.cc"; - std::unique_ptr Factory = - clang::tooling::newFrontendActionFactory(&MatchFinder); + + std::unique_ptr Factory( + new MockFindAllSymbolsActionFactory(&Reporter)); + tooling::ToolInvocation Invocation( {std::string("find_all_symbols"), std::string("-fsyntax-only"), FileName}, @@ -371,5 +406,20 @@ } } +TEST_F(FindAllSymbolsTest, IWYUPrivatePragmaTest) { + static const char Code[] = R"( + // IWYU pragma: private, include "bar.h" + struct Bar { + }; + )"; + runFindAllSymbols(Code); + + { + SymbolInfo Symbol = + CreateSymbolInfo("Bar", SymbolInfo::Class, "bar.h", 3, {}); + EXPECT_TRUE(hasSymbol(Symbol)); + } +} + } // namespace find_all_symbols } // namespace clang