diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -24,9 +24,11 @@ #include #include "mlir-c/IR.h" +#include "mlir-c/Pass.h" #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" #define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr" +#define MLIR_PYTHON_CAPSULE_PASS_MANAGER "mlir.passmanager.PassManager._CAPIPtr" /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers @@ -52,6 +54,14 @@ * delineated). */ #define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate" +/// Gets a void* from a wrapped struct. Needed because const cast is different +/// between C/C++. +#ifdef __cplusplus +#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) const_cast(object.ptr) +#else +#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) (void *)(object.ptr) +#endif + #ifdef __cplusplus extern "C" { #endif @@ -60,7 +70,7 @@ * The returned capsule does not extend or affect ownership of any Python * objects that reference the context in any way. */ -inline PyObject *mlirPythonContextToCapsule(MlirContext context) { +static inline PyObject *mlirPythonContextToCapsule(MlirContext context) { return PyCapsule_New(context.ptr, MLIR_PYTHON_CAPSULE_CONTEXT, NULL); } @@ -68,7 +78,7 @@ * mlirPythonContextToCapsule. If the capsule is not of the right type, then * a null context is returned (as checked via mlirContextIsNull). In such a * case, the Python APIs will have already set an error. */ -inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) { +static inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) { void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_CONTEXT); MlirContext context = {ptr}; return context; @@ -77,25 +87,39 @@ /** Creates a capsule object encapsulating the raw C-API MlirModule. * The returned capsule does not extend or affect ownership of any Python * objects that reference the module in any way. */ -inline PyObject *mlirPythonModuleToCapsule(MlirModule module) { -#ifdef __cplusplus - void *ptr = const_cast(module.ptr); -#else - void *ptr = (void *)ptr; -#endif - return PyCapsule_New(ptr, MLIR_PYTHON_CAPSULE_MODULE, NULL); +static inline PyObject *mlirPythonModuleToCapsule(MlirModule module) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(module), + MLIR_PYTHON_CAPSULE_MODULE, NULL); } /** Extracts an MlirModule from a capsule as produced from * mlirPythonModuleToCapsule. If the capsule is not of the right type, then * a null module is returned (as checked via mlirModuleIsNull). In such a * case, the Python APIs will have already set an error. */ -inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) { +static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) { void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_MODULE); MlirModule module = {ptr}; return module; } +/** Creates a capsule object encapsulating the raw C-API MlirPassManager. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the module in any way. */ +static inline PyObject *mlirPythonPassManagerToCapsule(MlirPassManager pm) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm), + MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL); +} + +/** Extracts an MlirPassManager from a capsule as produced from + * mlirPythonPassManagerToCapsule. If the capsule is not of the right type, then + * a null pass manager is returned (as checked via mlirPassManagerIsNull). */ +static inline MlirPassManager +mlirPythonCapsuleToPassManager(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER); + MlirPassManager pm = {ptr}; + return pm; +} + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -53,6 +53,11 @@ /// Destroy the provided PassManager. MLIR_CAPI_EXPORTED void mlirPassManagerDestroy(MlirPassManager passManager); +/// Checks if a PassManager is null. +static inline int mlirPassManagerIsNull(MlirPassManager passManager) { + return !passManager.ptr; +} + /// Cast a top-level PassManager to a generic OpPassManager. MLIR_CAPI_EXPORTED MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager); diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -9,6 +9,7 @@ #include "Pass.h" #include "IRModules.h" +#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Pass.h" namespace py = pybind11; @@ -21,9 +22,28 @@ class PyPassManager { public: PyPassManager(MlirPassManager passManager) : passManager(passManager) {} - ~PyPassManager() { mlirPassManagerDestroy(passManager); } + PyPassManager(PyPassManager &&other) : passManager(other.passManager) { + other.passManager.ptr = nullptr; + } + ~PyPassManager() { + if (!mlirPassManagerIsNull(passManager)) + mlirPassManagerDestroy(passManager); + } MlirPassManager get() { return passManager; } + void release() { passManager.ptr = nullptr; } + pybind11::object getCapsule() { + return py::reinterpret_steal( + mlirPythonPassManagerToCapsule(get())); + } + + static pybind11::object createFromCapsule(pybind11::object capsule) { + MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); + if (mlirPassManagerIsNull(rawPm)) + throw py::error_already_set(); + return py::cast(PyPassManager(rawPm), py::return_value_policy::move); + } + private: MlirPassManager passManager; }; @@ -43,6 +63,11 @@ }), py::arg("context") = py::none(), "Create a new PassManager for the current (or provided) Context.") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyPassManager::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) + .def("_testing_release", &PyPassManager::release, + "Releases (leaks) the backing pass manager (testing)") .def_static( "parse", [](const std::string pipeline, DefaultingPyMlirContext context) { diff --git a/mlir/test/Bindings/Python/pass_manager.py b/mlir/test/Bindings/Python/pass_manager.py --- a/mlir/test/Bindings/Python/pass_manager.py +++ b/mlir/test/Bindings/Python/pass_manager.py @@ -16,6 +16,19 @@ gc.collect() assert Context._get_live_count() == 0 +# Verify capsule interop. +# CHECK-LABEL: TEST: testCapsule +def testCapsule(): + with Context(): + pm = PassManager() + pm_capsule = pm._CAPIPtr + assert '"mlir.passmanager.PassManager._CAPIPtr"' in repr(pm_capsule) + pm._testing_release() + pm1 = PassManager._CAPICreate(pm_capsule) + assert pm1 is not None # And does not crash. +run(testCapsule) + + # Verify successful round-trip. # CHECK-LABEL: TEST: testParseSuccess def testParseSuccess():