Index: include/clang/AST/ASTImporter.h =================================================================== --- include/clang/AST/ASTImporter.h +++ include/clang/AST/ASTImporter.h @@ -142,6 +142,12 @@ void AddToLookupTable(Decl *ToD); + protected: + /// Can be overwritten by subclasses to implement their own import logic. + /// The overwritten method should call this method if it didn't import the + /// decl on its own. + virtual Expected ImportImpl(Decl *From); + public: /// \param ToContext The context we'll be importing into. @@ -427,6 +433,8 @@ /// \c To declaration mappings as they are imported. virtual void Imported(Decl *From, Decl *To) {} + void RegisterImportedDecl(Decl *FromD, Decl *ToD); + /// Store and assign the imported declaration to its counterpart. Decl *MapImported(Decl *From, Decl *To); Index: lib/AST/ASTImporter.cpp =================================================================== --- lib/AST/ASTImporter.cpp +++ lib/AST/ASTImporter.cpp @@ -255,8 +255,7 @@ return true; // Already imported. ToD = CreateFun(std::forward(args)...); // Keep track of imported Decls. - Importer.MapImported(FromD, ToD); - Importer.AddToLookupTable(ToD); + Importer.RegisterImportedDecl(FromD, ToD); InitializeImportedDecl(FromD, ToD); return false; // A new Decl is created. } @@ -7656,6 +7655,17 @@ LookupTable->add(ToND); } +Expected ASTImporter::ImportImpl(Decl *FromD) { + // Import the decl using ASTNodeImporter. + ASTNodeImporter Importer(*this); + return Importer.Visit(FromD); +} + +void ASTImporter::RegisterImportedDecl(Decl *FromD, Decl *ToD) { + MapImported(FromD, ToD); + AddToLookupTable(ToD); +} + Expected ASTImporter::Import_New(QualType FromT) { if (FromT.isNull()) return QualType{}; @@ -7749,7 +7759,6 @@ if (!FromD) return nullptr; - ASTNodeImporter Importer(*this); // Check whether we've already imported this declaration. Decl *ToD = GetAlreadyImportedOrNull(FromD); @@ -7760,7 +7769,7 @@ } // Import the declaration. - ExpectedDecl ToDOrErr = Importer.Visit(FromD); + ExpectedDecl ToDOrErr = ImportImpl(FromD); if (!ToDOrErr) return ToDOrErr; ToD = *ToDOrErr; @@ -7771,6 +7780,9 @@ return nullptr; } + // Make sure that ImportImpl registered the imported decl. + assert(ImportedDecls.count(FromD) != 0 && "Missing call to MapImported?"); + // Once the decl is connected to the existing declarations, i.e. when the // redecl chain is properly set then we populate the lookup again. // This way the primary context will be able to find all decls. Index: unittests/AST/ASTImporterTest.cpp =================================================================== --- unittests/AST/ASTImporterTest.cpp +++ unittests/AST/ASTImporterTest.cpp @@ -313,6 +313,17 @@ const char *const InputFileName = "input.cc"; const char *const OutputFileName = "output.cc"; +public: + /// Allocates an ASTImporter (or one of its subclasses). + typedef std::function + ImporterConstructor; + + // The lambda that constructs the ASTImporter we use in this test. + ImporterConstructor Creator; + +private: // Buffer for the To context, must live in the test scope. std::string ToCode; @@ -325,22 +336,32 @@ std::unique_ptr Unit; TranslationUnitDecl *TUDecl = nullptr; std::unique_ptr Importer; - TU(StringRef Code, StringRef FileName, ArgVector Args) + ImporterConstructor Creator; + TU(StringRef Code, StringRef FileName, ArgVector Args, + ImporterConstructor C = ImporterConstructor()) : Code(Code), FileName(FileName), Unit(tooling::buildASTFromCodeWithArgs(this->Code, Args, this->FileName)), - TUDecl(Unit->getASTContext().getTranslationUnitDecl()) { + TUDecl(Unit->getASTContext().getTranslationUnitDecl()), Creator(C) { Unit->enableSourceFileDiagnostics(); + + // If the test doesn't need a specific ASTImporter, we just create a + // normal ASTImporter with it. + if (!Creator) + Creator = [](ASTContext &ToContext, FileManager &ToFileManager, + ASTContext &FromContext, FileManager &FromFileManager, + bool MinimalImport, ASTImporterLookupTable *LookupTable) { + return new ASTImporter(ToContext, ToFileManager, FromContext, + FromFileManager, MinimalImport, LookupTable); + }; } void lazyInitImporter(ASTImporterLookupTable &LookupTable, ASTUnit *ToAST) { assert(ToAST); - if (!Importer) { - Importer.reset( - new ASTImporter(ToAST->getASTContext(), ToAST->getFileManager(), - Unit->getASTContext(), Unit->getFileManager(), - false, &LookupTable)); - } + if (!Importer) + Importer.reset(Creator(ToAST->getASTContext(), ToAST->getFileManager(), + Unit->getASTContext(), Unit->getFileManager(), + false, &LookupTable)); assert(&ToAST->getASTContext() == &Importer->getToContext()); createVirtualFileIfNeeded(ToAST, FileName, Code); } @@ -424,7 +445,7 @@ ArgVector FromArgs = getArgVectorForLanguage(FromLang), ToArgs = getArgVectorForLanguage(ToLang); - FromTUs.emplace_back(FromSrcCode, InputFileName, FromArgs); + FromTUs.emplace_back(FromSrcCode, InputFileName, FromArgs, Creator); TU &FromTU = FromTUs.back(); assert(!ToAST); @@ -562,6 +583,74 @@ EXPECT_THAT(RedeclsD1, ::testing::ContainerEq(RedeclsD2)); } +namespace { +struct RedirectingImporter : public ASTImporter { + using ASTImporter::ASTImporter; + +protected: + llvm::Expected ImportImpl(Decl *FromD) override { + auto *ND = dyn_cast(FromD); + if (!ND || ND->getName() != "shouldNotBeImported") + return ASTImporter::ImportImpl(FromD); + for (Decl *D : getToContext().getTranslationUnitDecl()->decls()) { + if (auto *ND = dyn_cast(D)) + if (ND->getName() == "realDecl") { + RegisterImportedDecl(FromD, ND); + return ND; + } + } + return ASTImporter::ImportImpl(FromD); + } +}; + +} // namespace + +struct RedirectingImporterTest : ASTImporterOptionSpecificTestBase { + RedirectingImporterTest() { + Creator = [](ASTContext &ToContext, FileManager &ToFileManager, + ASTContext &FromContext, FileManager &FromFileManager, + bool MinimalImport, ASTImporterLookupTable *LookupTable) { + return new RedirectingImporter(ToContext, ToFileManager, FromContext, + FromFileManager, MinimalImport, + LookupTable); + }; + } +}; + +// Test that an ASTImporter subclass can intercept an import call. +TEST_P(RedirectingImporterTest, InterceptImport) { + Decl *From, *To; + std::tie(From, To) = + getImportedDecl("class shouldNotBeImported {};", Lang_CXX, + "class realDecl {};", Lang_CXX, "shouldNotBeImported"); + auto *Imported = cast(To); + EXPECT_EQ(Imported->getQualifiedNameAsString(), "realDecl"); + + // Make sure our importer prevented the importing of the decl. + auto *ToTU = Imported->getTranslationUnitDecl(); + auto Pattern = functionDecl(hasName("shouldNotBeImported")); + unsigned count = + DeclCounterWithPredicate().match(ToTU, Pattern); + EXPECT_EQ(0U, count); +} + +// Test that when we indirectly import a declaration the custom ASTImporter +// is still intercepting the import. +TEST_P(RedirectingImporterTest, InterceptIndirectImport) { + Decl *From, *To; + std::tie(From, To) = + getImportedDecl("class shouldNotBeImported {};" + "class F { shouldNotBeImported f; };", + Lang_CXX, "class realDecl {};", Lang_CXX, "F"); + + // Make sure our ASTImporter prevented the importing of the decl. + auto *ToTU = To->getTranslationUnitDecl(); + auto Pattern = functionDecl(hasName("shouldNotBeImported")); + unsigned count = + DeclCounterWithPredicate().match(ToTU, Pattern); + EXPECT_EQ(0U, count); +} + TEST_P(ImportExpr, ImportStringLiteral) { MatchVerifier Verifier; testImport( @@ -5549,6 +5638,9 @@ INSTANTIATE_TEST_CASE_P(ParameterizedTests, ASTImporterOptionSpecificTestBase, DefaultTestValuesForRunOptions, ); +INSTANTIATE_TEST_CASE_P(ParameterizedTests, RedirectingImporterTest, + DefaultTestValuesForRunOptions, ); + INSTANTIATE_TEST_CASE_P(ParameterizedTests, ImportFunctions, DefaultTestValuesForRunOptions, );