Index: include/llvm/Passes/PassBuilder.h =================================================================== --- include/llvm/Passes/PassBuilder.h +++ include/llvm/Passes/PassBuilder.h @@ -100,6 +100,147 @@ bool VerifyEachPass, bool DebugLogging); }; -} +/// \brief Convenience Type used to install a default hook, adding an analysis +/// or transform pass using its default constructor. +template struct DefaultHook { + template void operator()(AnalysisManager &AM) { + AM.registerPass(PassT()); + } + template void operator()(PassManager &PM) { + PM.addPass(PassT()); + } +}; + +/// \brief This class manages the registration of hooks through which +/// out-of-tree or plugin passes are able to register themselves with a +/// \c PassManager or \c AnalysisManager. +/// +/// Such a hook is a callable with a single argument, the Pass- or +/// AnalysisManager which passes can be added to. Hooks are registered with the +/// manager by name, which is the name used in the pass pipeline text. If the +/// name matches a name of a built-in pass, this name is effectively "stolen", +/// thus overriding the built-in pass. +/// +/// To install a hook in an instance of this class, two interfaces are supplied +/// to handle analysis and transform passes, respectively. Installing a hook for +/// an analysis pass also automatically generates default hooks for the +/// require<> and invalidate<> utility passes. +template class HookManager { +public: + typedef std::function &)> AnalysisPassHook; + typedef std::function &)> TransformPassHook; + + /// \brief Register a hook for a transform pass. + /// + /// If a hook has already been registered with the same Name, this hook is + /// overwritten. + void registerTransformPassHook(StringRef Name, TransformPassHook H) { + transformHooks.insert({Name, H}); + } + + /// \brief Register a hook for an analysis pass. + /// + /// Besides the pass Name, this interface also takes the Pass type as template + /// argument, used to install default hooks for the require<> and invalidate<> + /// utilities. + /// If a hook has already been registered with the same Name, the + /// utilitypasses for that name are overwritten. + template + void registerAnalysisPassHook(StringRef Name, AnalysisPassHook H) { + analysisHooks.push_back(H); + requireUtilityHooks.insert( + {Name, DefaultHook>()}); + invalidateUtilityHooks.insert( + {Name, DefaultHook>()}); + } + + //@{ + /// Public interface for the \c PassBuilder to query or invoke registered + /// hooks. + void registerAnalyses(AnalysisManager &AM) { + for (auto &H : analysisHooks) + H(AM); + } + + inline StringRef getRequireUtilityName(StringRef Name) { + if (Name.endswith(">") && Name.startswith("require<")) + return Name.slice(8, Name.size() - 1); + return StringRef(); + } + inline StringRef getInvalidateUtilityName(StringRef Name) { + if (Name.endswith(">") && Name.startswith("invalidate<")) + return Name.slice(11, Name.size() - 1); + return StringRef(); + } + + bool hasPassName(StringRef Name) { + if (transformHooks.count(Name) > 0) + return true; + StringRef slice = getRequireUtilityName(Name); + if (!slice.empty()) + if (requireUtilityHooks.count(slice)) + return true; + slice = getInvalidateUtilityName(Name); + if (!slice.empty()) + if (invalidateUtilityHooks.count(slice)) + return true; + return false; + } + + bool parsePassName(StringRef Name, PassManager &PM) { + if (transformHooks.count(Name) > 0) { + transformHooks[Name](PM); + return true; + } + StringRef slice = getRequireUtilityName(Name); + if (!slice.empty()) { + if (requireUtilityHooks.count(slice)) { + requireUtilityHooks[slice](PM); + return true; + } + } + slice = getInvalidateUtilityName(Name); + if (!slice.empty()) { + if (invalidateUtilityHooks.count(slice)) { + invalidateUtilityHooks[slice](PM); + return true; + } + } + return false; + } + //@} +private: + std::vector analysisHooks; + StringMap transformHooks; + StringMap requireUtilityHooks; + StringMap invalidateUtilityHooks; +}; + +/// \brief Return the global hook manager for this IRUnit type. +template extern HookManager &getGlobalHookManager(); + +/// \brief Convenience type to install the hooks for a pass. Constructing an +/// instance of this will perform the registration of a hook with the given +/// arguments. +template struct RegisterAnalysisPassHook { + RegisterAnalysisPassHook( + StringRef Name, typename HookManager::AnalysisPassHook Hook) { + getGlobalHookManager().template registerAnalysisPassHook( + Name, Hook); + } +}; + +/// \brief Convenience type to install the hooks for a pass. Constructing an +/// instance of this will perform the registration of a hook with the given +/// arguments. +template struct RegisterTransformPassHook { + RegisterTransformPassHook( + StringRef Name, typename HookManager::TransformPassHook Hook) { + getGlobalHookManager().registerTransformPassHook(Name, Hook); + } +}; + + +} #endif Index: lib/Passes/PassBuilder.cpp =================================================================== --- lib/Passes/PassBuilder.cpp +++ lib/Passes/PassBuilder.cpp @@ -27,6 +27,7 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/InstCombine/InstCombine.h" #include "llvm/Transforms/Scalar/EarlyCSE.h" @@ -95,19 +96,47 @@ } // End anonymous namespace. +//@{ +/// Global instances of HookManager for various IRUnit types. +static ManagedStatic> ModuleHookManager; +static ManagedStatic> CGSCCHookManager; +static ManagedStatic> FunctionHookManager; +//@} + +//@{ +/// Getters for global HookManager instances +namespace llvm { +template <> HookManager &getGlobalHookManager() { + return *ModuleHookManager; +} +template <> HookManager &getGlobalHookManager() { + return *CGSCCHookManager; +} +template <> HookManager &getGlobalHookManager() { + return *FunctionHookManager; +} +} +//@} + void PassBuilder::registerModuleAnalyses(ModuleAnalysisManager &MAM) { + getGlobalHookManager().registerAnalyses(MAM); + #define MODULE_ANALYSIS(NAME, CREATE_PASS) \ MAM.registerPass(CREATE_PASS); #include "PassRegistry.def" } void PassBuilder::registerCGSCCAnalyses(CGSCCAnalysisManager &CGAM) { + getGlobalHookManager().registerAnalyses(CGAM); + #define CGSCC_ANALYSIS(NAME, CREATE_PASS) \ CGAM.registerPass(CREATE_PASS); #include "PassRegistry.def" } void PassBuilder::registerFunctionAnalyses(FunctionAnalysisManager &FAM) { + getGlobalHookManager().registerAnalyses(FAM); + #define FUNCTION_ANALYSIS(NAME, CREATE_PASS) \ FAM.registerPass(CREATE_PASS); #include "PassRegistry.def" @@ -115,6 +144,9 @@ #ifndef NDEBUG static bool isModulePassName(StringRef Name) { + if (getGlobalHookManager().hasPassName(Name)) + return true; + #define MODULE_PASS(NAME, CREATE_PASS) if (Name == NAME) return true; #define MODULE_ANALYSIS(NAME, CREATE_PASS) \ if (Name == "require<" NAME ">" || Name == "invalidate<" NAME ">") \ @@ -126,6 +158,9 @@ #endif static bool isCGSCCPassName(StringRef Name) { + if (getGlobalHookManager().hasPassName(Name)) + return true; + #define CGSCC_PASS(NAME, CREATE_PASS) if (Name == NAME) return true; #define CGSCC_ANALYSIS(NAME, CREATE_PASS) \ if (Name == "require<" NAME ">" || Name == "invalidate<" NAME ">") \ @@ -136,6 +171,9 @@ } static bool isFunctionPassName(StringRef Name) { + if (getGlobalHookManager().hasPassName(Name)) + return true; + #define FUNCTION_PASS(NAME, CREATE_PASS) if (Name == NAME) return true; #define FUNCTION_ANALYSIS(NAME, CREATE_PASS) \ if (Name == "require<" NAME ">" || Name == "invalidate<" NAME ">") \ @@ -146,6 +184,9 @@ } bool PassBuilder::parseModulePassName(ModulePassManager &MPM, StringRef Name) { + if (getGlobalHookManager().parsePassName(Name, MPM)) + return true; + #define MODULE_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ MPM.addPass(CREATE_PASS); \ @@ -166,6 +207,9 @@ } bool PassBuilder::parseCGSCCPassName(CGSCCPassManager &CGPM, StringRef Name) { + if (getGlobalHookManager().parsePassName(Name, CGPM)) + return true; + #define CGSCC_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ CGPM.addPass(CREATE_PASS); \ @@ -187,6 +231,9 @@ bool PassBuilder::parseFunctionPassName(FunctionPassManager &FPM, StringRef Name) { + if (getGlobalHookManager().parsePassName(Name, FPM)) + return true; + #define FUNCTION_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ FPM.addPass(CREATE_PASS); \ Index: unittests/IR/CMakeLists.txt =================================================================== --- unittests/IR/CMakeLists.txt +++ unittests/IR/CMakeLists.txt @@ -4,6 +4,7 @@ Core IPA Support + Passes ) set(IRSources @@ -12,6 +13,7 @@ ConstantsTest.cpp DebugInfoTest.cpp DominatorTreeTest.cpp + HookManagerTest.cpp IRBuilderTest.cpp InstructionsTest.cpp LegacyPassManagerTest.cpp Index: unittests/IR/HookManagerTest.cpp =================================================================== --- /dev/null +++ unittests/IR/HookManagerTest.cpp @@ -0,0 +1,66 @@ +//===- llvm/unittest/IR/HookManagerTest.cpp - HookManager unit tests ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include + +using namespace llvm; + +namespace { +struct TestAnalysisPass { + typedef int Result; + static char PassID; + static StringRef name() { return "TestAnalysisPass"; } + static void *ID() { return &PassID; } + int run(Module &M) { return 7; } +}; +struct TestTransformPass { + static StringRef name() { return "TestTransformPass"; } + PreservedAnalyses run(Module &M, ModuleAnalysisManager *AM) { + int R = AM->getResult(M); + EXPECT_EQ(R, 7); + PreservedAnalyses P; + P.preserve(); + return P; + } +}; + +char TestAnalysisPass::PassID; + +TEST(HookManager, TransformPass) { + PassBuilder PB; + ModuleAnalysisManager AM; + ModulePassManager PM; + Module M("TestM", getGlobalContext()); + + auto &HM = getGlobalHookManager(); + HM.registerAnalysisPassHook( + "test-analysis", DefaultHook()); + HM.registerTransformPassHook("test-transform", + DefaultHook()); + PB.registerModuleAnalyses(AM); + ASSERT_TRUE(PB.parsePassPipeline(PM, "test-transform", true)); + ASSERT_TRUE(PM.run(M, &AM).preserved()); +} + +TEST(HookManager, AnalysisPass) { + PassBuilder PB; + ModuleAnalysisManager AM; + ModulePassManager PM; + Module M("TestM", getGlobalContext()); + + auto &HM = getGlobalHookManager(); + HM.registerAnalysisPassHook( + "test-analysis", DefaultHook()); + PB.registerModuleAnalyses(AM); + ASSERT_TRUE(PB.parsePassPipeline( + PM, "require,invalidate", true)); +} +} // end namespace llvm