diff --git a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp --- a/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp +++ b/clang-tools-extra/include-cleaner/lib/FindHeaders.cpp @@ -25,6 +25,8 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include +#include +#include #include namespace clang::include_cleaner { @@ -188,13 +190,13 @@ if (!PI) return {{FE, Hints::PublicHeader | Hints::OriginHeader}}; bool IsOrigin = true; + std::queue Exporters; while (FE) { Results.emplace_back(FE, isPublicHeader(FE, *PI) | (IsOrigin ? Hints::OriginHeader : Hints::None)); - // FIXME: compute transitive exporter headers. for (const auto *Export : PI->getExporters(FE, SM.getFileManager())) - Results.emplace_back(Export, isPublicHeader(Export, *PI)); + Exporters.push(Export); if (auto Verbatim = PI->getPublic(FE); !Verbatim.empty()) { Results.emplace_back(Verbatim, @@ -209,6 +211,20 @@ FE = SM.getFileEntryForID(FID); IsOrigin = false; } + // Now traverse provider trees rooted at exporters. + // Note that we only traverse export edges, and ignore private -> public + // mappings, as those pragmas apply to exporter, and not the main provider + // being exported in this header. + std::set SeenExports; + while (!Exporters.empty()) { + auto *Export = Exporters.front(); + Exporters.pop(); + if (!SeenExports.insert(Export).second) // In case of cyclic exports + continue; + Results.emplace_back(Export, isPublicHeader(Export, *PI)); + for (const auto *Export : PI->getExporters(Export, SM.getFileManager())) + Exporters.push(Export); + } return Results; } case SymbolLocation::Standard: { diff --git a/clang-tools-extra/include-cleaner/lib/Record.cpp b/clang-tools-extra/include-cleaner/lib/Record.cpp --- a/clang-tools-extra/include-cleaner/lib/Record.cpp +++ b/clang-tools-extra/include-cleaner/lib/Record.cpp @@ -12,6 +12,7 @@ #include "clang/AST/ASTContext.h" #include "clang/AST/DeclGroup.h" #include "clang/Basic/FileEntry.h" +#include "clang/Basic/FileManager.h" #include "clang/Basic/LLVM.h" #include "clang/Basic/SourceLocation.h" #include "clang/Basic/SourceManager.h" @@ -24,16 +25,21 @@ #include "clang/Tooling/Inclusions/HeaderAnalysis.h" #include "clang/Tooling/Inclusions/StandardLibrary.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem/UniqueID.h" #include "llvm/Support/StringSaver.h" #include #include #include #include +#include #include #include diff --git a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp --- a/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp +++ b/clang-tools-extra/include-cleaner/unittests/FindHeadersTest.cpp @@ -465,6 +465,25 @@ physicalHeader("foo.h"))); } +TEST_F(HeadersForSymbolTest, IWYUTransitiveExport) { + Inputs.Code = R"cpp( + #include "export1.h" + )cpp"; + Inputs.ExtraFiles["export1.h"] = R"cpp( + #include "export2.h" // IWYU pragma: export + )cpp"; + Inputs.ExtraFiles["export2.h"] = R"cpp( + #include "foo.h" // IWYU pragma: export + )cpp"; + Inputs.ExtraFiles["foo.h"] = guard(R"cpp( + struct foo {}; + )cpp"); + buildAST(); + EXPECT_THAT(headersForFoo(), + ElementsAre(physicalHeader("foo.h"), physicalHeader("export1.h"), + physicalHeader("export2.h"))); +} + TEST_F(HeadersForSymbolTest, AmbiguousStdSymbols) { struct { llvm::StringRef Code;