diff --git a/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp b/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp --- a/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp +++ b/clang-tools-extra/include-cleaner/unittests/AnalysisTest.cpp @@ -311,12 +311,9 @@ Results = {}; Results.Missing.push_back("\"d.h\""); Code = R"cpp(#include "a.h")cpp"; - // FIXME: this isn't correct, the main-file header d.h should be added before - // a.h. EXPECT_EQ(fixIncludes(Results, "d.cc", Code, format::getLLVMStyle()), -R"cpp(#include "a.h" -#include "d.h" -)cpp"); +R"cpp(#include "d.h" +#include "a.h")cpp"); } MATCHER_P3(expandedAt, FileID, Offset, SM, "") { diff --git a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp --- a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp +++ b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp @@ -335,7 +335,7 @@ // Only record the offset of current #include if we can insert after it. if (CurInclude.R.getOffset() <= MaxInsertOffset) { int Priority = Categories.getIncludePriority( - CurInclude.Name, /*CheckMainHeader=*/FirstIncludeOffset < 0); + CurInclude.Name, /*CheckMainHeader=*/true); CategoryEndOffsets[Priority] = NextLineOffset; IncludesByPriority[Priority].push_back(&CurInclude); if (FirstIncludeOffset < 0) @@ -362,7 +362,7 @@ std::string(llvm::formatv(IsAngled ? "<{0}>" : "\"{0}\"", IncludeName)); StringRef QuotedName = Quoted; int Priority = Categories.getIncludePriority( - QuotedName, /*CheckMainHeader=*/FirstIncludeOffset < 0); + QuotedName, /*CheckMainHeader=*/true); auto CatOffset = CategoryEndOffsets.find(Priority); assert(CatOffset != CategoryEndOffsets.end()); unsigned InsertOffset = CatOffset->second; // Fall back offset diff --git a/clang/unittests/Tooling/HeaderIncludesTest.cpp b/clang/unittests/Tooling/HeaderIncludesTest.cpp --- a/clang/unittests/Tooling/HeaderIncludesTest.cpp +++ b/clang/unittests/Tooling/HeaderIncludesTest.cpp @@ -143,6 +143,17 @@ EXPECT_NE(Expected, insert(Code, "")) << "Not main header"; } +TEST_F(HeaderIncludesTest, InsertMainHeader) { + std::string Code = R"cpp(#include "a.h")cpp"; + std::string Expected = R"cpp(#include "fix.h" +#include "a.h")cpp"; + + Style = format::getGoogleStyle(format::FormatStyle::LanguageKind::LK_Cpp) + .IncludeStyle; + FileName = "fix.cpp"; + EXPECT_EQ(Expected, insert(Code, "\"fix.h\"")); +} + TEST_F(HeaderIncludesTest, InsertBeforeSystemHeaderLLVM) { std::string Code = "#include \n" "\n"