diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -810,7 +810,8 @@ }]; // A constructor must be provided to specify how to create a default instance - // of MyPass. + // of MyPass. It can be skipped for this specific example, because both the + // constructor and the registration methods live in the same namespace. let constructor = "foo::createMyPass()"; // Specify any options. @@ -842,17 +843,23 @@ `-name` input parameter, that registers all of the passes present. ```c++ -// gen-pass-decls -name="Example" +// Tablegen options: -gen-pass-decls -name="Example" +// Passes.h + +namespace foo { #define GEN_PASS_REGISTRATION #include "Passes.h.inc" +} // namespace foo void registerMyPasses() { // Register all of the passes. - registerExamplePasses(); + foo::registerExamplePasses(); + + // Or // Register `MyPass` specifically. - registerMyPass(); + foo::registerMyPass(); } ``` @@ -900,20 +907,22 @@ It generates a base class for each of the passes, containing most of the boiler plate related to pass definitions. These classes are named in the form of -`MyPassBase`, where `MyPass` is the name of the pass definition in tablegen. We -can update the original C++ pass definition as so: +`MyPassBase` and are declared inside the `impl` namespace, where `MyPass` is +the name of the pass definition in tablegen. We can update the original C++ +pass definition as so: ```c++ +// MyPass.cpp + /// Include the generated base pass class definitions. +namespace foo { #define GEN_PASS_DEF_MYPASS #include "Passes.cpp.inc" +} /// Define the main class as deriving from the generated base class. -struct MyPass : MyPassBase { - /// The explicit constructor is no longer explicitly necessary when defining - /// pass options and statistics, the base class takes care of that - /// automatically. - ... +struct MyPass : foo::impl::MyPassBase { + using MyPassBase::MyPassBase; /// The definitions of the options and statistics are now generated within /// the base class, but are accessible in the same way. diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -192,7 +192,7 @@ /// {1}: The base class for the pass. /// {2): The command line argument for the pass. /// {3}: The dependent dialects registration. -const char *const passDeclBegin = R"( +const char *const baseClassBegin = R"( template class {0}Base : public {1} { public: @@ -243,18 +243,42 @@ registry.insert<{0}>(); )"; -const char *const friendDefaultConstructorTemplate = R"( +const char *const friendDefaultConstructorDeclTemplate = R"( +namespace impl {{ + std::unique_ptr<::mlir::Pass> create{0}(); +} // namespace impl +)"; + +const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"( +namespace impl {{ + std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options); +} // namespace impl +)"; + +const char *const friendDefaultConstructorDefTemplate = R"( friend std::unique_ptr<::mlir::Pass> create{0}() {{ return std::make_unique(); } )"; -const char *const friendDefaultConstructorWithOptionsTemplate = R"( +const char *const friendDefaultConstructorWithOptionsDefTemplate = R"( friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ return std::make_unique(options); } )"; +const char *const defaultConstructorDefTemplate = R"( +std::unique_ptr<::mlir::Pass> create{0}() {{ + return impl::create{0}(); +} +)"; + +const char *const defaultConstructorWithOptionsDefTemplate = R"( +std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ + return impl::create{0}(options); +} +)"; + /// Emit the declarations for each of the pass options. static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { for (const PassOption &opt : pass.getOptions()) { @@ -285,10 +309,20 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) { StringRef passName = pass.getDef()->getName(); std::string enableVarName = "GEN_PASS_DEF_" + passName.upper(); + bool emitDefaultConstructors = pass.getConstructor().empty(); + bool emitDefaultConstructorWithOptions = !pass.getOptions().empty(); os << "#ifdef " << enableVarName << "\n"; os << llvm::formatv(passHeader, passName); + if (emitDefaultConstructors) { + os << llvm::formatv(friendDefaultConstructorDeclTemplate, passName); + + if (emitDefaultConstructorWithOptions) + os << llvm::formatv(friendDefaultConstructorWithOptionsDeclTemplate, + passName); + } + std::string dependentDialectRegistrations; { llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); @@ -297,7 +331,8 @@ dependentDialect); } - os << llvm::formatv(passDeclBegin, passName, pass.getBaseClass(), + os << "namespace impl {\n"; + os << llvm::formatv(baseClassBegin, passName, pass.getBaseClass(), pass.getArgument(), pass.getSummary(), dependentDialectRegistrations); @@ -320,15 +355,23 @@ // Private content os << "private:\n"; - if (pass.getConstructor().empty()) { - os << llvm::formatv(friendDefaultConstructorTemplate, passName); + if (emitDefaultConstructors) { + os << llvm::formatv(friendDefaultConstructorDefTemplate, passName); if (!pass.getOptions().empty()) - os << llvm::formatv(friendDefaultConstructorWithOptionsTemplate, + os << llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate, passName); } os << "};\n"; + os << "} // namespace impl\n"; + + if (emitDefaultConstructors) { + os << llvm::formatv(defaultConstructorDefTemplate, passName); + + if (emitDefaultConstructorWithOptions) + os << llvm::formatv(defaultConstructorWithOptionsDefTemplate, passName); + } os << "#undef " << enableVarName << "\n"; os << "#endif // " << enableVarName << "\n"; diff --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp --- a/mlir/unittests/TableGen/PassGenTest.cpp +++ b/mlir/unittests/TableGen/PassGenTest.cpp @@ -24,7 +24,7 @@ #define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR #include "PassGenTest.cpp.inc" -struct TestPass : public TestPassBase { +struct TestPass : public impl::TestPassBase { using TestPassBase::TestPassBase; void runOnOperation() override {} @@ -54,7 +54,7 @@ } struct TestPassWithOptions - : public TestPassWithOptionsBase { + : public impl::TestPassWithOptionsBase { using TestPassWithOptionsBase::TestPassWithOptionsBase; void runOnOperation() override {} @@ -89,7 +89,8 @@ } struct TestPassWithCustomConstructor - : public TestPassWithCustomConstructorBase { + : public impl::TestPassWithCustomConstructorBase< + TestPassWithCustomConstructor> { explicit TestPassWithCustomConstructor(int v) : extraVal(v) {} void runOnOperation() override {}