diff --git a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp --- a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp +++ b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp @@ -18,8 +18,7 @@ llvm::SmallVector
Results; switch (Loc.kind()) { case SymbolLocation::Physical: { - // FIXME: Handle macro locations. - FileID FID = SM.getFileID(Loc.physical()); + FileID FID = SM.getFileID(SM.getExpansionLoc(Loc.physical())); const FileEntry *FE = SM.getFileEntryForID(FID); if (!PI) { return FE ? llvm::SmallVector
{Header(FE)} diff --git a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp --- a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp +++ b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp @@ -7,11 +7,17 @@ //===----------------------------------------------------------------------===// #include "AnalysisInternal.h" +#include "clang-include-cleaner/Analysis.h" #include "clang-include-cleaner/Record.h" +#include "clang-include-cleaner/Types.h" +#include "clang/AST/RecursiveASTVisitor.h" #include "clang/Basic/FileEntry.h" #include "clang/Basic/FileManager.h" #include "clang/Frontend/FrontendActions.h" #include "clang/Testing/TestAST.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/Annotations.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include @@ -31,7 +37,7 @@ std::unique_ptr AST; FindHeadersTest() { Inputs.MakeAction = [this] { - struct Hook : public PreprocessOnlyAction { + struct Hook : public SyntaxOnlyAction { public: Hook(PragmaIncludes *Out) : Out(Out) {} bool BeginSourceFileAction(clang::CompilerInstance &CI) override { @@ -153,5 +159,60 @@ physicalHeader("exporter.h"))); } +TEST_F(FindHeadersTest, TargetIsExpandedFromMacroInHeader) { + struct CustomVisitor : RecursiveASTVisitor { + const Decl *Out = nullptr; + bool VisitNamedDecl(const NamedDecl *ND) { + if (ND->getName() == "FLAG_foo" || ND->getName() == "Foo") { + EXPECT_TRUE(Out == nullptr); + Out = ND; + } + return true; + } + }; + + struct { + llvm::StringRef MacroHeader; + llvm::StringRef DeclareHeader; + } TestCases[] = { + {/*MacroHeader=*/R"cpp( + #define DEFINE_CLASS(name) class name {}; + )cpp", + /*DeclareHeader=*/R"cpp( + #include "macro.h" + DEFINE_CLASS(Foo) + )cpp"}, + {/*MacroHeader=*/R"cpp( + #define DEFINE_Foo class Foo {}; + )cpp", + /*DeclareHeader=*/R"cpp( + #include "macro.h" + DEFINE_Foo + )cpp"}, + {/*MacroHeader=*/R"cpp( + #define DECLARE_FLAGS(name) extern int FLAG_##name + )cpp", + /*DeclareHeader=*/R"cpp( + #include "macro.h" + DECLARE_FLAGS(foo); + )cpp"}, + }; + + for (const auto &T : TestCases) { + Inputs.Code = R"cpp(#include "declare.h")cpp"; + Inputs.ExtraFiles["declare.h"] = guard(T.DeclareHeader); + Inputs.ExtraFiles["macro.h"] = guard(T.MacroHeader); + buildAST(); + + CustomVisitor Visitor; + Visitor.TraverseDecl(AST->context().getTranslationUnitDecl()); + + llvm::SmallVector
Headers = clang::include_cleaner::findHeaders( + Visitor.Out->getLocation(), AST->sourceManager(), + /*PragmaIncludes=*/nullptr); + EXPECT_THAT(Headers, UnorderedElementsAre(physicalHeader("declare.h"))); + } +} + } // namespace } // namespace clang::include_cleaner