diff --git a/clang-tools-extra/include-cleaner/lib/Analysis.cpp b/clang-tools-extra/include-cleaner/lib/Analysis.cpp --- a/clang-tools-extra/include-cleaner/lib/Analysis.cpp +++ b/clang-tools-extra/include-cleaner/lib/Analysis.cpp @@ -8,37 +8,53 @@ #include "clang-include-cleaner/Analysis.h" #include "AnalysisInternal.h" +#include "clang-include-cleaner/Record.h" #include "clang-include-cleaner/Types.h" #include "clang/AST/ASTContext.h" +#include "clang/AST/Decl.h" #include "clang/Basic/SourceManager.h" #include "clang/Tooling/Inclusions/StandardLibrary.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include namespace clang::include_cleaner { +namespace { +// Gets all the providers for a symbol by tarversing each location. +llvm::SmallVector
findAllHeaders(const Symbol &S, + const SourceManager &SM, + const PragmaIncludes &PI) { + llvm::SmallVector
Headers; + for (auto &Loc : locateSymbol(S)) { + // FIXME: Propagate header hints. + Headers.append(findHeaders(Loc.first, SM, PI)); + } + return Headers; +} +} // namespace + void walkUsed(llvm::ArrayRef ASTRoots, llvm::ArrayRef MacroRefs, const PragmaIncludes &PI, const SourceManager &SM, UsedSymbolCB CB) { + // Cache for decl to header mappings, as the same decl might be referenced in + // multiple locations and finding providers for each location is expensive. + std::unordered_map> DeclToHeaders; tooling::stdlib::Recognizer Recognizer; for (auto *Root : ASTRoots) { auto &SM = Root->getASTContext().getSourceManager(); walkAST(*Root, [&](SourceLocation Loc, NamedDecl &ND, RefType RT) { SymbolReference SymRef{Loc, ND, RT}; - if (auto SS = Recognizer(&ND)) { - // FIXME: Also report forward decls from main-file, so that the caller - // can decide to insert/ignore a header. - return CB(SymRef, findHeaders(*SS, SM, PI)); - } - // FIXME: Extract locations from redecls. - return CB(SymRef, findHeaders(ND.getLocation(), SM, PI)); + auto Inserted = DeclToHeaders.try_emplace(&ND); + if (Inserted.second) + Inserted.first->second = findAllHeaders(ND, SM, PI); + return CB(SymRef, Inserted.first->second); }); } for (const SymbolReference &MacroRef : MacroRefs) { assert(MacroRef.Target.kind() == Symbol::Macro); - return CB(MacroRef, - findHeaders(MacroRef.Target.macro().Definition, SM, PI)); + return CB(MacroRef, findAllHeaders(MacroRef.Target, SM, PI)); } } diff --git a/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp b/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp --- a/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp +++ b/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp @@ -25,11 +25,49 @@ namespace clang::include_cleaner { namespace { +using testing::Contains; using testing::Pair; using testing::UnorderedElementsAre; -TEST(WalkUsed, Basic) { - // FIXME: Have a fixture for setting up tests. +class WalkUsedTest : public testing::Test { +protected: + TestInputs Inputs; + PragmaIncludes PI; + WalkUsedTest() { + Inputs.MakeAction = [this] { + struct Hook : public SyntaxOnlyAction { + public: + Hook(PragmaIncludes *Out) : Out(Out) {} + bool BeginSourceFileAction(clang::CompilerInstance &CI) override { + Out->record(CI); + return true; + } + + PragmaIncludes *Out; + }; + return std::make_unique(&PI); + }; + } + + llvm::DenseMap> + offsetToProviders(TestAST &AST, SourceManager &SM) { + llvm::SmallVector TopLevelDecls; + for (Decl *D : AST.context().getTranslationUnitDecl()->decls()) { + TopLevelDecls.emplace_back(D); + } + llvm::DenseMap> OffsetToProviders; + walkUsed(TopLevelDecls, /*MacroRefs=*/{}, PI, SM, + [&](const SymbolReference &Ref, llvm::ArrayRef
Providers) { + auto [FID, Offset] = SM.getDecomposedLoc(Ref.RefLocation); + if (FID != SM.getMainFileID()) + ADD_FAILURE() << "Reference outside of the main file!"; + OffsetToProviders.try_emplace(Offset, Providers.vec()); + }); + return OffsetToProviders; + } +}; + +TEST_F(WalkUsedTest, Basic) { llvm::Annotations Code(R"cpp( #include "header.h" #include "private.h" @@ -39,7 +77,7 @@ std::$vector^vector $vconstructor^v; } )cpp"); - TestInputs Inputs(Code.code()); + Inputs.Code = Code.code(); Inputs.ExtraFiles["header.h"] = R"cpp( void foo(); namespace std { class vector {}; } @@ -49,40 +87,13 @@ class Private {}; )cpp"; - PragmaIncludes PI; - Inputs.MakeAction = [&PI] { - struct Hook : public SyntaxOnlyAction { - public: - Hook(PragmaIncludes *Out) : Out(Out) {} - bool BeginSourceFileAction(clang::CompilerInstance &CI) override { - Out->record(CI); - return true; - } - - PragmaIncludes *Out; - }; - return std::make_unique(&PI); - }; TestAST AST(Inputs); - - llvm::SmallVector TopLevelDecls; - for (Decl *D : AST.context().getTranslationUnitDecl()->decls()) { - TopLevelDecls.emplace_back(D); - } - auto &SM = AST.sourceManager(); - llvm::DenseMap> OffsetToProviders; - walkUsed(TopLevelDecls, /*MacroRefs=*/{}, PI, SM, - [&](const SymbolReference &Ref, llvm::ArrayRef
Providers) { - auto [FID, Offset] = SM.getDecomposedLoc(Ref.RefLocation); - EXPECT_EQ(FID, SM.getMainFileID()); - OffsetToProviders.try_emplace(Offset, Providers.vec()); - }); auto HeaderFile = Header(AST.fileManager().getFile("header.h").get()); auto MainFile = Header(SM.getFileEntryForID(SM.getMainFileID())); auto VectorSTL = Header(tooling::stdlib::Header::named("").value()); EXPECT_THAT( - OffsetToProviders, + offsetToProviders(AST, SM), UnorderedElementsAre( Pair(Code.point("bar"), UnorderedElementsAre(MainFile)), Pair(Code.point("private"), @@ -92,6 +103,35 @@ Pair(Code.point("vconstructor"), UnorderedElementsAre(VectorSTL)))); } +TEST_F(WalkUsedTest, MultipleProviders) { + llvm::Annotations Code(R"cpp( + #include "header1.h" + #include "header2.h" + void foo(); + + void bar() { + $foo^foo(); + } + )cpp"); + Inputs.Code = Code.code(); + Inputs.ExtraFiles["header1.h"] = R"cpp( + void foo(); + )cpp"; + Inputs.ExtraFiles["header2.h"] = R"cpp( + void foo(); + )cpp"; + + TestAST AST(Inputs); + auto &SM = AST.sourceManager(); + auto HeaderFile1 = Header(AST.fileManager().getFile("header1.h").get()); + auto HeaderFile2 = Header(AST.fileManager().getFile("header2.h").get()); + auto MainFile = Header(SM.getFileEntryForID(SM.getMainFileID())); + EXPECT_THAT( + offsetToProviders(AST, SM), + Contains(Pair(Code.point("foo"), + UnorderedElementsAre(HeaderFile1, HeaderFile2, MainFile)))); +} + TEST(WalkUsed, MacroRefs) { llvm::Annotations Hdr(R"cpp( #define ^ANSWER 42 @@ -129,6 +169,5 @@ UnorderedElementsAre(Pair(Main.point(), UnorderedElementsAre(HdrFile)))); } - } // namespace } // namespace clang::include_cleaner