diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td --- a/mlir/include/mlir/Pass/PassBase.td +++ b/mlir/include/mlir/Pass/PassBase.td @@ -76,7 +76,7 @@ string description = ""; // A C++ constructor call to create an instance of this pass. - code constructor = [{}]; + code constructor = ?; // A list of dialects this pass may produce entities in. list dependentDialects = []; diff --git a/mlir/include/mlir/TableGen/Pass.h b/mlir/include/mlir/TableGen/Pass.h --- a/mlir/include/mlir/TableGen/Pass.h +++ b/mlir/include/mlir/TableGen/Pass.h @@ -92,7 +92,7 @@ StringRef getDescription() const; /// Return the C++ constructor call to create an instance of this pass. - StringRef getConstructor() const; + Optional getConstructor() const; /// Return the dialects this pass needs to be registered. ArrayRef getDependentDialects() const; diff --git a/mlir/lib/TableGen/Pass.cpp b/mlir/lib/TableGen/Pass.cpp --- a/mlir/lib/TableGen/Pass.cpp +++ b/mlir/lib/TableGen/Pass.cpp @@ -87,9 +87,10 @@ return def->getValueAsString("description"); } -StringRef Pass::getConstructor() const { - return def->getValueAsString("constructor"); +Optional Pass::getConstructor() const { + return def->getValueAsOptionalString("constructor"); } + ArrayRef Pass::getDependentDialects() const { return dependentDialects; } diff --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp --- a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp +++ b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp @@ -90,6 +90,15 @@ } )"; +/// Emit the call to create an instance of the pass. +static void emitConstructorCall(const Pass &pass, raw_ostream &os) { + if (auto constructor = pass.getConstructor()) { + os << *constructor; + } else { + os << llvm::formatv("create{0}Pass()", pass.getDef()->getName()); + } +} + static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) { os << "/* Autogenerated by mlir-tblgen; don't manually edit. */"; os << llvm::formatv(passGroupRegistrationCode, groupName); @@ -97,8 +106,12 @@ for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { Pass pass(def); StringRef defName = pass.getDef()->getName(); - os << llvm::formatv(passCreateDef, groupName, defName, - pass.getConstructor()); + + std::string constructorCall; + llvm::raw_string_ostream constructorCallOs(constructorCall); + emitConstructorCall(pass, constructorCallOs); + + os << llvm::formatv(passCreateDef, groupName, defName, constructorCall); } return false; } diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -27,6 +27,147 @@ groupName("name", llvm::cl::desc("The name of this group of passes"), llvm::cl::cat(passGenCat)); +static void emitOldPassDecl(const Pass& pass, raw_ostream &os); + +/// Extract the list of passes from the TableGen records. +std::vector getPasses(const llvm::RecordKeeper &recordKeeper) { + std::vector passes; + + for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase")) + passes.emplace_back(def); + + return passes; +} + +const char *const passHeader = R"( +//===----------------------------------------------------------------------===// +// {0} +//===----------------------------------------------------------------------===// +)"; + +//===----------------------------------------------------------------------===// +// GEN: Pass registration generation +//===----------------------------------------------------------------------===// + +/// The code snippet used to generate a pass registration. +/// +/// {0}: The def name of the pass record. +/// {1}: The pass constructor call. +const char *const passRegistrationCode = R"( +//===----------------------------------------------------------------------===// +// {0} Registration +//===----------------------------------------------------------------------===// + +inline void register{0}Pass() {{ + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ + return {1}; + }); +} +)"; + +/// The code snippet used to generate a function to register all passes in a +/// group. +/// +/// {0}: The name of the pass group. +const char *const passGroupRegistrationCode = R"( +//===----------------------------------------------------------------------===// +// {0} Registration +//===----------------------------------------------------------------------===// + +inline void register{0}Passes() {{ +)"; + +/// Emits the definition of the struct to be used to control the pass options. +static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) { + StringRef passName = pass.getDef()->getName(); + ArrayRef options = pass.getOptions(); + + os << llvm::formatv("struct {0}PassOptions {{\n", passName); + + for (const PassOption &opt : options) { + os.indent(2) << llvm::formatv("{0} {1}", opt.getType(), + opt.getCppVariableName()); + + if (Optional defaultVal = opt.getDefaultValue()) + os << " = " << defaultVal; + + os << ";\n"; + } + + os << "};\n"; +} + +/// Emit the code to be included in the public header of the pass. +static void emitPassDecls(const Pass &pass, raw_ostream &os) { + StringRef passName = pass.getDef()->getName(); + + os << "#ifdef GEN_PASS_DECL_" << passName << "\n"; + os << llvm::formatv(passHeader, passName); + + emitPassOptionsStruct(pass, os); + + if (!pass.getConstructor()) { + // Default constructor declaration. + os << "std::unique_ptr<::mlir::Pass> create" << passName << "Pass();\n"; + + // Declaration of the constructor with options. + os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}Pass(const {0}PassOptions &options);\n", passName); + } + + os << "#undef GEN_PASS_DECL_" << passName << "\n"; + os << "#endif // GEN_PASS_DECL_" << passName << "\n"; +} + +/// Emit the call to create an instance of the pass. +static void emitConstructorCall(const Pass &pass, raw_ostream &os) { + if (auto constructor = pass.getConstructor()) { + os << *constructor; + } else { + os << llvm::formatv("create{0}Pass()", pass.getDef()->getName()); + } +} + +/// Emit the code for registering each of the given passes with the global +/// PassRegistry. +static void emitRegistrations(llvm::ArrayRef passes, raw_ostream &os) { + os << "#ifdef GEN_PASS_REGISTRATION\n"; + + for (const Pass &pass : passes) { + std::string constructorCall; + llvm::raw_string_ostream constructorCallOs(constructorCall); + emitConstructorCall(pass, constructorCallOs); + + os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(), + constructorCall); + } + + os << llvm::formatv(passGroupRegistrationCode, groupName); + + for (const Pass &pass : passes) + os << " register" << pass.getDef()->getName() << "Pass();\n"; + + os << "}\n"; + os << "#undef GEN_PASS_REGISTRATION\n"; + os << "#endif // GEN_PASS_REGISTRATION\n"; +} + +static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { + std::vector passes = getPasses(recordKeeper); + os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; + + for (const Pass &pass : passes) + emitPassDecls(pass, os); + + emitRegistrations(passes, os); + + // Emit the old code until all the passes have switched to the new design. + os << "#ifdef GEN_PASS_CLASSES\n"; + for (const Pass &pass : passes) + emitOldPassDecl(pass, os); + os << "#undef GEN_PASS_CLASSES\n"; + os << "#endif // GEN_PASS_CLASSES\n"; +} + //===----------------------------------------------------------------------===// // GEN: Pass base class generation //===----------------------------------------------------------------------===// @@ -38,10 +179,6 @@ /// {2): The command line argument for the pass. /// {3}: The dependent dialects registration. const char *const passDeclBegin = R"( -//===----------------------------------------------------------------------===// -// {0} -//===----------------------------------------------------------------------===// - template class {0}Base : public {1} { public: @@ -79,6 +216,10 @@ {4} } + {0}Base(const {0}PassOptions &options) : {0}Base() {{ + {5} + } + /// Explicitly declare the TypeID for this class. We declare an explicit private /// instantiation because Pass classes should only be visible by the current /// library. @@ -119,8 +260,13 @@ } } -static void emitPassDecl(const Pass &pass, raw_ostream &os) { +/// Emit the code to be used in the implementation of the pass. +static void emitPassDefs(const Pass &pass, raw_ostream &os) { StringRef defName = pass.getDef()->getName(); + + os << "#ifdef GEN_PASS_DEF_" << defName << "\n"; + os << llvm::formatv(passHeader, defName); + std::string dependentDialectRegistrations; { llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); @@ -128,90 +274,110 @@ dialectsOs << llvm::formatv(dialectRegistrationTemplate, dependentDialect); } + + std::string optionsConstructor; + { + llvm::raw_string_ostream optionsConstructorOs(optionsConstructor); + for (const PassOption &opt : pass.getOptions()) + optionsConstructorOs << llvm::formatv("{0} = options.{0};\n", opt.getCppVariableName()); + } + os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(), pass.getArgument(), pass.getSummary(), - dependentDialectRegistrations); + dependentDialectRegistrations, optionsConstructor); + emitPassOptionDecls(pass, os); emitPassStatisticDecls(pass, os); os << "};\n"; + + os << "#undef GEN_PASS_DEF_" << defName << "\n"; + os << "#endif // GEN_PASS_DEF_" << defName << "\n"; } -/// Emit the code for registering each of the given passes with the global -/// PassRegistry. -static void emitPassDecls(ArrayRef passes, raw_ostream &os) { - os << "#ifdef GEN_PASS_CLASSES\n"; +static void emitDefs(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { + std::vector passes = getPasses(recordKeeper); + os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; + for (const Pass &pass : passes) - emitPassDecl(pass, os); - os << "#undef GEN_PASS_CLASSES\n"; - os << "#endif // GEN_PASS_CLASSES\n"; + emitPassDefs(pass, os); } -//===----------------------------------------------------------------------===// -// GEN: Pass registration generation -//===----------------------------------------------------------------------===// +// The old pass base class is being kept until all the passes have switched to +// the new decls/defs design. +const char *const oldPassDeclBegin = R"( +template +class {0}Base : public {1} { +public: + using Base = {0}Base; -/// The code snippet used to generate a pass registration. -/// -/// {0}: The def name of the pass record. -/// {1}: The pass constructor call. -const char *const passRegistrationCode = R"( -//===----------------------------------------------------------------------===// -// {0} Registration -//===----------------------------------------------------------------------===// + {0}Base() : {1}(::mlir::TypeID::get()) {{} + {0}Base(const {0}Base &other) : {1}(other) {{} -inline void register{0}Pass() {{ - ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {{ - return {1}; - }); -} -)"; + /// Returns the command-line argument attached to this pass. + static constexpr ::llvm::StringLiteral getArgumentName() { + return ::llvm::StringLiteral("{2}"); + } + ::llvm::StringRef getArgument() const override { return "{2}"; } -/// The code snippet used to generate a function to register all passes in a -/// group. -/// -/// {0}: The name of the pass group. -const char *const passGroupRegistrationCode = R"( -//===----------------------------------------------------------------------===// -// {0} Registration -//===----------------------------------------------------------------------===// + ::llvm::StringRef getDescription() const override { return "{3}"; } -inline void register{0}Passes() {{ -)"; + /// Returns the derived pass name. + static constexpr ::llvm::StringLiteral getPassName() { + return ::llvm::StringLiteral("{0}"); + } + ::llvm::StringRef getName() const override { return "{0}"; } -/// Emit the code for registering each of the given passes with the global -/// PassRegistry. -static void emitRegistration(ArrayRef passes, raw_ostream &os) { - os << "#ifdef GEN_PASS_REGISTRATION\n"; - for (const Pass &pass : passes) { - os << llvm::formatv(passRegistrationCode, pass.getDef()->getName(), - pass.getConstructor()); + /// Support isa/dyn_cast functionality for the derived pass class. + static bool classof(const ::mlir::Pass *pass) {{ + return pass->getTypeID() == ::mlir::TypeID::get(); } - os << llvm::formatv(passGroupRegistrationCode, groupName); - for (const Pass &pass : passes) - os << " register" << pass.getDef()->getName() << "Pass();\n"; - os << "}\n"; - os << "#undef GEN_PASS_REGISTRATION\n"; - os << "#endif // GEN_PASS_REGISTRATION\n"; -} + /// A clone method to create a copy of this pass. + std::unique_ptr<::mlir::Pass> clonePass() const override {{ + return std::make_unique(*static_cast(this)); + } -//===----------------------------------------------------------------------===// -// GEN: Registration hooks -//===----------------------------------------------------------------------===// + /// Return the dialect that must be loaded in the context before this pass. + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + {4} + } -static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) { - os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n"; - std::vector passes; - for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase")) - passes.emplace_back(def); + /// Explicitly declare the TypeID for this class. We declare an explicit private + /// instantiation because Pass classes should only be visible by the current + /// library. + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID({0}Base) + +protected: +)"; - emitPassDecls(passes, os); - emitRegistration(passes, os); +/// Emit a backward-compatible declaration of the pass base class. +static void emitOldPassDecl(const Pass &pass, raw_ostream &os) { + StringRef defName = pass.getDef()->getName(); + std::string dependentDialectRegistrations; + { + llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); + for (StringRef dependentDialect : pass.getDependentDialects()) + dialectsOs << llvm::formatv(dialectRegistrationTemplate, + dependentDialect); + } + os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(), + pass.getArgument(), pass.getSummary(), + dependentDialectRegistrations); + emitPassOptionDecls(pass, os); + emitPassStatisticDecls(pass, os); + os << "};\n"; } static mlir::GenRegistration - genRegister("gen-pass-decls", "Generate pass declarations", + genPassDecls("gen-pass-decls", "Generate pass declarations", + [](const llvm::RecordKeeper &records, raw_ostream &os) { + emitDecls(records, os); + return false; + }); + +static mlir::GenRegistration + genPassDefs("gen-pass-defs", "Generate pass definitions", [](const llvm::RecordKeeper &records, raw_ostream &os) { - emitDecls(records, os); + emitDefs(records, os); return false; }); diff --git a/mlir/unittests/TableGen/CMakeLists.txt b/mlir/unittests/TableGen/CMakeLists.txt --- a/mlir/unittests/TableGen/CMakeLists.txt +++ b/mlir/unittests/TableGen/CMakeLists.txt @@ -5,6 +5,7 @@ set(LLVM_TARGET_DEFINITIONS passes.td) mlir_tablegen(PassGenTest.h.inc -gen-pass-decls -name TableGenTest) +mlir_tablegen(PassGenTest.cpp.inc -gen-pass-defs -name TableGenTest) add_public_tablegen_target(MLIRTableGenTestPassIncGen) add_mlir_unittest(MLIRTableGenTests diff --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp --- a/mlir/unittests/TableGen/PassGenTest.cpp +++ b/mlir/unittests/TableGen/PassGenTest.cpp @@ -12,21 +12,27 @@ std::unique_ptr createTestPass(int v = 0); +#define GEN_PASS_DECL_TestPass #define GEN_PASS_REGISTRATION #include "PassGenTest.h.inc" -#define GEN_PASS_CLASSES -#include "PassGenTest.h.inc" +#define GEN_PASS_DEF_TestPass +#include "PassGenTest.cpp.inc" struct TestPass : public TestPassBase { explicit TestPass(int v) : extraVal(v) {} + TestPass(int v, const TestPassPassOptions &options) + : TestPassBase(options), extraVal(v) {} + void runOnOperation() override {} std::unique_ptr clone() const { return TestPassBase::clone(); } + unsigned getTestOption() const { return testOption; } + int extraVal; }; @@ -34,6 +40,11 @@ return std::make_unique(v); } +std::unique_ptr createTestPass(int v, + const TestPassPassOptions &options) { + return std::make_unique(v, options); +} + TEST(PassGenTest, PassClone) { mlir::MLIRContext context; @@ -46,3 +57,18 @@ EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal); } + +TEST(PassGenTest, PassOptions) { + mlir::MLIRContext context; + + TestPassPassOptions options; + options.testOption = 57; + + const auto unwrap = [](const std::unique_ptr &pass) { + return static_cast(pass.get()); + }; + + const auto pass = createTestPass(10, options); + + EXPECT_EQ(unwrap(pass)->getTestOption(), 57); +} diff --git a/mlir/unittests/TableGen/passes.td b/mlir/unittests/TableGen/passes.td --- a/mlir/unittests/TableGen/passes.td +++ b/mlir/unittests/TableGen/passes.td @@ -15,5 +15,7 @@ let constructor = "::createTestPass()"; - let options = RewritePassUtils.options; + let options = [ + Option<"testOption", "testOption", "unsigned", "0", "Test option"> + ]; }