diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -192,7 +192,7 @@ /// {1}: The base class for the pass. /// {2): The command line argument for the pass. /// {3}: The dependent dialects registration. -const char *const passDeclBegin = R"( +const char *const baseClassBegin = R"( template class {0}Base : public {1} { public: @@ -243,18 +243,42 @@ registry.insert<{0}>(); )"; -const char *const friendDefaultConstructorTemplate = R"( +const char *const friendDefaultConstructorDeclTemplate = R"( +namespace impl {{ + std::unique_ptr<::mlir::Pass> create{0}(); +} // namespace impl +)"; + +const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"( +namespace impl {{ + std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options); +} // namespace impl +)"; + +const char *const friendDefaultConstructorDefTemplate = R"( friend std::unique_ptr<::mlir::Pass> create{0}() {{ return std::make_unique(); } )"; -const char *const friendDefaultConstructorWithOptionsTemplate = R"( +const char *const friendDefaultConstructorWithOptionsDefTemplate = R"( friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ return std::make_unique(options); } )"; +const char *const defaultConstructorDefTemplate = R"( +std::unique_ptr<::mlir::Pass> create{0}() {{ + return impl::create{0}(); +} +)"; + +const char *const defaultConstructorWithOptionsDefTemplate = R"( +std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ + return impl::create{0}(options); +} +)"; + /// Emit the declarations for each of the pass options. static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { for (const PassOption &opt : pass.getOptions()) { @@ -285,10 +309,20 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) { StringRef passName = pass.getDef()->getName(); std::string enableVarName = "GEN_PASS_DEF_" + passName.upper(); + bool emitDefaultConstructors = pass.getConstructor().empty(); + bool emitDefaultConstructorWithOptions = !pass.getOptions().empty(); os << "#ifdef " << enableVarName << "\n"; os << llvm::formatv(passHeader, passName); + if (emitDefaultConstructors) { + os << llvm::formatv(friendDefaultConstructorDeclTemplate, passName); + + if (emitDefaultConstructorWithOptions) + os << llvm::formatv(friendDefaultConstructorWithOptionsDeclTemplate, + passName); + } + std::string dependentDialectRegistrations; { llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); @@ -297,7 +331,8 @@ dependentDialect); } - os << llvm::formatv(passDeclBegin, passName, pass.getBaseClass(), + os << "namespace impl {\n"; + os << llvm::formatv(baseClassBegin, passName, pass.getBaseClass(), pass.getArgument(), pass.getSummary(), dependentDialectRegistrations); @@ -320,15 +355,23 @@ // Private content os << "private:\n"; - if (pass.getConstructor().empty()) { - os << llvm::formatv(friendDefaultConstructorTemplate, passName); + if (emitDefaultConstructors) { + os << llvm::formatv(friendDefaultConstructorDefTemplate, passName); if (!pass.getOptions().empty()) - os << llvm::formatv(friendDefaultConstructorWithOptionsTemplate, + os << llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate, passName); } os << "};\n"; + os << "} // namespace impl\n"; + + if (emitDefaultConstructors) { + os << llvm::formatv(defaultConstructorDefTemplate, passName); + + if (emitDefaultConstructorWithOptions) + os << llvm::formatv(defaultConstructorWithOptionsDefTemplate, passName); + } os << "#undef " << enableVarName << "\n"; os << "#endif // " << enableVarName << "\n"; diff --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp --- a/mlir/unittests/TableGen/PassGenTest.cpp +++ b/mlir/unittests/TableGen/PassGenTest.cpp @@ -24,7 +24,7 @@ #define GEN_PASS_DEF_TESTPASSWITHCUSTOMCONSTRUCTOR #include "PassGenTest.cpp.inc" -struct TestPass : public TestPassBase { +struct TestPass : public impl::TestPassBase { using TestPassBase::TestPassBase; void runOnOperation() override {} @@ -54,7 +54,7 @@ } struct TestPassWithOptions - : public TestPassWithOptionsBase { + : public impl::TestPassWithOptionsBase { using TestPassWithOptionsBase::TestPassWithOptionsBase; void runOnOperation() override {} @@ -89,7 +89,8 @@ } struct TestPassWithCustomConstructor - : public TestPassWithCustomConstructorBase { + : public impl::TestPassWithCustomConstructorBase< + TestPassWithCustomConstructor> { explicit TestPassWithCustomConstructor(int v) : extraVal(v) {} void runOnOperation() override {}