diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -868,6 +868,19 @@ return PyBlock{std::move(parentOperation), block}; } +py::object PyOperation::getCapsule() { + return py::reinterpret_steal(mlirPythonOperationToCapsule(get())); +} + +py::object PyOperation::createFromCapsule(py::object capsule) { + MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); + if (mlirOperationIsNull(rawOperation)) + throw py::error_already_set(); + MlirContext rawCtxt = mlirOperationGetContext(rawOperation); + return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) + .releaseObject(); +} + py::object PyOperation::create( std::string name, llvm::Optional> results, llvm::Optional> operands, @@ -2031,6 +2044,9 @@ py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), kOperationCreateDocstring) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyOperation::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) .def_property_readonly("name", [](PyOperation &self) { MlirOperation operation = self.get(); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -454,6 +454,14 @@ /// no parent. PyOperationRef getParentOperation(); + /// Gets a capsule wrapping the void* within the MlirOperation. + pybind11::object getCapsule(); + + /// Creates a PyOperation from the MlirOperation wrapped by a capsule. + /// Ownership of the underlying MlirOperation is taken by calling this + /// function. + static pybind11::object createFromCapsule(pybind11::object capsule); + /// Creates an operation. See corresponding python docstring. static pybind11::object create(std::string name, llvm::Optional> results, diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -601,3 +601,16 @@ print(op.operation.name) run(testOperationName) + +# CHECK-LABEL: TEST: testCapsuleConversions +def testCapsuleConversions(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + m = Operation.create("custom.op1").operation + m_capsule = m._CAPIPtr + assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) + m2 = Operation._CAPICreate(m_capsule) + assert m2 is m + +run(testCapsuleConversions)