diff --git a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp --- a/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp +++ b/clang/lib/Tooling/Inclusions/HeaderIncludes.cpp @@ -233,9 +233,6 @@ return Ret; } bool IncludeCategoryManager::isMainHeader(StringRef IncludeName) const { - if (!IncludeName.startswith("\"")) - return false; - IncludeName = IncludeName.drop_front(1).drop_back(1); // remove the surrounding "" or <> // Not matchingStem: implementation files may have compound extensions but @@ -259,8 +256,22 @@ if (!Matching.empty()) { llvm::Regex MainIncludeRegex(HeaderStem.str() + Style.IncludeIsMainRegex, llvm::Regex::IgnoreCase); - if (MainIncludeRegex.match(Matching)) + if (MainIncludeRegex.match(Matching)) { + // Matching is non-empty so these should be non-empty as well, making + // ++rbegin(IncludeName) and ++rbegin(FileName) safe. + assert(!IncludeName.empty()); + assert(!FileName.empty()); + // Checked stems above. Check remaining common path components here. + auto IncludePathRIter = ++llvm::sys::path::rbegin(IncludeName); + auto FilePathRiter = ++llvm::sys::path::rbegin(FileName); + for (; IncludePathRIter != llvm::sys::path::rend(IncludeName) && + FilePathRiter != llvm::sys::path::rend(FileName); + ++IncludePathRIter, ++FilePathRiter) { + if (*IncludePathRIter != *FilePathRiter) + return false; + } return true; + } } return false; } diff --git a/clang/unittests/Tooling/HeaderIncludesTest.cpp b/clang/unittests/Tooling/HeaderIncludesTest.cpp --- a/clang/unittests/Tooling/HeaderIncludesTest.cpp +++ b/clang/unittests/Tooling/HeaderIncludesTest.cpp @@ -91,6 +91,28 @@ EXPECT_EQ(Expected, insert(Code, "\"a.h\"")); } +TEST_F(HeaderIncludesTest, IsMainHeader) { + Style = format::getGoogleStyle(format::FormatStyle::LanguageKind::LK_Cpp) + .IncludeStyle; + std::vector FileNames{"foo/bar/baz.cpp", "foo/bar/baz.cu.cpp", + "foo/bar/baz_test.cu.cpp"}; + for (const StringRef &FileName : FileNames) { + IncludeCategoryManager Manager(Style, FileName); + // These framework-style includes should all be considered "main". + EXPECT_EQ(Manager.getIncludePriority("", true), 0) + << "for source file " << FileName; + EXPECT_EQ(Manager.getIncludePriority("\"bar/baz.h\"", true), 0) + << "for source file " << FileName; + EXPECT_EQ(Manager.getIncludePriority("", true), 0) + << "for source file " << FileName; + // These should not be considered "main" as the paths to baz.h differ. + EXPECT_NE(Manager.getIncludePriority("", true), 0) + << "for source file " << FileName; + EXPECT_NE(Manager.getIncludePriority("\"foo/baz.h\"", true), 0) + << "for source file " << FileName; + } +} + TEST_F(HeaderIncludesTest, InsertAfterMainHeader) { std::string Code = "#include \"fix.h\"\n" "\n"