Index: include/llvm/Passes/PassBuilder.h =================================================================== --- include/llvm/Passes/PassBuilder.h +++ include/llvm/Passes/PassBuilder.h @@ -17,6 +17,7 @@ #define LLVM_PASSES_PASSBUILDER_H #include "llvm/ADT/Optional.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/IR/PassManager.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" @@ -319,6 +320,219 @@ ArrayRef Pipeline, bool VerifyEachPass, bool DebugLogging); }; -} +/// \brief This class manages the registration of hooks through which +/// out-of-tree or plugin passes are able to register themselves with a +/// \c PassManager or \c AnalysisManager. +/// +/// Such a hook is a Callable with a single argument, the Pass- or +/// AnalysisManager which passes can be added to. Hooks are registered with the +/// manager by name, which is the name used in the pass pipeline text. If the +/// name matches a name of a built-in pass, this name is effectively "stolen", +/// thus overriding the built-in pass. +/// +/// To install a hook in an instance of this class, interfaces are supplied +/// to handle analysis, transform and AA passes, respectively. Installing a hook +/// for an analysis pass also automatically generates default hooks for the +/// require<> and invalidate<> utility passes. +template , + typename PassManagerT = PassManager> +class HookManager { +public: + typedef std::function AnalysisPassHook; + typedef std::function TransformPassHook; + typedef std::function AAPassHook; + + /// \brief Register a hook for a transform pass. + /// + /// If a hook has already been registered with the same Name, this hook is + /// overwritten. + void registerTransformPassHook(StringRef Name, TransformPassHook H) { + transformHooks.insert({Name, H}); + } + + /// \brief Register a hook for an analysis pass. + /// + /// Besides the pass Name, this interface also takes the Pass type as template + /// argument, used to install default hooks for the require<> and invalidate<> + /// utilities. + /// If a hook has already been registered with the same Name, the + /// utilitypasses for that name are overwritten. + template + void registerAnalysisPassHook(StringRef Name, AnalysisPassHook H) { + analysisHooks.push_back(H); + requireUtilityHooks.insert({Name, [](PassManagerT &) { + return RequireAnalysisPass(); + }}); + invalidateUtilityHooks.insert({Name, [](PassManagerT &) { + return InvalidateAnalysisPass(); + }}); + } + + /// \brief Register a hook for an alias analysis pass. + /// + /// If a hook has already been registered with the same Name, this hook is + /// overwritten. + void registerAAPassHook(StringRef Name, AAPassHook H) { + AAHooks.insert({Name, H}); + } + + //@{ + /// Public interface for the \c PassBuilder to query or invoke registered + /// hooks. + void registerAnalyses(AnalysisManagerT &AM) { + for (auto &H : analysisHooks) + H(AM); + } + + inline StringRef getRequireUtilityName(StringRef Name) { + if (Name.endswith(">") && Name.startswith("require<")) + return Name.slice(8, Name.size() - 1); + return StringRef(); + } + + inline StringRef getInvalidateUtilityName(StringRef Name) { + if (Name.endswith(">") && Name.startswith("invalidate<")) + return Name.slice(11, Name.size() - 1); + return StringRef(); + } + + bool hasPassName(StringRef Name) { + if (transformHooks.count(Name) > 0) + return true; + StringRef slice = getRequireUtilityName(Name); + if (!slice.empty()) + if (requireUtilityHooks.count(slice)) + return true; + slice = getInvalidateUtilityName(Name); + if (!slice.empty()) + if (invalidateUtilityHooks.count(slice)) + return true; + return false; + } + + bool parsePassName(StringRef Name, PassManagerT &PM) { + if (transformHooks.count(Name) > 0) { + transformHooks[Name](PM); + return true; + } + StringRef slice = getRequireUtilityName(Name); + if (!slice.empty()) { + if (requireUtilityHooks.count(slice)) { + requireUtilityHooks[slice](PM); + return true; + } + } + slice = getInvalidateUtilityName(Name); + if (!slice.empty()) { + if (invalidateUtilityHooks.count(slice)) { + invalidateUtilityHooks[slice](PM); + return true; + } + } + return false; + } + + bool parseAAPassName(StringRef Name, AAManager &AA) { + if (!transformHooks.count(Name)) + return false; + AAHooks[Name](AA); + return true; + } + //@} +private: + std::vector analysisHooks; + StringMap transformHooks; + StringMap requireUtilityHooks; + StringMap invalidateUtilityHooks; + StringMap AAHooks; +}; + +/// \brief Return the global hook manager for this IRUnit type. +template , + typename PassManagerT = PassManager> +extern HookManager & +getGlobalHookManager(); +//{@ +/// \brief Convenience types to install the hooks for a pass. Constructing an +/// instance of this will perform the registration of a hook with the given +/// arguments. +template , + typename PassManagerT = PassManager> +struct RegisterAnalysisPassHook { + RegisterAnalysisPassHook( + StringRef Name, + typename HookManager::AnalysisPassHook Hook) { + getGlobalHookManager() + .template registerAnalysisPassHook(Name, Hook); + } +}; + +template , + typename PassManagerT = PassManager> +struct RegisterTransformPassHook { + RegisterTransformPassHook( + StringRef Name, + typename HookManager::TransformPassHook Hook) { + getGlobalHookManager() + .registerTransformPassHook(Name, Hook); + } +}; + +template struct RegisterAAPassHook { + RegisterAAPassHook(StringRef Name, + typename HookManager::AnalysisPassHook Hook, + typename HookManager::AAPassHook AAHook) { + getGlobalHookManager().template registerAnalysisPassHook( + Name, Hook); + getGlobalHookManager().registerAAPassHook(Name, AAHook); + } +}; + +template , + typename PassManagerT = PassManager> +struct RegisterAnalysisPass + : RegisterAnalysisPassHook { + struct Hook { + void operator()(AnalysisManagerT &AM) { + AM.registerPass([] { return PassT(); }); + } + }; + RegisterAnalysisPass(StringRef Name) + : RegisterAnalysisPassHook(Name, Hook()) {} +}; + +template , + typename PassManagerT = PassManager> +struct RegisterTransformPass + : RegisterTransformPassHook { + struct Hook { + void operator()(PassManagerT &PM) { PM.addPass(PassT()); } + }; + RegisterTransformPass(StringRef Name) + : RegisterTransformPassHook( + Name, Hook()) {} +}; + +template struct RegisterAAPass : RegisterAAPassHook { + struct Hook { + void operator()(AnalysisManager &AM) { + AM.registerPass([] { return PassT(); }); + } + void operator()(AAManager &AA) { AA.registerFunctionAnalysis(); } + }; + RegisterAAPass(StringRef Name) + : RegisterAAPassHook(Name, Hook(), Hook()) {} +}; +//@} +} #endif Index: lib/Passes/PassBuilder.cpp =================================================================== --- lib/Passes/PassBuilder.cpp +++ lib/Passes/PassBuilder.cpp @@ -57,6 +57,7 @@ #include "llvm/IR/PassManager.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" #include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/GCOVProfiler.h" @@ -258,26 +259,69 @@ } // End anonymous namespace. +//@{ +/// Global instances of HookManager for various IRUnit types. +static ManagedStatic> ModuleHookManager; +static ManagedStatic< + HookManager> + CGSCCHookManager; +static ManagedStatic> FunctionHookManager; +static ManagedStatic> + LoopHookManager; +//@} + +//@{ +/// Getters for global HookManager instances +namespace llvm { +template <> HookManager &getGlobalHookManager() { + return *ModuleHookManager; +} +template <> +HookManager & +getGlobalHookManager() { + return *CGSCCHookManager; +} +template <> HookManager &getGlobalHookManager() { + return *FunctionHookManager; +} +template <> +HookManager & +getGlobalHookManager() { + return *LoopHookManager; +} +} +//@} + void PassBuilder::registerModuleAnalyses(ModuleAnalysisManager &MAM) { -#define MODULE_ANALYSIS(NAME, CREATE_PASS) \ + getGlobalHookManager().registerAnalyses(MAM); + +#define MODULE_ANALYSIS(NAME, CREATE_PASS) \ MAM.registerPass([&] { return CREATE_PASS; }); #include "PassRegistry.def" } void PassBuilder::registerCGSCCAnalyses(CGSCCAnalysisManager &CGAM) { -#define CGSCC_ANALYSIS(NAME, CREATE_PASS) \ + getGlobalHookManager() + .registerAnalyses(CGAM); + +#define CGSCC_ANALYSIS(NAME, CREATE_PASS) \ CGAM.registerPass([&] { return CREATE_PASS; }); #include "PassRegistry.def" } void PassBuilder::registerFunctionAnalyses(FunctionAnalysisManager &FAM) { -#define FUNCTION_ANALYSIS(NAME, CREATE_PASS) \ + getGlobalHookManager().registerAnalyses(FAM); + +#define FUNCTION_ANALYSIS(NAME, CREATE_PASS) \ FAM.registerPass([&] { return CREATE_PASS; }); #include "PassRegistry.def" } void PassBuilder::registerLoopAnalyses(LoopAnalysisManager &LAM) { -#define LOOP_ANALYSIS(NAME, CREATE_PASS) \ + getGlobalHookManager().registerAnalyses(LAM); + +#define LOOP_ANALYSIS(NAME, CREATE_PASS) \ LAM.registerPass([&] { return CREATE_PASS; }); #include "PassRegistry.def" } @@ -868,6 +912,9 @@ if (Name == "function") return true; + if (getGlobalHookManager().hasPassName(Name)) + return true; + // Explicitly handle custom-parsed pass names. if (parseRepeatPassName(Name)) return true; @@ -890,6 +937,11 @@ if (Name == "function") return true; + if (getGlobalHookManager() + .hasPassName(Name)) + return true; + // Explicitly handle custom-parsed pass names. if (parseRepeatPassName(Name)) return true; @@ -914,6 +966,9 @@ if (Name == "loop") return true; + if (getGlobalHookManager().hasPassName(Name)) + return true; + // Explicitly handle custom-parsed pass names. if (parseRepeatPassName(Name)) return true; @@ -934,6 +989,10 @@ if (Name == "loop") return true; + if (getGlobalHookManager() + .hasPassName(Name)) + return true; + // Explicitly handle custom-parsed pass names. if (parseRepeatPassName(Name)) return true; @@ -1080,6 +1139,9 @@ return true; } + if (getGlobalHookManager().parsePassName(Name, MPM)) + return true; + // Finally expand the basic registered passes from the .inc file. #define MODULE_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ @@ -1151,6 +1213,11 @@ return false; } + if (getGlobalHookManager() + .parsePassName(Name, CGPM)) + return true; + // Now expand the basic registered passes from the .inc file. #define CGSCC_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ @@ -1213,6 +1280,9 @@ return false; } + if (getGlobalHookManager().parsePassName(Name, FPM)) + return true; + // Now expand the basic registered passes from the .inc file. #define FUNCTION_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ @@ -1264,6 +1334,10 @@ return false; } + if (getGlobalHookManager() + .parsePassName(Name, LPM)) + return true; + // Now expand the basic registered passes from the .inc file. #define LOOP_PASS(NAME, CREATE_PASS) \ if (Name == NAME) { \ @@ -1289,12 +1363,16 @@ } bool PassBuilder::parseAAPassName(AAManager &AA, StringRef Name) { + if (getGlobalHookManager().parseAAPassName(Name, AA)) + return true; + #define MODULE_ALIAS_ANALYSIS(NAME, CREATE_PASS) \ if (Name == NAME) { \ AA.registerModuleAnalysis< \ std::remove_reference::type>(); \ return true; \ } + #define FUNCTION_ALIAS_ANALYSIS(NAME, CREATE_PASS) \ if (Name == NAME) { \ AA.registerFunctionAnalysis< \ Index: unittests/IR/CMakeLists.txt =================================================================== --- unittests/IR/CMakeLists.txt +++ unittests/IR/CMakeLists.txt @@ -3,6 +3,7 @@ AsmParser Core Support + Passes ) set(IRSources @@ -14,6 +15,7 @@ DebugTypeODRUniquingTest.cpp DominatorTreeTest.cpp FunctionTest.cpp + HookManagerTest.cpp IRBuilderTest.cpp InstructionsTest.cpp IntrinsicsTest.cpp Index: unittests/IR/HookManagerTest.cpp =================================================================== --- /dev/null +++ unittests/IR/HookManagerTest.cpp @@ -0,0 +1,114 @@ +//===- llvm/unittest/IR/HookManagerTest.cpp - HookManager unit tests ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +using namespace llvm; + +namespace { +template +struct TestAnalysisPass : public AnalysisInfoMixin> { + struct Result {}; + template + Result run(IRUnitT &M, AnalysisManagerT &, ExtraTs &&... ExtraArgs) { + return Result(); + } + static AnalysisKey Key; +}; + +template +struct TestTransformPass : public PassInfoMixin> { + template + PreservedAnalyses run(IRUnitT &M, AnalysisManagerT &AM, ResultT &&R, + ExtraTs &&... ExtraArgs) { + AM.template getResult>(M, R); + PreservedAnalyses P; + P.preserve>(); + return P; + } + template + PreservedAnalyses run(IRUnitT &M, AnalysisManagerT &AM) { + AM.template getResult>(M); + PreservedAnalyses P; + P.preserve>(); + return P; + } +}; + +template AnalysisKey TestAnalysisPass::Key; + +template class HookManagerTest : public testing::Test {}; + +typedef testing::Types PMTypes; + +TYPED_TEST_CASE(HookManagerTest, PMTypes); + +template bool TransformPassTest() { + LLVMContext Ctx; + Module M("TestModule", Ctx); + PassBuilder PB; + ModulePassManager PM; + ModuleAnalysisManager AM; + + auto &HM = getGlobalHookManager(); + HM.template registerAnalysisPassHook>( + "test-analysis", + typename RegisterAnalysisPass, + Extra...>::Hook()); + HM.registerTransformPassHook( + "test-transform", + typename RegisterTransformPass, + Extra...>::Hook()); + PB.registerModuleAnalyses(AM); + return PB.parsePassPipeline(PM, "test-transform", true); +} + +TYPED_TEST(HookManagerTest, TransformPass) { + ASSERT_TRUE(TransformPassTest()); +} +TEST(HookManagerTest, LoopTransformPass) { + ASSERT_TRUE( + (TransformPassTest())); +} +TEST(HookManagerTest, SCCTransformPass) { + ASSERT_TRUE((TransformPassTest())); +} + +template bool AnalysisPassTest() { + LLVMContext Ctx; + Module M("TestModule", Ctx); + PassBuilder PB; + ModulePassManager PM; + ModuleAnalysisManager AM; + + auto &HM = getGlobalHookManager(); + HM.template registerAnalysisPassHook>( + "test-analysis", + typename RegisterAnalysisPass>::Hook()); + PB.registerModuleAnalyses(AM); + return PB.parsePassPipeline( + PM, "require,invalidate", true); +} +TYPED_TEST(HookManagerTest, AnalysisPass) { + ASSERT_TRUE(AnalysisPassTest()); +} +TEST(HookManagerTest, LoopAnalysisPass) { + ASSERT_TRUE( + (TransformPassTest())); +} +TEST(HookManagerTest, SCCAnalysisPass) { + ASSERT_TRUE((TransformPassTest())); +} +} // end anonymous namespace