diff --git a/clang-tools-extra/include-cleaner/lib/Record.cpp b/clang-tools-extra/include-cleaner/lib/Record.cpp --- a/clang-tools-extra/include-cleaner/lib/Record.cpp +++ b/clang-tools-extra/include-cleaner/lib/Record.cpp @@ -12,6 +12,7 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/DeclGroup.h" #include "clang/Basic/FileEntry.h" +#include "clang/Basic/FileManager.h" #include "clang/Basic/LLVM.h" #include "clang/Basic/SourceLocation.h" #include "clang/Basic/SourceManager.h" @@ -24,16 +25,21 @@ #include "clang/Tooling/Inclusions/HeaderAnalysis.h" #include "clang/Tooling/Inclusions/StandardLibrary.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem/UniqueID.h" #include "llvm/Support/StringSaver.h" #include #include #include #include +#include #include #include @@ -387,24 +393,44 @@ return It->getSecond(); } +template static llvm::SmallVector -toFileEntries(llvm::ArrayRef FileNames, FileManager &FM) { +toFileEntries(Iter FileNamesBegin, Iter FileNamesEnd, FileManager &FM) { llvm::SmallVector Results; - for (auto FName : FileNames) { + for (auto FNameIt = FileNamesBegin; FNameIt != FileNamesEnd; ++FNameIt) { // FIMXE: log the failing cases? - if (auto FE = expectedToOptional(FM.getFileRef(FName))) + if (auto FE = expectedToOptional(FM.getFileRef(*FNameIt))) Results.push_back(*FE); } return Results; } -llvm::SmallVector -PragmaIncludes::getExporters(const FileEntry *File, FileManager &FM) const { - auto It = IWYUExportBy.find(File->getUniqueID()); + +void collectExportersRecursively( + llvm::DenseMap> + IWYUExportBy, + const llvm::sys::fs::UniqueID &UID, std::set &Result, + FileManager &FM) { + auto It = IWYUExportBy.find(UID); if (It == IWYUExportBy.end()) - return {}; + return; + auto Exporters = + toFileEntries(It->getSecond().begin(), It->getSecond().end(), FM); + for (const auto &E : Exporters) { + Result.insert(E); + collectExportersRecursively(IWYUExportBy, E->getUniqueID(), Result, FM); + } +} - return toFileEntries(It->getSecond(), FM); +llvm::SmallVector +PragmaIncludes::getExporters(const FileEntry *File, FileManager &FM) const { + std::set UniqueExporters; + collectExportersRecursively(IWYUExportBy, File->getUniqueID(), + UniqueExporters, FM); + llvm::SmallVector Exporters{UniqueExporters.begin(), + UniqueExporters.end()}; + return Exporters; } llvm::SmallVector PragmaIncludes::getExporters(tooling::stdlib::Header StdHeader, @@ -412,7 +438,7 @@ auto It = StdIWYUExportBy.find(StdHeader); if (It == StdIWYUExportBy.end()) return {}; - return toFileEntries(It->getSecond(), FM); + return toFileEntries(It->getSecond().begin(), It->getSecond().end(), FM); } bool PragmaIncludes::isSelfContained(const FileEntry *FE) const { diff --git a/clang-tools-extra/include-cleaner/unittests/RecordTest.cpp b/clang-tools-extra/include-cleaner/unittests/RecordTest.cpp --- a/clang-tools-extra/include-cleaner/unittests/RecordTest.cpp +++ b/clang-tools-extra/include-cleaner/unittests/RecordTest.cpp @@ -439,6 +439,25 @@ PI.getExporters(SM.getFileEntryForID(SM.getMainFileID()), FM).empty()); } +TEST_F(PragmaIncludeTest, IWYUTransitiveExport) { + Inputs.Code = R"cpp( + #include "export1.h" + )cpp"; + Inputs.ExtraFiles["export1.h"] = R"cpp( + #include "export2.h" // IWYU pragma: export + )cpp"; + Inputs.ExtraFiles["export2.h"] = R"cpp( + #include "provider.h" // IWYU pragma: export + )cpp"; + Inputs.ExtraFiles["provider.h"] = ""; + TestAST Processed = build(); + auto &FM = Processed.fileManager(); + + EXPECT_THAT(PI.getExporters(FM.getFile("provider.h").get(), FM), + testing::UnorderedElementsAre(FileNamed("export1.h"), + FileNamed("export2.h"))); +} + TEST_F(PragmaIncludeTest, IWYUExportForStandardHeaders) { Inputs.Code = R"cpp( #include "export.h" @@ -484,9 +503,11 @@ testing::UnorderedElementsAre(FileNamed("export1.h"), FileNamed("normal.h"))); EXPECT_THAT(PI.getExporters(FM.getFile("private2.h").get(), FM), - testing::UnorderedElementsAre(FileNamed("export1.h"))); + testing::UnorderedElementsAre(FileNamed("export1.h"), + FileNamed("normal.h"))); EXPECT_THAT(PI.getExporters(FM.getFile("private3.h").get(), FM), - testing::UnorderedElementsAre(FileNamed("export1.h"))); + testing::UnorderedElementsAre(FileNamed("export1.h"), + FileNamed("normal.h"))); EXPECT_TRUE(PI.getExporters(FM.getFile("foo.h").get(), FM).empty()); EXPECT_TRUE(PI.getExporters(FM.getFile("bar.h").get(), FM).empty());