diff --git a/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h b/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h --- a/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h +++ b/clang/include/clang/Tooling/Inclusions/HeaderIncludes.h @@ -40,8 +40,6 @@ const IncludeStyle Style; bool IsMainFile; std::string FileName; - // This refers to a substring in FileName. - StringRef FileStem; SmallVector CategoryRegexs; }; diff --git a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp --- a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp +++ b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp @@ -190,7 +190,6 @@ IncludeCategoryManager::IncludeCategoryManager(const IncludeStyle &Style, StringRef FileName) : Style(Style), FileName(FileName) { - FileStem = matchingStem(FileName); for (const auto &Category : Style.IncludeCategories) CategoryRegexs.emplace_back(Category.Regex, llvm::Regex::IgnoreCase); IsMainFile = FileName.endswith(".c") || FileName.endswith(".cc") || @@ -234,16 +233,30 @@ if (!IncludeName.startswith("\"")) return false; + IncludeName = + IncludeName.drop_front(1).drop_back(1); // remove the surrounding "" or <> // Not matchingStem: implementation files may have compound extensions but // headers may not. - StringRef HeaderStem = - llvm::sys::path::stem(IncludeName.drop_front(1).drop_back( - 1) /* remove the surrounding "" or <> */); - if (FileStem.startswith(HeaderStem) || - FileStem.startswith_lower(HeaderStem)) { + StringRef HeaderStem = llvm::sys::path::stem(IncludeName); + StringRef FileStem = llvm::sys::path::stem(FileName); // foo.cu for foo.cu.cc + StringRef MatchingFileStem = matchingStem(FileName); // foo for foo.cu.cc + // main-header examples: + // 1) foo.h => foo.cc + // 2) foo.h => foo.cu.cc + // 3) foo.proto.h => foo.proto.cc + // + // non-main-header examples: + // 1) foo.h => bar.cc + // 2) foo.proto.h => foo.cc + StringRef Matching; + if (MatchingFileStem.startswith_lower(HeaderStem)) + Matching = MatchingFileStem; // example 1), 2) + else if (FileStem.equals_lower(HeaderStem)) + Matching = FileStem; // example 3) + if (!Matching.empty()) { llvm::Regex MainIncludeRegex(HeaderStem.str() + Style.IncludeIsMainRegex, llvm::Regex::IgnoreCase); - if (MainIncludeRegex.match(FileStem)) + if (MainIncludeRegex.match(Matching)) return true; } return false; diff --git a/clang/unittests/Format/SortIncludesTest.cpp b/clang/unittests/Format/SortIncludesTest.cpp --- a/clang/unittests/Format/SortIncludesTest.cpp +++ b/clang/unittests/Format/SortIncludesTest.cpp @@ -151,7 +151,7 @@ EXPECT_TRUE(sortIncludes(FmtStyle, Code, GetCodeRange(Code), "a.cc").empty()); } -TEST_F(SortIncludesTest, NoMainFileHeader) { +TEST_F(SortIncludesTest, MainFileHeader) { std::string Code = "#include \n" "\n" "#include \"a/extra_action.proto.h\"\n"; @@ -159,6 +159,13 @@ EXPECT_TRUE( sortIncludes(FmtStyle, Code, GetCodeRange(Code), "a/extra_action.cc") .empty()); + + EXPECT_EQ("#include \"foo.bar.h\"\n" + "\n" + "#include \"a.h\"\n", + sort("#include \"a.h\"\n" + "#include \"foo.bar.h\"\n", + "foo.bar.cc")); } TEST_F(SortIncludesTest, SortedIncludesInMultipleBlocksAreMerged) {