Index: include/llvm/Passes/PassBuilder.h =================================================================== --- include/llvm/Passes/PassBuilder.h +++ include/llvm/Passes/PassBuilder.h @@ -17,6 +17,7 @@ #define LLVM_PASSES_PASSBUILDER_H #include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/LoopPassManager.h" #include "llvm/IR/PassManager.h" @@ -249,6 +250,193 @@ bool parseModulePassPipeline(ModulePassManager &MPM, StringRef &PipelineText, bool VerifyEachPass, bool DebugLogging); }; -} +/// \brief This class manages the registration of hooks through which +/// out-of-tree or plugin passes are able to register themselves with a +/// \c PassManager or \c AnalysisManager. +/// +/// Such a hook is a Callable with a single argument, the Pass- or +/// AnalysisManager which passes can be added to. Hooks are registered with the +/// manager by name, which is the name used in the pass pipeline text. If the +/// name matches a name of a built-in pass, this name is effectively "stolen", +/// thus overriding the built-in pass. +/// +/// To install a hook in an instance of this class, interfaces are supplied +/// to handle analysis, transform and AA passes, respectively. Installing a hook +/// for an analysis pass also automatically generates default hooks for the +/// require<> and invalidate<> utility passes. +template class HookManager { +public: + typedef std::function &)> AnalysisPassHook; + typedef std::function &)> TransformPassHook; + typedef std::function AAPassHook; + + /// \brief Register a hook for a transform pass. + /// + /// If a hook has already been registered with the same Name, this hook is + /// overwritten. + void registerTransformPassHook(StringRef Name, TransformPassHook H) { + transformHooks.insert({Name, H}); + } + + /// \brief Register a hook for an analysis pass. + /// + /// Besides the pass Name, this interface also takes the Pass type as template + /// argument, used to install default hooks for the require<> and invalidate<> + /// utilities. + /// If a hook has already been registered with the same Name, the + /// utilitypasses for that name are overwritten. + template + void registerAnalysisPassHook(StringRef Name, AnalysisPassHook H) { + analysisHooks.push_back(H); + requireUtilityHooks.insert({Name, [](PassManager &) { + return RequireAnalysisPass(); + }}); + invalidateUtilityHooks.insert({Name, [](PassManager &) { + return InvalidateAnalysisPass(); + }}); + } + + /// \brief Register a hook for an alias analysis pass. + /// + /// If a hook has already been registered with the same Name, this hook is + /// overwritten. + void registerAAPassHook(StringRef Name, AAPassHook H) { + AAHooks.insert({Name, H}); + } + + //@{ + /// Public interface for the \c PassBuilder to query or invoke registered + /// hooks. + void registerAnalyses(AnalysisManager &AM) { + for (auto &H : analysisHooks) + H(AM); + } + + inline StringRef getRequireUtilityName(StringRef Name) { + if (Name.endswith(">") && Name.startswith("require<")) + return Name.slice(8, Name.size() - 1); + return StringRef(); + } + + inline StringRef getInvalidateUtilityName(StringRef Name) { + if (Name.endswith(">") && Name.startswith("invalidate<")) + return Name.slice(11, Name.size() - 1); + return StringRef(); + } + + bool hasPassName(StringRef Name) { + if (transformHooks.count(Name) > 0) + return true; + StringRef slice = getRequireUtilityName(Name); + if (!slice.empty()) + if (requireUtilityHooks.count(slice)) + return true; + slice = getInvalidateUtilityName(Name); + if (!slice.empty()) + if (invalidateUtilityHooks.count(slice)) + return true; + return false; + } + + bool parsePassName(StringRef Name, PassManager &PM) { + if (transformHooks.count(Name) > 0) { + transformHooks[Name](PM); + return true; + } + StringRef slice = getRequireUtilityName(Name); + if (!slice.empty()) { + if (requireUtilityHooks.count(slice)) { + requireUtilityHooks[slice](PM); + return true; + } + } + slice = getInvalidateUtilityName(Name); + if (!slice.empty()) { + if (invalidateUtilityHooks.count(slice)) { + invalidateUtilityHooks[slice](PM); + return true; + } + } + return false; + } + + bool parseAAPassName(StringRef Name, AAManager &AA) { + if (!transformHooks.count(Name)) + return false; + AAHooks[Name](AA); + return true; + } + //@} +private: + std::vector analysisHooks; + StringMap transformHooks; + StringMap requireUtilityHooks; + StringMap invalidateUtilityHooks; + StringMap AAHooks; +}; + +/// \brief Return the global hook manager for this IRUnit type. +template extern HookManager &getGlobalHookManager(); +//{@ +/// \brief Convenience types to install the hooks for a pass. Constructing an +/// instance of this will perform the registration of a hook with the given +/// arguments. +template struct RegisterAnalysisPassHook { + RegisterAnalysisPassHook( + StringRef Name, typename HookManager::AnalysisPassHook Hook) { + getGlobalHookManager().template registerAnalysisPassHook( + Name, Hook); + } +}; + +template struct RegisterTransformPassHook { + RegisterTransformPassHook( + StringRef Name, typename HookManager::TransformPassHook Hook) { + getGlobalHookManager().registerTransformPassHook(Name, Hook); + } +}; + +template struct RegisterAAPassHook { + RegisterAAPassHook(StringRef Name, + typename HookManager::AnalysisPassHook Hook, + typename HookManager::AAPassHook AAHook) { + getGlobalHookManager().template registerAnalysisPassHook( + Name, Hook); + getGlobalHookManager().registerAAPassHook(Name, AAHook); + } +}; + +template +struct RegisterAnalysisPass : RegisterAnalysisPassHook { + struct Hook { + void operator()(AnalysisManager &AM) { + AM.registerPass([] { return PassT(); }); + } + }; + RegisterAnalysisPass(StringRef Name) + : RegisterAnalysisPassHook(Name, Hook()) {} +}; + +template +struct RegisterTransformPass : RegisterTransformPassHook { + struct Hook { + void operator()(PassManager &PM) { PM.addPass(PassT()); } + }; + RegisterTransformPass(StringRef Name) + : RegisterTransformPassHook(Name, Hook()) {} +}; + +template struct RegisterAAPass : RegisterAAPassHook { + struct Hook { + void operator()(AnalysisManager &AM) { + AM.registerPass([] { return PassT(); }); + } + void operator()(AAManager &AA) { AA.registerFunctionAnalysis(); } + }; + RegisterAAPass(StringRef Name) + : RegisterAAPassHook(Name, Hook(), Hook()) {} +}; +//@} +} #endif Index: lib/Passes/PassBuilder.cpp =================================================================== --- lib/Passes/PassBuilder.cpp +++ lib/Passes/PassBuilder.cpp @@ -42,6 +42,7 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO/ForceFunctionAttrs.h" @@ -140,25 +141,59 @@ } // End anonymous namespace. +//@{ +/// Global instances of HookManager for various IRUnit types. +static ManagedStatic> ModuleHookManager; +static ManagedStatic> CGSCCHookManager; +static ManagedStatic> FunctionHookManager; +static ManagedStatic> LoopHookManager; +//@} + +//@{ +/// Getters for global HookManager instances +namespace llvm { +template <> HookManager &getGlobalHookManager() { + return *ModuleHookManager; +} +template <> HookManager &getGlobalHookManager() { + return *CGSCCHookManager; +} +template <> HookManager &getGlobalHookManager() { + return *FunctionHookManager; +} +template <> HookManager &getGlobalHookManager() { + return *LoopHookManager; +} +} +//@} + void PassBuilder::registerModuleAnalyses(ModuleAnalysisManager &MAM) { + getGlobalHookManager().registerAnalyses(MAM); + #define MODULE_ANALYSIS(NAME, CREATE_PASS) \ MAM.registerPass([&] { return CREATE_PASS; }); #include "PassRegistry.def" } void PassBuilder::registerCGSCCAnalyses(CGSCCAnalysisManager &CGAM) { + getGlobalHookManager().registerAnalyses(CGAM); + #define CGSCC_ANALYSIS(NAME, CREATE_PASS) \ CGAM.registerPass([&] { return CREATE_PASS; }); #include "PassRegistry.def" } void PassBuilder::registerFunctionAnalyses(FunctionAnalysisManager &FAM) { + getGlobalHookManager().registerAnalyses(FAM); + #define FUNCTION_ANALYSIS(NAME, CREATE_PASS) \ FAM.registerPass([&] { return CREATE_PASS; }); #include "PassRegistry.def" } void PassBuilder::registerLoopAnalyses(LoopAnalysisManager &LAM) { + getGlobalHookManager().registerAnalyses(LAM); + #define LOOP_ANALYSIS(NAME, CREATE_PASS) \ LAM.registerPass([&] { return CREATE_PASS; }); #include "PassRegistry.def" @@ -201,6 +236,9 @@ if (Name.startswith("default") || Name.startswith("lto")) return DefaultAliasRegex.match(Name); + if (getGlobalHookManager().hasPassName(Name)) + return true; + #define MODULE_PASS(NAME, CREATE_PASS) if (Name == NAME) return true; #define MODULE_ANALYSIS(NAME, CREATE_PASS) \ if (Name == "require<" NAME ">" || Name == "invalidate<" NAME ">") \ @@ -212,6 +250,9 @@ #endif static bool isCGSCCPassName(StringRef Name) { + if (getGlobalHookManager().hasPassName(Name)) + return true; + #define CGSCC_PASS(NAME, CREATE_PASS) if (Name == NAME) return true; #define CGSCC_ANALYSIS(NAME, CREATE_PASS) \ if (Name == "require<" NAME ">" || Name == "invalidate<" NAME ">") \ @@ -222,6 +263,9 @@ } static bool isFunctionPassName(StringRef Name) { + if (getGlobalHookManager().hasPassName(Name)) + return true; + #define FUNCTION_PASS(NAME, CREATE_PASS) if (Name == NAME) return true; #define FUNCTION_ANALYSIS(NAME, CREATE_PASS) \ if (Name == "require<" NAME ">" || Name == "invalidate<" NAME ">") \ @@ -232,6 +276,9 @@ } static bool isLoopPassName(StringRef Name) { + if (getGlobalHookManager().hasPassName(Name)) + return true; + #define LOOP_PASS(NAME, CREATE_PASS) if (Name == NAME) return true; #define LOOP_ANALYSIS(NAME, CREATE_PASS) \ if (Name == "require<" NAME ">" || Name == "invalidate<" NAME ">") \ @@ -269,6 +316,9 @@ return true; } + if (getGlobalHookManager().parsePassName(Name, MPM)) + return true; + #define MODULE_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ MPM.addPass(CREATE_PASS); \ @@ -291,6 +341,9 @@ } bool PassBuilder::parseCGSCCPassName(CGSCCPassManager &CGPM, StringRef Name) { + if (getGlobalHookManager().parsePassName(Name, CGPM)) + return true; + #define CGSCC_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ CGPM.addPass(CREATE_PASS); \ @@ -314,6 +367,9 @@ bool PassBuilder::parseFunctionPassName(FunctionPassManager &FPM, StringRef Name) { + if (getGlobalHookManager().parsePassName(Name, FPM)) + return true; + #define FUNCTION_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ FPM.addPass(CREATE_PASS); \ @@ -337,6 +393,9 @@ bool PassBuilder::parseLoopPassName(LoopPassManager &FPM, StringRef Name) { + if (getGlobalHookManager().parsePassName(Name, FPM)) + return true; + #define LOOP_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ FPM.addPass(CREATE_PASS); \ @@ -359,12 +418,16 @@ } bool PassBuilder::parseAAPassName(AAManager &AA, StringRef Name) { + if (getGlobalHookManager().parseAAPassName(Name, AA)) + return true; + #define MODULE_ALIAS_ANALYSIS(NAME, CREATE_PASS) \ if (Name == NAME) { \ AA.registerModuleAnalysis< \ std::remove_reference::type>(); \ return true; \ } + #define FUNCTION_ALIAS_ANALYSIS(NAME, CREATE_PASS) \ if (Name == NAME) { \ AA.registerFunctionAnalysis< \ Index: unittests/IR/CMakeLists.txt =================================================================== --- unittests/IR/CMakeLists.txt +++ unittests/IR/CMakeLists.txt @@ -3,6 +3,7 @@ AsmParser Core Support + Passes ) set(IRSources @@ -12,6 +13,7 @@ ConstantsTest.cpp DebugInfoTest.cpp DominatorTreeTest.cpp + HookManagerTest.cpp IRBuilderTest.cpp InstructionsTest.cpp IntrinsicsTest.cpp Index: unittests/IR/HookManagerTest.cpp =================================================================== --- /dev/null +++ unittests/IR/HookManagerTest.cpp @@ -0,0 +1,75 @@ +//===- llvm/unittest/IR/HookManagerTest.cpp - HookManager unit tests ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +using namespace llvm; + +namespace { +template +struct TestAnalysisPass : public AnalysisBase> { + typedef int Result; + int run(IRUnitT &M) { return 7; } +}; + +template +struct TestTransformPass : public PassBase> { + PreservedAnalyses run(IRUnitT &M, AnalysisManager *AM) { + int R = AM->template getResult>(M); + EXPECT_EQ(R, 7); + PreservedAnalyses P; + P.preserve>(); + return P; + } +}; + +template class HookManagerTest : public testing::Test {}; + +typedef testing::Types PMTypes; + +TYPED_TEST_CASE(HookManagerTest, PMTypes); + +TYPED_TEST(HookManagerTest, TransformPass) { + Module M("TestModule", getGlobalContext()); + PassBuilder PB; + ModulePassManager PM; + ModuleAnalysisManager AM; + + auto &HM = getGlobalHookManager(); + HM.template registerAnalysisPassHook>( + "test-analysis", + typename RegisterAnalysisPass>::Hook()); + HM.registerTransformPassHook( + "test-transform", + typename RegisterTransformPass>::Hook()); + PB.registerModuleAnalyses(AM); + ASSERT_TRUE(PB.parsePassPipeline(PM, "test-transform", true)); +} + +TYPED_TEST(HookManagerTest, AnalysisPass) { + Module M("TestModule", getGlobalContext()); + PassBuilder PB; + ModulePassManager PM; + ModuleAnalysisManager AM; + + auto &HM = getGlobalHookManager(); + HM.template registerAnalysisPassHook>( + "test-analysis", + typename RegisterAnalysisPass>::Hook()); + PB.registerModuleAnalyses(AM); + ASSERT_TRUE(PB.parsePassPipeline( + PM, "require,invalidate", true)); +} +} // end anonymous namespace