diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -110,6 +110,25 @@ } } +/// Emit the declaration of the struct to be used for providing custom options +/// through the C++ APIs. +static void emitPassOptionStructDecl(const Pass &pass, raw_ostream &os) { + StringRef defName = pass.getDef()->getName(); + os << llvm::formatv("struct {0}BaseOptions {{\n", defName); + + for (const PassOption &opt : pass.getOptions()) { + os.indent(4) << llvm::formatv("{0} {1}", opt.getType(), + opt.getCppVariableName()); + + if (Optional defaultVal = opt.getDefaultValue()) + os << " = " << defaultVal; + + os << ";\n"; + } + + os.indent(2) << "};\n"; +} + /// Emit the declarations for each of the pass statistics. static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) { for (const PassStatistic &stat : pass.getStatistics()) { @@ -121,6 +140,10 @@ static void emitPassDecl(const Pass &pass, raw_ostream &os) { StringRef defName = pass.getDef()->getName(); + + // Declare the user-editable structure containing the options. + emitPassOptionStructDecl(pass, os); + std::string dependentDialectRegistrations; { llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations); @@ -128,10 +151,24 @@ dialectsOs << llvm::formatv(dialectRegistrationTemplate, dependentDialect); } + os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(), pass.getArgument(), pass.getSummary(), dependentDialectRegistrations); + + // Create the constructor accepting custom options. + os.indent(2) << llvm::formatv( + "{0}Base(const {0}BaseOptions& options) : {0}Base() {{\n", defName); + + for (const PassOption &opt : pass.getOptions()) + os.indent(4) << llvm::formatv("{0} = options.{0};\n", + opt.getCppVariableName()); + + os.indent(2) << "}\n"; + + // Declare the options. emitPassOptionDecls(pass, os); + emitPassStatisticDecls(pass, os); os << "};\n"; } diff --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp --- a/mlir/unittests/TableGen/PassGenTest.cpp +++ b/mlir/unittests/TableGen/PassGenTest.cpp @@ -21,12 +21,17 @@ struct TestPass : public TestPassBase { explicit TestPass(int v) : extraVal(v) {} + TestPass(int v, const TestPassBaseOptions &options) + : TestPassBase(options), extraVal(v) {} + void runOnOperation() override {} std::unique_ptr clone() const { return TestPassBase::clone(); } + unsigned getTestOption() const { return testOption; } + int extraVal; }; @@ -34,6 +39,11 @@ return std::make_unique(v); } +std::unique_ptr createTestPass(int v, + const TestPassBaseOptions &options) { + return std::make_unique(v, options); +} + TEST(PassGenTest, PassClone) { mlir::MLIRContext context; @@ -46,3 +56,18 @@ EXPECT_EQ(unwrap(origPass)->extraVal, unwrap(clonePass)->extraVal); } + +TEST(PassGenTest, PassOptions) { + mlir::MLIRContext context; + + TestPassBaseOptions options; + options.testOption = 57; + + const auto unwrap = [](const std::unique_ptr &pass) { + return static_cast(pass.get()); + }; + + const auto pass = createTestPass(10, options); + + EXPECT_EQ(unwrap(pass)->getTestOption(), 57); +} diff --git a/mlir/unittests/TableGen/passes.td b/mlir/unittests/TableGen/passes.td --- a/mlir/unittests/TableGen/passes.td +++ b/mlir/unittests/TableGen/passes.td @@ -15,5 +15,7 @@ let constructor = "::createTestPass()"; - let options = RewritePassUtils.options; + let options = [ + Option<"testOption", "testOption", "unsigned", "0", "Test option"> + ]; }