diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -56,11 +56,14 @@ // Mapping of the top-level PassManager //---------------------------------------------------------------------------- py::class_(m, "PassManager", py::module_local()) - .def(py::init<>([](DefaultingPyMlirContext context) { - MlirPassManager passManager = - mlirPassManagerCreate(context->get()); + .def(py::init<>([](const std::string &anchorOp, + DefaultingPyMlirContext context) { + MlirPassManager passManager = mlirPassManagerCreateOnOperation( + context->get(), + mlirStringRefCreate(anchorOp.data(), anchorOp.size())); return new PyPassManager(passManager); }), + py::arg("anchor_op") = py::str("any"), py::arg("context") = py::none(), "Create a new PassManager for the current (or provided) Context.") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -28,6 +28,17 @@ assert pm1 is not None # And does not crash. run(testCapsule) +# CHECK-LABEL: TEST: testConstruct +@run +def testConstruct(): + with Context(): + # CHECK: pm1: 'any()' + # CHECK: pm2: 'builtin.module()' + pm1 = PassManager() + pm2 = PassManager("builtin.module") + log(f"pm1: '{pm1}'") + log(f"pm2: '{pm2}'") + # Verify successful round-trip. # CHECK-LABEL: TEST: testParseSuccess