Index: clang/include/clang/AST/ASTImporter.h =================================================================== --- clang/include/clang/AST/ASTImporter.h +++ clang/include/clang/AST/ASTImporter.h @@ -142,6 +142,12 @@ void AddToLookupTable(Decl *ToD); + protected: + /// Can be overwritten by subclasses to implement their own import logic. + /// The overwritten method should call this method if it didn't import the + /// decl on its own. + virtual Expected ImportInternal(Decl *From); + public: /// \param ToContext The context we'll be importing into. Index: clang/lib/AST/ASTImporter.cpp =================================================================== --- clang/lib/AST/ASTImporter.cpp +++ clang/lib/AST/ASTImporter.cpp @@ -7734,6 +7734,12 @@ LookupTable->add(ToND); } +Expected ASTImporter::ImportInternal(Decl *FromD) { + // Import the declaration. + ASTNodeImporter Importer(*this); + return Importer.Visit(FromD); +} + Expected ASTImporter::Import_New(QualType FromT) { if (FromT.isNull()) return QualType{}; @@ -7827,7 +7833,6 @@ if (!FromD) return nullptr; - ASTNodeImporter Importer(*this); // Check whether we've already imported this declaration. Decl *ToD = GetAlreadyImportedOrNull(FromD); @@ -7837,8 +7842,7 @@ return ToD; } - // Import the declaration. - ExpectedDecl ToDOrErr = Importer.Visit(FromD); + ExpectedDecl ToDOrErr = ImportInternal(FromD); if (!ToDOrErr) return ToDOrErr; ToD = *ToDOrErr; Index: clang/unittests/AST/ASTImporterTest.cpp =================================================================== --- clang/unittests/AST/ASTImporterTest.cpp +++ clang/unittests/AST/ASTImporterTest.cpp @@ -304,6 +304,14 @@ const char *const InputFileName = "input.cc"; const char *const OutputFileName = "output.cc"; +public: + /// Allocates an ASTImporter (or one of its subclasses). + typedef std::function + ImporterConstructor; + +private: // Buffer for the To context, must live in the test scope. std::string ToCode; @@ -316,22 +324,37 @@ std::unique_ptr Unit; TranslationUnitDecl *TUDecl = nullptr; std::unique_ptr Importer; - TU(StringRef Code, StringRef FileName, ArgVector Args) + // The lambda that constructs the ASTImporter we use in this test. + ImporterConstructor Creator; + TU(StringRef Code, StringRef FileName, ArgVector Args, + ImporterConstructor C = ImporterConstructor()) : Code(Code), FileName(FileName), Unit(tooling::buildASTFromCodeWithArgs(this->Code, Args, this->FileName)), - TUDecl(Unit->getASTContext().getTranslationUnitDecl()) { + TUDecl(Unit->getASTContext().getTranslationUnitDecl()), Creator(C) { Unit->enableSourceFileDiagnostics(); + + // If the test doesn't need a specific ASTImporter, we just create a + // normal ASTImporter with it. + if (!Creator) + Creator = [](ASTContext &ToContext, FileManager &ToFileManager, + ASTContext &FromContext, FileManager &FromFileManager, + bool MinimalImport, ASTImporterLookupTable *LookupTable) { + return new ASTImporter(ToContext, ToFileManager, FromContext, + FromFileManager, MinimalImport, LookupTable); + }; + } + + void setImporter(std::unique_ptr I) { + Importer = std::move(I); } void lazyInitImporter(ASTImporterLookupTable &LookupTable, ASTUnit *ToAST) { assert(ToAST); - if (!Importer) { - Importer.reset( - new ASTImporter(ToAST->getASTContext(), ToAST->getFileManager(), - Unit->getASTContext(), Unit->getFileManager(), - false, &LookupTable)); - } + if (!Importer) + Importer.reset(Creator(ToAST->getASTContext(), ToAST->getFileManager(), + Unit->getASTContext(), Unit->getFileManager(), + false, &LookupTable)); assert(&ToAST->getASTContext() == &Importer->getToContext()); createVirtualFileIfNeeded(ToAST, FileName, Code); } @@ -401,11 +424,12 @@ // Must not be called more than once within the same test. std::tuple getImportedDecl(StringRef FromSrcCode, Language FromLang, StringRef ToSrcCode, - Language ToLang, StringRef Identifier = DeclToImportID) { + Language ToLang, StringRef Identifier = DeclToImportID, + ImporterConstructor Creator = ImporterConstructor()) { ArgVector FromArgs = getArgVectorForLanguage(FromLang), ToArgs = getArgVectorForLanguage(ToLang); - FromTUs.emplace_back(FromSrcCode, InputFileName, FromArgs); + FromTUs.emplace_back(FromSrcCode, InputFileName, FromArgs, Creator); TU &FromTU = FromTUs.back(); assert(!ToAST); @@ -455,6 +479,12 @@ return ToAST->getASTContext().getTranslationUnitDecl(); } + ASTImporter &getImporter(Decl *From, Language ToLang) { + lazyInitToAST(ToLang, "", OutputFileName); + TU *FromTU = findFromTU(From); + return *FromTU->Importer; + } + // Import the given Decl into the ToCtx. // May be called several times in a given test. // The different instances of the param From may have different ASTContext. @@ -544,6 +574,75 @@ EXPECT_THAT(RedeclsD1, ::testing::ContainerEq(RedeclsD2)); } +struct CustomImporter : ASTImporterOptionSpecificTestBase {}; + +namespace { +struct RedirectingImporter : public ASTImporter { + using ASTImporter::ASTImporter; + // Returns a ImporterConstructor that constructs this class. + static ASTImporterOptionSpecificTestBase::ImporterConstructor + getConstructor() { + return [](ASTContext &ToContext, FileManager &ToFileManager, + ASTContext &FromContext, FileManager &FromFileManager, + bool MinimalImport, ASTImporterLookupTable *LookupTabl) { + return static_cast( + new RedirectingImporter(ToContext, ToFileManager, FromContext, + FromFileManager, MinimalImport, LookupTabl)); + }; + } + +protected: + llvm::Expected ImportInternal(Decl *FromD) override { + auto *ND = dyn_cast(FromD); + if (!ND) + return ASTImporter::ImportInternal(FromD); + if (ND->getName() != "shouldNotBeImported") + return ASTImporter::ImportInternal(FromD); + for (Decl *D : getToContext().getTranslationUnitDecl()->decls()) { + if (auto ND = dyn_cast(D)) + if (ND->getName() == "realDecl") + return ND; + } + return ASTImporter::ImportInternal(FromD); + } +}; +} // namespace + +// Test that an ASTImporter subclass can intercept an import call. +TEST_P(CustomImporter, InterceptImport) { + Decl *From, *To; + std::tie(From, To) = getImportedDecl( + "class shouldNotBeImported {};", Lang_CXX, "class realDecl {};", Lang_CXX, + "shouldNotBeImported", RedirectingImporter::getConstructor()); + auto *Imported = cast(To); + EXPECT_EQ(Imported->getQualifiedNameAsString(), "realDecl"); + + // Make sure our importer prevented the importing of the decl. + auto *ToTU = Imported->getTranslationUnitDecl(); + auto Pattern = functionDecl(hasName("shouldNotBeImported")); + unsigned count = + DeclCounterWithPredicate().match(ToTU, Pattern); + EXPECT_EQ(0U, count); +} + +// Test that when we indirectly import a declaration the custom ASTImporter +// is still intercepting the import. +TEST_P(CustomImporter, InterceptIndirectImport) { + Decl *From, *To; + std::tie(From, To) = + getImportedDecl("class shouldNotBeImported {};" + "class F { shouldNotBeImported f; };", + Lang_CXX, "class realDecl {};", Lang_CXX, "F", + RedirectingImporter::getConstructor()); + + // Make sure our ASTImporter prevented the importing of the decl. + auto *ToTU = To->getTranslationUnitDecl(); + auto Pattern = functionDecl(hasName("shouldNotBeImported")); + unsigned count = + DeclCounterWithPredicate().match(ToTU, Pattern); + EXPECT_EQ(0U, count); +} + TEST_P(ImportExpr, ImportStringLiteral) { MatchVerifier Verifier; testImport( @@ -5512,6 +5611,9 @@ INSTANTIATE_TEST_CASE_P(ParameterizedTests, ASTImporterOptionSpecificTestBase, DefaultTestValuesForRunOptions, ); +INSTANTIATE_TEST_CASE_P(ParameterizedTests, CustomImporter, + DefaultTestValuesForRunOptions, ); + INSTANTIATE_TEST_CASE_P(ParameterizedTests, ImportFunctions, DefaultTestValuesForRunOptions, );