Index: include-fixer/IncludeFixer.cpp =================================================================== --- include-fixer/IncludeFixer.cpp +++ include-fixer/IncludeFixer.cpp @@ -365,6 +365,9 @@ .getLocWithOffset(Range.getOffset()) .print(llvm::dbgs(), CI->getSourceManager())); DEBUG(llvm::dbgs() << " ..."); + llvm::StringRef FileName = CI->getSourceManager().getFilename( + CI->getSourceManager().getLocForStartOfFile( + CI->getSourceManager().getMainFileID())); QuerySymbolInfos.push_back({Query.str(), ScopedQualifiers, Range}); @@ -385,9 +388,10 @@ // context, it might treat the identifier as a nested class of the scoped // namespace. std::vector MatchedSymbols = - SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false); + SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false, FileName); if (MatchedSymbols.empty()) - MatchedSymbols = SymbolIndexMgr.search(Query); + MatchedSymbols = + SymbolIndexMgr.search(Query, /*IsNestedSearch=*/true, FileName); DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols.size() << " symbols\n"); // We store a copy of MatchedSymbols in a place where it's globally reachable. Index: include-fixer/SymbolIndexManager.h =================================================================== --- include-fixer/SymbolIndexManager.h +++ include-fixer/SymbolIndexManager.h @@ -42,7 +42,8 @@ /// /// \returns A list of symbol candidates. std::vector - search(llvm::StringRef Identifier, bool IsNestedSearch = true) const; + search(llvm::StringRef Identifier, bool IsNestedSearch = true, + llvm::StringRef FileName = "") const; private: std::vector>> SymbolIndices; Index: include-fixer/SymbolIndexManager.cpp =================================================================== --- include-fixer/SymbolIndexManager.cpp +++ include-fixer/SymbolIndexManager.cpp @@ -12,6 +12,7 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Path.h" #define DEBUG_TYPE "include-fixer" @@ -20,30 +21,57 @@ using clang::find_all_symbols::SymbolInfo; -/// Sorts SymbolInfos based on the popularity info in SymbolInfo. -static void rankByPopularity(std::vector &Symbols) { - // First collect occurrences per header file. - llvm::DenseMap HeaderPopularity; - for (const SymbolInfo &Symbol : Symbols) { - unsigned &Popularity = HeaderPopularity[Symbol.getFilePath()]; - Popularity = std::max(Popularity, Symbol.getNumOccurrences()); +// Calculate a score based on whether we think the given header is closely +// related to the given source file. +static double similarityScore(llvm::StringRef FileName, + llvm::StringRef Header) { + // Compute the maximum number of common path segements between Header and + // a suffix of FileName. + // We do not do a full longest common substring computation, as Header + // specifies the path we would directly #include, so we assume it is rooted + // relatively to a subproject of the repository. + int MaxSegments = 1; + for (auto FileI = llvm::sys::path::begin(FileName), + FileE = llvm::sys::path::end(FileName); + FileI != FileE; ++FileI) { + int Segments = 0; + for (auto HeaderI = llvm::sys::path::begin(Header), + HeaderE = llvm::sys::path::end(Header), I = FileI; + HeaderI != HeaderE && *I == *HeaderI && I != FileE; ++I, ++HeaderI) { + ++Segments; + } + MaxSegments = std::max(Segments, MaxSegments); } + return MaxSegments; +} - // Sort by the gathered popularities. Use file name as a tie breaker so we can +static void rank(std::vector &Symbols, + llvm::StringRef FileName) { + llvm::DenseMap Score; + for (const SymbolInfo &Symbol : Symbols) { + // Calculate a score from the similarity of the header the symbol is in + // with the current file and the popularity of the symbol. + double NewScore = similarityScore(FileName, Symbol.getFilePath()) * + (1.0 + std::log2(1 + Symbol.getNumOccurrences())); + double &S = Score[Symbol.getFilePath()]; + S = std::max(S, NewScore); + } + // Sort by the gathered scores. Use file name as a tie breaker so we can // deduplicate. std::sort(Symbols.begin(), Symbols.end(), [&](const SymbolInfo &A, const SymbolInfo &B) { - auto APop = HeaderPopularity[A.getFilePath()]; - auto BPop = HeaderPopularity[B.getFilePath()]; - if (APop != BPop) - return APop > BPop; + auto AS = Score[A.getFilePath()]; + auto BS = Score[B.getFilePath()]; + if (AS != BS) + return AS > BS; return A.getFilePath() < B.getFilePath(); }); } std::vector SymbolIndexManager::search(llvm::StringRef Identifier, - bool IsNestedSearch) const { + bool IsNestedSearch, + llvm::StringRef FileName) const { // The identifier may be fully qualified, so split it and get all the context // names. llvm::SmallVector Names; @@ -119,7 +147,7 @@ TookPrefix = true; } while (MatchedSymbols.empty() && !Names.empty() && IsNestedSearch); - rankByPopularity(MatchedSymbols); + rank(MatchedSymbols, FileName); return MatchedSymbols; } Index: include-fixer/tool/ClangIncludeFixer.cpp =================================================================== --- include-fixer/tool/ClangIncludeFixer.cpp +++ include-fixer/tool/ClangIncludeFixer.cpp @@ -332,7 +332,8 @@ // Query symbol mode. if (!QuerySymbol.empty()) { - auto MatchedSymbols = SymbolIndexMgr->search(QuerySymbol); + auto MatchedSymbols = SymbolIndexMgr->search( + QuerySymbol, /*IsNestedSearch=*/true, SourceFilePath); for (auto &Symbol : MatchedSymbols) { std::string HeaderPath = Symbol.getFilePath().str(); Symbol.SetFilePath(((HeaderPath[0] == '"' || HeaderPath[0] == '<') Index: test/include-fixer/Inputs/fake_yaml_db.yaml =================================================================== --- test/include-fixer/Inputs/fake_yaml_db.yaml +++ test/include-fixer/Inputs/fake_yaml_db.yaml @@ -9,7 +9,6 @@ LineNumber: 1 Type: Class NumOccurrences: 1 -... --- Name: bar Contexts: @@ -21,7 +20,7 @@ LineNumber: 1 Type: Class NumOccurrences: 1 -... +--- Name: bar Contexts: - ContextType: Namespace @@ -32,7 +31,7 @@ LineNumber: 2 Type: Class NumOccurrences: 3 -... +--- Name: bar Contexts: - ContextType: Namespace @@ -50,4 +49,12 @@ LineNumber: 1 Type: Variable NumOccurrences: 1 -... +--- +Name: bar +Contexts: + - ContextType: Namespace + ContextName: c +FilePath: test/include-fixer/baz.h +LineNumber: 1 +Type: Class +NumOccurrences: 1 Index: test/include-fixer/ranking.cpp =================================================================== --- test/include-fixer/ranking.cpp +++ test/include-fixer/ranking.cpp @@ -1,6 +1,9 @@ // RUN: clang-include-fixer -db=yaml -input=%S/Inputs/fake_yaml_db.yaml -output-headers %s -- | FileCheck %s +// RUN: clang-include-fixer -query-symbol bar -db=yaml -input=%S/Inputs/fake_yaml_db.yaml -output-headers %s -- | FileCheck %s // CHECK: "HeaderInfos": [ +// CHECK-NEXT: {"Header": "\"test/include-fixer/baz.h\"", +// CHECK-NEXT: "QualifiedName": "c::bar"}, // CHECK-NEXT: {"Header": "\"../include/bar.h\"", // CHECK-NEXT: "QualifiedName": "b::a::bar"}, // CHECK-NEXT: {"Header": "\"../include/zbar.h\"",