diff --git a/llvm/include/llvm/CodeGen/TargetPassConfig.h b/llvm/include/llvm/CodeGen/TargetPassConfig.h --- a/llvm/include/llvm/CodeGen/TargetPassConfig.h +++ b/llvm/include/llvm/CodeGen/TargetPassConfig.h @@ -16,10 +16,12 @@ #include "llvm/Pass.h" #include "llvm/Support/CodeGen.h" #include +#include #include namespace llvm { +class Target; class LLVMTargetMachine; struct MachineSchedContext; class PassConfigImpl; @@ -82,6 +84,25 @@ /// This is an ImmutablePass solely for the purpose of exposing CodeGen options /// to the internals of other CodeGen passes. class TargetPassConfig : public ImmutablePass { +public: + /// Enum identifying when in the machine pass setup to add an + /// extension via a callback. + enum MachinePassExtensionPointTy { + MPEP_EarlyAsPossible, + MPEP_PreRegAlloc, + MPEP_PostRegAlloc, + MPEP_PreSched2, + MPEP_PreEmitPass, + MPEP_PreEmitPass2, + MPEP_LateAsPossible + }; + + /// Callback function type for added extensions to a target. + using ExtensionFn = std::function; + + /// ID used for removing a previously added target extension + using ExtensionID = int; + private: PassManagerBase *PM = nullptr; AnalysisID StartBefore = nullptr; @@ -343,6 +364,24 @@ /// Returns the CSEConfig object to use for the current optimization level. virtual std::unique_ptr getCSEConfig() const; + /// Add an extension to be applied for the given Target, returning a + /// nonzero ID for use with removeExtension. The preferred way to + /// use this is via the RegisterTargetExtension class defined below. + static ExtensionID addExtension(const Target *Target, + MachinePassExtensionPointTy MPEP, + ExtensionFn Fn); + + /// Remove a previously added extension. The preferred way to use + /// this is via the RegisterTargetExtension class defined below. + static void removeExtension(ExtensionID ID); + + /// Add a pass to the PassManager if that pass is supposed to be run, as + /// determined by the StartAfter and StopAfter options. Takes ownership of the + /// pass. + /// @p verifyAfter if true and adding a machine function pass add an extra + /// machine verification pass afterwards. + void addPass(Pass *P, bool verifyAfter = true); + protected: // Helper to verify the analysis is really immutable. void setOpt(bool &Opt, bool Val); @@ -449,13 +488,6 @@ /// machine verification pass afterwards. AnalysisID addPass(AnalysisID PassID, bool verifyAfter = true); - /// Add a pass to the PassManager if that pass is supposed to be run, as - /// determined by the StartAfter and StopAfter options. Takes ownership of the - /// pass. - /// @p verifyAfter if true and adding a machine function pass add an extra - /// machine verification pass afterwards. - void addPass(Pass *P, bool verifyAfter = true); - /// addMachinePasses helper to create the target-selected or overriden /// regalloc pass. virtual FunctionPass *createRegAllocPass(bool Optimized); @@ -464,11 +496,47 @@ /// and rewriting. \returns true if any passes were added. virtual bool addRegAssignAndRewriteFast(); virtual bool addRegAssignAndRewriteOptimized(); + +private: + /// Scan the set of extensions and call any registered for our + /// Target at the given extension point. + void applyAnyExtensions(MachinePassExtensionPointTy MPEP); }; void registerCodeGenCallback(PassInstrumentationCallbacks &PIC, LLVMTargetMachine &); +/// Registers an extension function for a given target +class RegisterTargetExtension { + TargetPassConfig::ExtensionID ExtensionID; + +public: + RegisterTargetExtension(const Target *Target, + TargetPassConfig::MachinePassExtensionPointTy MPEP, + TargetPassConfig::ExtensionFn Fn) + : + + ExtensionID( + TargetPassConfig::addExtension(Target, MPEP, std::move(Fn))) {} + + ~RegisterTargetExtension() { + if (ExtensionID) + TargetPassConfig::removeExtension(ExtensionID); + } + + // Movable but not copyable + RegisterTargetExtension(RegisterTargetExtension &&Other) + : + + ExtensionID(Other.ExtensionID) { + Other.ExtensionID = 0; + } + + RegisterTargetExtension(const RegisterTargetExtension &) = delete; + RegisterTargetExtension &operator=(const RegisterTargetExtension &) = delete; + RegisterTargetExtension &operator=(RegisterTargetExtension &&) = delete; +}; + } // end namespace llvm #endif // LLVM_CODEGEN_TARGETPASSCONFIG_H diff --git a/llvm/include/llvm/Target/TargetMachine.h b/llvm/include/llvm/Target/TargetMachine.h --- a/llvm/include/llvm/Target/TargetMachine.h +++ b/llvm/include/llvm/Target/TargetMachine.h @@ -33,6 +33,7 @@ class Function; class GlobalValue; +class LLVMTargetMachine; class MachineFunctionPassManager; class MachineFunctionAnalysisManager; class MachineModuleInfoWrapperPass; @@ -155,6 +156,15 @@ return false; } + /// If this is an LLVMTargetMachine then return the downcast + /// pointer, otherwise return nullptr. + virtual LLVMTargetMachine *asLLVMTargetMachine() { return nullptr; } + + /// Const version of asLLVMTargetMachine + const LLVMTargetMachine *asLLVMTargetMachine() const { + return const_cast(this)->asLLVMTargetMachine(); + } + /// This method returns a pointer to the specified type of /// TargetSubtargetInfo. In debug builds, it verifies that the object being /// returned is of the correct type. @@ -393,6 +403,9 @@ void initAsmInfo(); public: + /// Provide safe downcast + LLVMTargetMachine *asLLVMTargetMachine() override { return this; } + /// Get a TargetTransformInfo implementation for the target. /// /// The TTI returned uses the common code generator to answer queries about diff --git a/llvm/lib/CodeGen/TargetPassConfig.cpp b/llvm/lib/CodeGen/TargetPassConfig.cpp --- a/llvm/lib/CodeGen/TargetPassConfig.cpp +++ b/llvm/lib/CodeGen/TargetPassConfig.cpp @@ -40,6 +40,7 @@ #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ManagedStatic.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/Threading.h" #include "llvm/Target/CGPassBuilderOption.h" @@ -49,6 +50,7 @@ #include "llvm/Transforms/Utils/SymbolRewriter.h" #include #include +#include using namespace llvm; @@ -230,6 +232,12 @@ "disable-expand-reductions", cl::init(false), cl::Hidden, cl::desc("Disable the expand reduction intrinsics pass from running")); +using GlobalExtensionTy = + std::tuple; + +static ManagedStatic> GlobalExtensions; + /// Allow standard passes to be disabled by command line options. This supports /// simple binary flags that either suppress the pass or do nothing. /// i.e. -disable-mypass=false has no effect. @@ -671,6 +679,37 @@ FinalPtr.getID() != ID; } +auto TargetPassConfig::addExtension(const Target *Target, + MachinePassExtensionPointTy MPEP, + ExtensionFn Fn) -> ExtensionID { + static ExtensionID Counter = 0; + ExtensionID ID = ++Counter; + GlobalExtensions->push_back(std::make_tuple(Target, MPEP, Fn, ID)); + return ID; +} + +void TargetPassConfig::removeExtension(ExtensionID ID) { + if (!GlobalExtensions.isConstructed()) + return; + + auto I = llvm::find_if(*GlobalExtensions, [ID](const auto &elem) { + return std::get<3>(elem) == ID; + }); + + if (I != GlobalExtensions->end()) + GlobalExtensions->erase(I); +} + +void TargetPassConfig::applyAnyExtensions(MachinePassExtensionPointTy MPEP) { + if (GlobalExtensions.isConstructed() && !GlobalExtensions->empty()) { + const Target *Target = &TM->getTarget(); + for (const auto &Ext : *GlobalExtensions) { + if (std::get<0>(Ext) == Target && std::get<1>(Ext) == MPEP) + std::get<2>(Ext)(*this); + } + } +} + /// Add a pass to the PassManager if that pass is supposed to be run. If the /// Started/Stopped flags indicate either that the compilation should start at /// a later pass or that it should stop after an earlier pass, then do not add @@ -1074,11 +1113,16 @@ /// tied to a common pass. But if it has subtle dependencies on multiple passes, /// the target should override the stage instead. /// +/// External code can customize the passes added here for any target +/// via the addGlobalExtension function. +/// /// TODO: We could use a single addPre/Post(ID) hook to allow pass injection /// before/after any target-independent pass. But it's currently overkill. void TargetPassConfig::addMachinePasses() { AddingMachinePasses = true; + applyAnyExtensions(MPEP_EarlyAsPossible); + // Add passes that optimize machine instructions in SSA form. if (getOptLevel() != CodeGenOpt::None) { addMachineSSAOptimization(); @@ -1093,6 +1137,7 @@ // Run pre-ra passes. addPreRegAlloc(); + applyAnyExtensions(MPEP_PreRegAlloc); // Debugifying the register allocator passes seems to provoke some // non-determinism that affects CodeGen and there doesn't seem to be a point @@ -1107,6 +1152,7 @@ addFastRegAlloc(); // Run post-ra passes. + applyAnyExtensions(MPEP_PostRegAlloc); addPostRegAlloc(); addPass(&FixupStatepointCallerSavedID); @@ -1131,6 +1177,7 @@ // Run pre-sched2 passes. addPreSched2(); + applyAnyExtensions(MPEP_PreSched2); if (EnableImplicitNullChecks) addPass(&ImplicitNullChecksID); @@ -1163,6 +1210,7 @@ addPass(&PatchableFunctionID); addPreEmitPass(); + applyAnyExtensions(MPEP_PreEmitPass); if (TM->Options.EnableIPRA) // Collect register usage information and produce a register mask of @@ -1199,11 +1247,14 @@ // Add passes that directly emit MI after all other MI passes. addPreEmitPass2(); + applyAnyExtensions(MPEP_PreEmitPass2); // Insert pseudo probe annotation for callsite profiling if (TM->Options.PseudoProbeForProfiling) addPass(createPseudoProbeInserter()); + applyAnyExtensions(MPEP_LateAsPossible); + AddingMachinePasses = false; } diff --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt --- a/llvm/unittests/CodeGen/CMakeLists.txt +++ b/llvm/unittests/CodeGen/CMakeLists.txt @@ -28,6 +28,7 @@ ScalableVectorMVTsTest.cpp SelectionDAGAddressAnalysisTest.cpp TypeTraitsTest.cpp + TargetExtensionTest.cpp TargetOptionsTest.cpp TestAsmPrinter.cpp ) diff --git a/llvm/unittests/CodeGen/TargetExtensionTest.cpp b/llvm/unittests/CodeGen/TargetExtensionTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/CodeGen/TargetExtensionTest.cpp @@ -0,0 +1,122 @@ +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "gtest/gtest.h" +#include +#include + +using namespace llvm; + +static std::unique_ptr makeModule(LLVMContext &Context, + const Target &Target) { + auto M = std::make_unique("target-extensions", Context); + M->setTargetTriple(Target.getName()); + return M; +} + +namespace { +struct TargetMachineTest { + const Target &T; + std::unique_ptr Mod; + std::unique_ptr TM; + LLVMTargetMachine *LTM; + unsigned CallCount; + std::vector Registrars; +}; +} // namespace + +static void addMachinePasses(LLVMTargetMachine *LLVMTM) { + legacy::PassManager PM; + TargetPassConfig *TPC(LLVMTM->createPassConfig(PM)); + + auto MMIWMP = new MachineModuleInfoWrapperPass(LLVMTM); + PM.add(TPC); + PM.add(MMIWMP); + TPC->addMachinePasses(); +} + +TEST(TargetExtensionTest, ExtensionOrdering) { + InitializeAllTargets(); + InitializeAllTargetMCs(); + InitializeAllAsmPrinters(); + InitializeAllAsmParsers(); + + LLVMContext Context; + + std::vector TestCases; + + for (const Target &T : TargetRegistry::targets()) { + std::unique_ptr Mod(makeModule(Context, T)); + ASSERT_TRUE(Mod.get() != nullptr); + std::unique_ptr TM(T.createTargetMachine( + Mod->getTargetTriple(), "", "", TargetOptions{}, None)); + ASSERT_TRUE(TM.get() != nullptr); + LLVMTargetMachine *LTM(TM->asLLVMTargetMachine()); + + if (LTM) { + TestCases.emplace_back(TargetMachineTest{ + T, + std::move(Mod), + std::move(TM), + LTM, + static_cast(TargetPassConfig::MPEP_EarlyAsPossible), + {}}); + } + } + + // Register all extensions for all targets + for (TargetMachineTest &TestCase : TestCases) { + for (unsigned MPEP = TargetPassConfig::MPEP_EarlyAsPossible; + MPEP <= TargetPassConfig::MPEP_LateAsPossible; ++MPEP) { + + // Register an extension at each extension point for this machine + // to allow checks for any enums we missed in the implementation + // and they are called in the correct order + TestCase.Registrars.emplace_back( + &TestCase.T, + static_cast(MPEP), + [MPEP, &TestCase](TargetPassConfig &tpc) { + // Check that we're called for the right target machine + EXPECT_EQ(TestCase.TM.get(), &tpc.getTM()) + << "TargetMachine pointer mismatch for target " + << TestCase.T.getName(); + // Check that we are called in expected order for the machine + EXPECT_EQ(TestCase.CallCount, MPEP) + << "Calls out of order for target " << TestCase.T.getName(); + ++TestCase.CallCount; + }); + } + } + + // Check that all extensions are called for all targets + for (TargetMachineTest &TestCase : TestCases) { + EXPECT_EQ(TestCase.CallCount, + static_cast(TargetPassConfig::MPEP_EarlyAsPossible)) + << "Initial callcount mismatch for target " << TestCase.T.getName(); + + addMachinePasses(TestCase.LTM); + + EXPECT_EQ(TestCase.CallCount, + static_cast(TargetPassConfig::MPEP_LateAsPossible) + 1) + << "Check1 callcount mismatch for target " << TestCase.T.getName(); + } + + // Unregister and check extensions are not called again + for (TargetMachineTest &TestCase : TestCases) { + EXPECT_EQ(TestCase.CallCount, + static_cast(TargetPassConfig::MPEP_LateAsPossible) + 1) + << "Check2 callcount mismatch for target " << TestCase.T.getName(); + + TestCase.Registrars.clear(); + + addMachinePasses(TestCase.LTM); + EXPECT_EQ(TestCase.CallCount, + static_cast(TargetPassConfig::MPEP_LateAsPossible) + 1) + << "Final callcount mismatch for target " << TestCase.T.getName(); + } +}