diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -100,6 +100,20 @@ "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") + .def( + "add", + [](PyPassManager &passManager, const std::string &pipeline) { + PyPrintAccumulator errorMsg; + MlirLogicalResult status = mlirOpPassManagerAddPipeline( + mlirPassManagerGetAsOpPassManager(passManager.get()), + mlirStringRefCreate(pipeline.data(), pipeline.size()), + errorMsg.getCallback(), errorMsg.getUserData()); + if (mlirLogicalResultIsFailure(status)) + throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); + }, + py::arg("pipeline"), + "Add textual pipeline elements to the pass manager. Throws a " + "ValueError if the pipeline can't be parsed.") .def( "run", [](PyPassManager &passManager, PyModule &module) { diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -191,11 +191,17 @@ ops = module.operation.regions[0].blocks[0].operations mod = Module.parse("\n".join([str(op) for op in ops]) + boilerplate) - pm = PassManager.parse( - "builtin.module(func.func(convert-linalg-to-loops, lower-affine, " + - "convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), " - + "convert-vector-to-llvm, convert-memref-to-llvm, convert-func-to-llvm," + - "reconcile-unrealized-casts)") + pm = PassManager('builtin.module') + pm.add("func.func(convert-linalg-to-loops)") + pm.add("func.func(lower-affine)") + pm.add("func.func(convert-math-to-llvm)") + pm.add("func.func(convert-scf-to-cf)") + pm.add("func.func(arith-expand)") + pm.add("func.func(memref-expand)") + pm.add("convert-vector-to-llvm") + pm.add("convert-memref-to-llvm") + pm.add("convert-func-to-llvm") + pm.add("reconcile-unrealized-casts") pm.run(mod) return mod diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -75,6 +75,20 @@ log("Exception not produced") run(testParseFail) +# Check that adding to a pass manager works +# CHECK-LABEL: TEST: testAdd +@run +def testAdd(): + pm = PassManager("any", Context()) + # CHECK: pm: 'any()' + log(f"pm: '{pm}'") + # CHECK: pm: 'any(cse)' + pm.add("cse") + log(f"pm: '{pm}'") + # CHECK: pm: 'any(cse,cse)' + pm.add("cse") + log(f"pm: '{pm}'") + # Verify failure on incorrect level of nesting. # CHECK-LABEL: TEST: testInvalidNesting