diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -0,0 +1,93 @@ +/*===-- mlir-c/Interop.h - Constants for Python/C-API interop -----*- C -*-===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +|*===----------------------------------------------------------------------===*| +|* *| +|* This header declares constants and helpers necessary for C-level *| +|* interop with the MLIR Python extension module. Since the Python bindings *| +|* are a thin wrapper around the MLIR C-API, a further C-API is not provided *| +|* specifically for the Python extension. Instead, simple facilities are *| +|* provided for translating between Python types and corresponding MLIR C-API *| +|* types. *| +|* *| +|* This header is standalone, requiring nothing beyond normal linking against *| +|* the Python implementation. *| +\*===----------------------------------------------------------------------===*/ + +#ifndef MLIR_C_BINDINGS_PYTHON_INTEROP_H +#define MLIR_C_BINDINGS_PYTHON_INTEROP_H + +#include + +#include "mlir-c/IR.h" + +#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" +#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr" + +/** Attribute on MLIR Python objects that expose their C-API pointer. + * This will be a type-specific capsule created as per one of the helpers + * below. + * + * Ownership is not transferred by acquiring a capsule in this way: the + * validity of the pointer wrapped by the capsule will be bounded by the + * lifetime of the Python object that produced it. Only the name and pointer + * of the capsule are set. The caller is free to set a destructor and context + * as needed to manage anything further. */ +#define MLIR_PYTHON_CAPI_PTR_ATTR "_CAPIPtr" + +/** Attribute on MLIR Python objects that exposes a factory function for + * constructing the corresponding Python object from a type-specific + * capsule wrapping the C-API pointer. The signature of the function is: + * def _CAPICreate(capsule) -> object + * Calling such a function implies a transfer of ownership of the object the + * capsule wraps: after such a call, the capsule should be considered invalid, + * and its wrapped pointer must not be destroyed. + * + * Only a very small number of Python objects can be created in such a fashion + * (i.e. top-level types such as Context where the lifetime can be cleanly + * delineated). */ +#define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate" + +#ifdef __cplusplus +extern "C" { +#endif + +/** Creates a capsule object encapsulating the raw C-API MlirContext. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the context in any way. + */ +inline PyObject *mlirPythonContextToCapsule(MlirContext context) { + return PyCapsule_New(context.ptr, MLIR_PYTHON_CAPSULE_CONTEXT, NULL); +} + +/** Extracts a MlirContext from a capsule as produced from + * mlirPythonContextToCapsule. If the capsule is not of the right type, then + * a null context is returned (as checked via mlirContextIsNull). In such a + * case, the Python APIs will have already set an error. */ +inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_CONTEXT); + MlirContext context = {ptr}; + return context; +} + +/** Creates a capsule object encapsulating the raw C-API MlirModule. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the module in any way. */ +inline PyObject *mlirPythonModuleToCapsule(MlirModule module) { +#ifdef __cplusplus + void *ptr = const_cast(module.ptr); +#else + void *ptr = (void *)ptr; +#endif + return PyCapsule_New(ptr, MLIR_PYTHON_CAPSULE_MODULE, NULL); +} + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_BINDINGS_PYTHON_INTEROP_H diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -91,6 +91,9 @@ /** Checks if two contexts are equal. */ int mlirContextEqual(MlirContext ctx1, MlirContext ctx2); +/** Checks whether a context is null. */ +inline int mlirContextIsNull(MlirContext context) { return !context.ptr; } + /** Takes an MLIR context owned by the caller and destroys it. */ void mlirContextDestroy(MlirContext context); diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -108,6 +108,14 @@ return PyMlirContextRef(this, pybind11::cast(this)); } + /// Gets a capsule wrapping the void* within the MlirContext. + pybind11::object getCapsule(); + + /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. + /// Note that PyMlirContext instances are uniqued, so the returned object + /// may be a pre-existing object. + static pybind11::object createFromCapsule(pybind11::object capsule); + /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); @@ -195,6 +203,12 @@ pybind11::reinterpret_borrow(handle)); } + /// Gets a capsule wrapping the void* within the MlirModule. + /// Note that the module does not (yet) provide a corresponding factory for + /// constructing from a capsule as that would require uniquing PyModule + /// instances, which is not currently done. + pybind11::object getCapsule(); + private: PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -9,6 +9,7 @@ #include "IRModules.h" #include "PybindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Registration.h" #include "mlir-c/StandardAttributes.h" #include "mlir-c/StandardTypes.h" @@ -453,6 +454,17 @@ mlirContextDestroy(context); } +py::object PyMlirContext::getCapsule() { + return py::reinterpret_steal(mlirPythonContextToCapsule(get())); +} + +py::object PyMlirContext::createFromCapsule(py::object capsule) { + MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); + if (mlirContextIsNull(rawContext)) + throw py::error_already_set(); + return forContext(rawContext).releaseObject(); +} + PyMlirContext *PyMlirContext::createNewContextForInit() { MlirContext context = mlirContextCreate(); mlirRegisterAllDialects(context); @@ -581,6 +593,10 @@ return PyModuleRef(unownedModule, std::move(pyRef)); } +py::object PyModule::getCapsule() { + return py::reinterpret_steal(mlirPythonModuleToCapsule(get())); +} + //------------------------------------------------------------------------------ // PyOperation //------------------------------------------------------------------------------ @@ -1345,6 +1361,9 @@ return ref.releaseObject(); }) .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyMlirContext::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def_property( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { @@ -1428,6 +1447,7 @@ // Mapping of Module py::class_(m, "Module") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def_property_readonly( "operation", [](PyModule &self) { diff --git a/mlir/test/Bindings/Python/context_lifecycle.py b/mlir/test/Bindings/Python/context_lifecycle.py --- a/mlir/test/Bindings/Python/context_lifecycle.py +++ b/mlir/test/Bindings/Python/context_lifecycle.py @@ -40,3 +40,10 @@ c2 = None gc.collect() assert mlir.ir.Context._get_live_count() == 0 + +# Create a context, get its capsule and create from capsule. +c4 = mlir.ir.Context() +c4_capsule = c4._CAPIPtr +assert '"mlir.ir.Context._CAPIPtr"' in repr(c4_capsule) +c5 = mlir.ir.Context._CAPICreate(c4_capsule) +assert c4 is c5 diff --git a/mlir/test/Bindings/Python/ir_module.py b/mlir/test/Bindings/Python/ir_module.py --- a/mlir/test/Bindings/Python/ir_module.py +++ b/mlir/test/Bindings/Python/ir_module.py @@ -84,3 +84,13 @@ assert ctx._get_live_operation_count() == 0 run(testModuleOperation) + + +# CHECK-LABEL: TEST: testModuleCapsule +def testModuleCapsule(): + ctx = mlir.ir.Context() + module = ctx.parse_module(r"""module @successfulParse {}""") + # CHECK: "mlir.ir.Module._CAPIPtr" + print(module._CAPIPtr) + +run(testModuleCapsule)