diff --git a/clang/include/clang/CrossTU/CrossTranslationUnit.h b/clang/include/clang/CrossTU/CrossTranslationUnit.h --- a/clang/include/clang/CrossTU/CrossTranslationUnit.h +++ b/clang/include/clang/CrossTU/CrossTranslationUnit.h @@ -165,6 +165,7 @@ private: bool checkThresholdReached() const; llvm::Error lazyInitCTUIndex(StringRef CrossTUDir, StringRef IndexName); + ASTUnit *getCachedASTUnitForName(StringRef LookupName) const; void lazyInitImporterSharedSt(TranslationUnitDecl *ToTU); ASTImporter &getOrCreateASTImporter(ASTContext &From); template diff --git a/clang/lib/CrossTU/CrossTranslationUnit.cpp b/clang/lib/CrossTU/CrossTranslationUnit.cpp --- a/clang/lib/CrossTU/CrossTranslationUnit.cpp +++ b/clang/lib/CrossTU/CrossTranslationUnit.cpp @@ -371,6 +371,15 @@ }; } +ASTUnit *CrossTranslationUnitContext::getCachedASTUnitForName( + StringRef LookupName) const { + auto CacheEntry = NameASTUnitMap.find(LookupName); + if (CacheEntry != NameASTUnitMap.end()) + return CacheEntry->second; + else + return nullptr; +} + llvm::Expected CrossTranslationUnitContext::loadExternalAST( StringRef LookupName, StringRef CrossTUDir, StringRef IndexName, bool DisplayCTUProgress) { @@ -384,45 +393,45 @@ return llvm::make_error( index_error_code::load_threshold_reached); - ASTUnit *Unit = nullptr; - auto NameUnitCacheEntry = NameASTUnitMap.find(LookupName); - if (NameUnitCacheEntry == NameASTUnitMap.end()) { - // Lazily initialize the mapping from function names to AST files. - if (llvm::Error InitFailed = lazyInitCTUIndex(CrossTUDir, IndexName)) - return std::move(InitFailed); - - auto It = NameFileMap.find(LookupName); - if (It == NameFileMap.end()) { - ++NumNotInOtherTU; - return llvm::make_error(index_error_code::missing_definition); - } - StringRef ASTFileName = It->second; - auto ASTCacheEntry = FileASTUnitMap.find(ASTFileName); - if (ASTCacheEntry == FileASTUnitMap.end()) { - IntrusiveRefCntPtr DiagOpts = new DiagnosticOptions(); - TextDiagnosticPrinter *DiagClient = - new TextDiagnosticPrinter(llvm::errs(), &*DiagOpts); - IntrusiveRefCntPtr DiagID(new DiagnosticIDs()); - IntrusiveRefCntPtr Diags( - new DiagnosticsEngine(DiagID, &*DiagOpts, DiagClient)); - - std::unique_ptr LoadedUnit(ASTUnit::LoadFromASTFile( - ASTFileName, CI.getPCHContainerOperations()->getRawReader(), - ASTUnit::LoadEverything, Diags, CI.getFileSystemOpts())); - Unit = LoadedUnit.get(); - FileASTUnitMap[ASTFileName] = std::move(LoadedUnit); - ++NumASTLoaded; - if (DisplayCTUProgress) { - llvm::errs() << "CTU loaded AST file: " - << ASTFileName << "\n"; - } - } else { - Unit = ASTCacheEntry->second.get(); + // First try to access the cache to get the ASTUnit for the function name + // specified by LookupName. + ASTUnit *Unit = getCachedASTUnitForName(LookupName); + if (Unit) + return Unit; + + // Lazily initialize the mapping from function names to AST files. + if (llvm::Error InitFailed = lazyInitCTUIndex(CrossTUDir, IndexName)) + return std::move(InitFailed); + + auto It = NameFileMap.find(LookupName); + if (It == NameFileMap.end()) { + ++NumNotInOtherTU; + return llvm::make_error(index_error_code::missing_definition); + } + StringRef ASTFileName = It->second; + auto ASTCacheEntry = FileASTUnitMap.find(ASTFileName); + if (ASTCacheEntry == FileASTUnitMap.end()) { + IntrusiveRefCntPtr DiagOpts = new DiagnosticOptions(); + TextDiagnosticPrinter *DiagClient = + new TextDiagnosticPrinter(llvm::errs(), &*DiagOpts); + IntrusiveRefCntPtr DiagID(new DiagnosticIDs()); + IntrusiveRefCntPtr Diags( + new DiagnosticsEngine(DiagID, &*DiagOpts, DiagClient)); + + std::unique_ptr LoadedUnit(ASTUnit::LoadFromASTFile( + ASTFileName, CI.getPCHContainerOperations()->getRawReader(), + ASTUnit::LoadEverything, Diags, CI.getFileSystemOpts())); + Unit = LoadedUnit.get(); + FileASTUnitMap[ASTFileName] = std::move(LoadedUnit); + ++NumASTLoaded; + if (DisplayCTUProgress) { + llvm::errs() << "CTU loaded AST file: " << ASTFileName << "\n"; } - NameASTUnitMap[LookupName] = Unit; } else { - Unit = NameUnitCacheEntry->second; + Unit = ASTCacheEntry->second.get(); } + NameASTUnitMap[LookupName] = Unit; + if (!Unit) return llvm::make_error( index_error_code::failed_to_get_external_ast);