diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -243,18 +243,30 @@ registry.insert<{0}>(); )"; -const char *const friendDefaultConstructorTemplate = R"( - friend std::unique_ptr<::mlir::Pass> create{0}() {{ +const char *const friendDefaultConstructorDefTemplate = R"( + friend std::unique_ptr<::mlir::Pass> create{0}Impl() {{ return std::make_unique(); } )"; -const char *const friendDefaultConstructorWithOptionsTemplate = R"( - friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ +const char *const friendDefaultConstructorWithOptionsDefTemplate = R"( + friend std::unique_ptr<::mlir::Pass> create{0}Impl(const {0}Options &options) {{ return std::make_unique(options); } )"; +const char *const defaultConstructorDefTemplate = R"( +std::unique_ptr<::mlir::Pass> create{0}() {{ + return create{0}Impl(); +} +)"; + +const char *const defaultConstructorWithOptionsDefTemplate = R"( +std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ + return create{0}Impl(options); +} +)"; + /// Emit the declarations for each of the pass options. static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) { for (const PassOption &opt : pass.getOptions()) { @@ -285,10 +297,22 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) { StringRef passName = pass.getDef()->getName(); std::string enableVarName = "GEN_PASS_DEF_" + passName.upper(); + bool emitDefaultConstructors = pass.getConstructor().empty(); + bool emitDefaultConstructorWithOptions = !pass.getOptions().empty(); os << "#ifdef " << enableVarName << "\n"; os << llvm::formatv(passHeader, passName); + if (emitDefaultConstructors) { + os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}Impl();", + passName); + + if (emitDefaultConstructorWithOptions) + os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}Impl(const " + "{0}Options &options);", + passName); + } + std::string dependentDialectRegistrations; { llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); @@ -320,16 +344,23 @@ // Private content os << "private:\n"; - if (pass.getConstructor().empty()) { - os << llvm::formatv(friendDefaultConstructorTemplate, passName); + if (emitDefaultConstructors) { + os << llvm::formatv(friendDefaultConstructorDefTemplate, passName); if (!pass.getOptions().empty()) - os << llvm::formatv(friendDefaultConstructorWithOptionsTemplate, + os << llvm::formatv(friendDefaultConstructorWithOptionsDefTemplate, passName); } os << "};\n"; + if (emitDefaultConstructors) { + os << llvm::formatv(defaultConstructorDefTemplate, passName); + + if (emitDefaultConstructorWithOptions) + os << llvm::formatv(defaultConstructorWithOptionsDefTemplate, passName); + } + os << "#undef " << enableVarName << "\n"; os << "#endif // " << enableVarName << "\n"; }