diff --git a/clang-tools-extra/clangd/IncludeCleaner.cpp b/clang-tools-extra/clangd/IncludeCleaner.cpp --- a/clang-tools-extra/clangd/IncludeCleaner.cpp +++ b/clang-tools-extra/clangd/IncludeCleaner.cpp @@ -440,8 +440,9 @@ return {std::move(UnusedIncludes), std::move(MissingIncludes)}; } -Fix removeAllUnusedIncludes(llvm::ArrayRef UnusedIncludes) { - assert(!UnusedIncludes.empty()); +std::optional removeAllUnusedIncludes(llvm::ArrayRef UnusedIncludes) { + if (UnusedIncludes.empty()) + return std::nullopt; Fix RemoveAll; RemoveAll.Message = "remove all unused includes"; @@ -465,8 +466,10 @@ } return RemoveAll; } -Fix addAllMissingIncludes(llvm::ArrayRef MissingIncludeDiags) { - assert(!MissingIncludeDiags.empty()); +std::optional +addAllMissingIncludes(llvm::ArrayRef MissingIncludeDiags) { + if (MissingIncludeDiags.empty()) + return std::nullopt; Fix AddAllMissing; AddAllMissing.Message = "add all missing includes"; @@ -516,15 +519,11 @@ llvm::StringRef Code) { std::vector UnusedIncludes = generateUnusedIncludeDiagnostics( AST.tuPath(), Findings.UnusedIncludes, Code); - std::optional RemoveAllUnused;; - if (UnusedIncludes.size() > 1) - RemoveAllUnused = removeAllUnusedIncludes(UnusedIncludes); + std::optional RemoveAllUnused = removeAllUnusedIncludes(UnusedIncludes); std::vector MissingIncludeDiags = generateMissingIncludeDiagnostics( AST, Findings.MissingIncludes, Code); - std::optional AddAllMissing; - if (MissingIncludeDiags.size() > 1) - AddAllMissing = addAllMissingIncludes(MissingIncludeDiags); + std::optional AddAllMissing = addAllMissingIncludes(MissingIncludeDiags); std::optional FixAll; if (RemoveAllUnused && AddAllMissing) @@ -535,11 +534,16 @@ Out->Fixes.push_back(*F); }; for (auto &Diag : MissingIncludeDiags) { - AddBatchFix(AddAllMissing, &Diag); + AddBatchFix(MissingIncludeDiags.size() > 1 + ? AddAllMissing + : std::nullopt, + &Diag); AddBatchFix(FixAll, &Diag); } for (auto &Diag : UnusedIncludes) { - AddBatchFix(RemoveAllUnused, &Diag); + AddBatchFix(UnusedIncludes.size() > 1 ? RemoveAllUnused + : std::nullopt, + &Diag); AddBatchFix(FixAll, &Diag); } diff --git a/clang-tools-extra/clangd/unittests/IncludeCleanerTests.cpp b/clang-tools-extra/clangd/unittests/IncludeCleanerTests.cpp --- a/clang-tools-extra/clangd/unittests/IncludeCleanerTests.cpp +++ b/clang-tools-extra/clangd/unittests/IncludeCleanerTests.cpp @@ -197,10 +197,10 @@ ns::$bar[[Bar]] bar; bar.d(); - $f[[f]](); + $f[[f]](); // this should not be diagnosed, because it's ignored in the config - buzz(); + buzz(); $foobar[[foobar]](); @@ -244,7 +244,7 @@ TU.AdditionalFiles["foo.h"] = guard(R"cpp( #define BAR(x) Foo *x #define FOO 1 - struct Foo{}; + struct Foo{}; )cpp"); TU.Code = MainFile.code(); @@ -510,6 +510,71 @@ } } +TEST(IncludeCleaner, BatchFix) { + Config Cfg; + Cfg.Diagnostics.MissingIncludes = Config::IncludesPolicy::Strict; + Cfg.Diagnostics.UnusedIncludes = Config::IncludesPolicy::Strict; + WithContextValue Ctx(Config::Key, std::move(Cfg)); + + TestTU TU; + TU.Filename = "main.cpp"; + TU.AdditionalFiles["foo.h"] = guard("class Foo;"); + TU.AdditionalFiles["bar.h"] = guard("class Bar;"); + TU.AdditionalFiles["all.h"] = guard(R"cpp( + #include "foo.h" + #include "bar.h" + )cpp"); + + TU.Code = R"cpp( + #include "all.h" + + Foo* foo; + )cpp"; + auto AST = TU.build(); + EXPECT_THAT( + issueIncludeCleanerDiagnostics(AST, TU.Code), + UnorderedElementsAre(withFix({FixMessage("#include \"foo.h\""), + FixMessage("fix all includes")}), + withFix({FixMessage("remove #include directive"), + FixMessage("fix all includes")}))); + + TU.Code = R"cpp( + #include "all.h" + #include "bar.h" + + Foo* foo; + )cpp"; + AST = TU.build(); + EXPECT_THAT( + issueIncludeCleanerDiagnostics(AST, TU.Code), + UnorderedElementsAre(withFix({FixMessage("#include \"foo.h\""), + FixMessage("fix all includes")}), + withFix({FixMessage("remove #include directive"), + FixMessage("remove all unused includes"), + FixMessage("fix all includes")}), + withFix({FixMessage("remove #include directive"), + FixMessage("remove all unused includes"), + FixMessage("fix all includes")}))); + + TU.Code = R"cpp( + #include "all.h" + + Foo* foo; + Bar* bar; + )cpp"; + AST = TU.build(); + EXPECT_THAT( + issueIncludeCleanerDiagnostics(AST, TU.Code), + UnorderedElementsAre(withFix({FixMessage("#include \"foo.h\""), + FixMessage("add all missing includes"), + FixMessage("fix all includes")}), + withFix({FixMessage("#include \"bar.h\""), + FixMessage("add all missing includes"), + FixMessage("fix all includes")}), + withFix({FixMessage("remove #include directive"), + FixMessage("fix all includes")}))); +} + } // namespace } // namespace clangd } // namespace clang