Index: clang/include/clang/Frontend/FrontendAction.h =================================================================== --- clang/include/clang/Frontend/FrontendAction.h +++ clang/include/clang/Frontend/FrontendAction.h @@ -326,6 +326,28 @@ bool hasCodeCompletionSupport() const override; }; +/// Class for adding custom ASTConsumer to a FrontendAction. +/// Wraps the base action and forwards all method call except CreateASTConsumer. +class ASTConsumerInjector : public WrapperFrontendAction { + std::function(CompilerInstance &CI, + StringRef InFile)> + ASTConsumerFactory; + +public: + /// \p BaseActionInput is the base action to which all method calls are + /// forwarded. \p ASTConsumerFactory is factory to produce the custom + /// ASTConsumer in CreateASTConsumer method. + ASTConsumerInjector(std::unique_ptr BaseActionInput, + std::function( + CompilerInstance &CI, StringRef InFile)> + ASTConsumerFactory); + + /// Returns MultiplexConsumer consumer with first element produced by + /// ASTConsumerFactory, followed by the ASTConsumer from the base action. + std::unique_ptr CreateASTConsumer(CompilerInstance &CI, + StringRef InFile) override; +}; + } // end namespace clang #endif Index: clang/lib/Frontend/FrontendAction.cpp =================================================================== --- clang/lib/Frontend/FrontendAction.cpp +++ clang/lib/Frontend/FrontendAction.cpp @@ -1100,3 +1100,18 @@ std::unique_ptr WrappedAction) : WrappedAction(std::move(WrappedAction)) {} +ASTConsumerInjector::ASTConsumerInjector( + std::unique_ptr BaseActionInput, + std::function(CompilerInstance &CI, + StringRef InFile)> + ASTConsumerFactory) + : WrapperFrontendAction(std::move(BaseActionInput)), + ASTConsumerFactory(ASTConsumerFactory) {} + +std::unique_ptr +ASTConsumerInjector::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) { + std::vector> Consumers; + Consumers.emplace_back(ASTConsumerFactory(CI, InFile)); + Consumers.emplace_back(WrapperFrontendAction::CreateASTConsumer(CI, InFile)); + return std::make_unique(std::move(Consumers)); +} Index: clang/unittests/Frontend/FrontendActionTest.cpp =================================================================== --- clang/unittests/Frontend/FrontendActionTest.cpp +++ clang/unittests/Frontend/FrontendActionTest.cpp @@ -22,6 +22,7 @@ #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/ToolOutputFile.h" #include "gtest/gtest.h" +#include using namespace llvm; using namespace clang; @@ -292,4 +293,66 @@ } } +TEST(ASTFrontendAction, ASTConsumerInjector) { + class InjectTestASTConsumer : public ASTConsumer { + std::string& TestLog; + const std::string LogValue; + public: + InjectTestASTConsumer(std::string& TestLog, const std::string& LogValue) + : TestLog(TestLog), LogValue(LogValue) + { } + + bool HandleTopLevelDecl(DeclGroupRef D) override { + TestLog += LogValue; + return true; + } + }; + + class InjectTestASTFrontendAction : public ASTFrontendAction { + std::string& TestLog; + const std::string LogValue; + public: + InjectTestASTFrontendAction( + std::string& TestLog, + const std::string& LogValue + ) : TestLog(TestLog), + LogValue(LogValue) { } + + bool BeginSourceFileAction(CompilerInstance &ci) override { + return ASTFrontendAction::BeginSourceFileAction(ci); + } + + std::unique_ptr CreateASTConsumer(CompilerInstance &CI, + StringRef) override { + return std::make_unique(TestLog, LogValue); + } + }; + + auto invocation = std::make_shared(); + invocation->getLangOpts()->CPlusPlus = true; + invocation->getLangOpts()->DelayedTemplateParsing = true; + invocation->getPreprocessorOpts().addRemappedFile( + "test.cpp", MemoryBuffer::getMemBuffer( + "int main() { return 0;}\n").release()); + invocation->getFrontendOpts().Inputs.push_back( + FrontendInputFile("test.cpp", Language::CXX)); + invocation->getFrontendOpts().ProgramAction = frontend::ParseSyntaxOnly; + invocation->getTargetOpts().Triple = "i386-unknown-linux-gnu"; + CompilerInstance compiler; + compiler.setInvocation(std::move(invocation)); + compiler.createDiagnostics(); + + std::string TestLog; + + ASTConsumerInjector WrappedTestAction( + std::make_unique(TestLog, "A"), + [&](CompilerInstance&, StringRef) { + return std::make_unique(TestLog, "B"); + } + ); + + ASSERT_TRUE(compiler.ExecuteAction(WrappedTestAction)); + ASSERT_EQ(TestLog, "BA"); +} + } // anonymous namespace