diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -42,6 +42,7 @@ #define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr" #define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr" #define MLIR_PYTHON_CAPSULE_PASS_MANAGER "mlir.passmanager.PassManager._CAPIPtr" +#define MLIR_PYTHON_CAPSULE_VALUE "mlir.ir.Value._CAPIPtr" /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers @@ -285,6 +286,25 @@ return jit; } +/** Creates a capsule object encapsulating the raw C-API MlirValue. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the operation in any way. + */ +static inline PyObject *mlirPythonValueToCapsule(MlirValue value) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(value), + MLIR_PYTHON_CAPSULE_VALUE, NULL); +} + +/** Extracts an MlirValue from a capsule as produced from + * mlirPythonValueToCapsule. If the capsule is not of the right type, then a + * null type is returned (as checked via mlirValueIsNull). In such a case, the + * Python APIs will have already set an error. */ +static inline MlirValue mlirPythonCapsuleToValue(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_VALUE); + MlirValue value = {ptr}; + return value; +} + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -15,6 +15,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" +#include "mlir-c/IR.h" #include "mlir-c/Registration.h" #include "llvm/ADT/SmallVector.h" #include @@ -1467,6 +1468,27 @@ // PyValue and subclases. //------------------------------------------------------------------------------ +pybind11::object PyValue::getCapsule() { + return py::reinterpret_steal(mlirPythonValueToCapsule(get())); +} + +PyValue PyValue::createFromCapsule(pybind11::object capsule) { + MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); + if (mlirValueIsNull(value)) + throw py::error_already_set(); + MlirOperation owner; + if (mlirValueIsAOpResult(value)) + owner = mlirOpResultGetOwner(value); + if (mlirValueIsABlockArgument(value)) + owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); + if (mlirOperationIsNull(owner)) + throw py::error_already_set(); + MlirContext ctx = mlirOperationGetContext(owner); + PyOperationRef ownerRef = + PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); + return PyValue(ownerRef, value); +} + namespace { /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed @@ -2353,6 +2375,8 @@ // Mapping of Value. //---------------------------------------------------------------------------- py::class_(m, "Value") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) .def_property_readonly( "context", [](PyValue &self) { return self.getParentOperation()->getContext(); }, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -721,6 +721,13 @@ void checkValid() { return parentOperation->checkValid(); } + /// Gets a capsule wrapping the void* within the MlirValue. + pybind11::object getCapsule(); + + /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of + /// the underlying MlirValue is still tied to the owning operation. + static PyValue createFromCapsule(pybind11::object capsule); + private: PyOperationRef parentOperation; MlirValue value; diff --git a/mlir/test/Bindings/Python/ir_value.py b/mlir/test/Bindings/Python/ir_value.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_value.py @@ -0,0 +1,27 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +# CHECK-LABEL: TEST: testCapsuleConversions +def testCapsuleConversions(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + value = Operation.create("custom.op1", results=[i32]).result + value_capsule = value._CAPIPtr + assert '"mlir.ir.Value._CAPIPtr"' in repr(value_capsule) + value2 = Value._CAPICreate(value_capsule) + assert value2 == value + + +run(testCapsuleConversions)