diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -23,9 +23,11 @@ #include +#include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "mlir-c/Pass.h" +#define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr" #define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr" #define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" #define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr" @@ -198,6 +200,25 @@ return type; } +/** Creates a capsule object encapsulating the raw C-API MlirAffineMap. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the type in any way. + */ +static inline PyObject *mlirPythonAffineMapToCapsule(MlirAffineMap affineMap) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(affineMap), + MLIR_PYTHON_CAPSULE_AFFINE_MAP, NULL); +} + +/** Extracts an MlirAffineMap from a capsule as produced from + * mlirPythonAffineMapToCapsule. If the capsule is not of the right type, then + * a null type is returned (as checked via mlirAffineMapIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_AFFINE_MAP); + MlirAffineMap affineMap = {ptr}; + return affineMap; +} + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -13,6 +13,7 @@ #include "PybindUtils.h" +#include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" #include "llvm/ADT/DenseMap.h" @@ -667,6 +668,27 @@ MlirValue value; }; +class PyAffineMap : public BaseContextObject { +public: + PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) + : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} + bool operator==(const PyAffineMap &other); + operator MlirAffineMap() const { return affineMap; } + MlirAffineMap get() const { return affineMap; } + + /// Gets a capsule wrapping the void* within the MlirAffineMap. + pybind11::object getCapsule(); + + /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. + /// Note that PyAffineMap instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirAffineMap + /// is taken by calling this function. + static PyAffineMap createFromCapsule(pybind11::object capsule); + +private: + MlirAffineMap affineMap; +}; + void populateIRSubmodule(pybind11::module &m); } // namespace python diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2710,6 +2710,27 @@ } // namespace +//------------------------------------------------------------------------------ +// PyAffineMap. +//------------------------------------------------------------------------------ + +bool PyAffineMap::operator==(const PyAffineMap &other) { + return mlirAffineMapEqual(affineMap, other.affineMap); +} + +py::object PyAffineMap::getCapsule() { + return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); +} + +PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { + MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); + if (mlirAffineMapIsNull(rawAffineMap)) + throw py::error_already_set(); + return PyAffineMap( + PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), + rawAffineMap); +} + //------------------------------------------------------------------------------ // Populates the pybind11 IR submodule. //------------------------------------------------------------------------------ @@ -3392,4 +3413,29 @@ PyOpResultList::bind(m); PyRegionIterator::bind(m); PyRegionList::bind(m); + + //---------------------------------------------------------------------------- + // Mapping of PyAffineMap. + //---------------------------------------------------------------------------- + py::class_(m, "AffineMap") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAffineMap::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) + .def_static( + "get_empty", + [](DefaultingPyMlirContext context) { + MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("context") = py::none(), "Gets an empty affine map.") + .def_property_readonly( + "context", + [](PyAffineMap &self) { return self.getContext().getObject(); }, + "Context that owns the Affine Map") + .def("__eq__", + [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) + .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) + .def( + "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, + kDumpDocstring); } diff --git a/mlir/test/Bindings/Python/ir_affine_map.py b/mlir/test/Bindings/Python/ir_affine_map.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_affine_map.py @@ -0,0 +1,24 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + + +# CHECK-LABEL: TEST: testAffineMapCapsule +def testAffineMapCapsule(): + with Context() as ctx: + am1 = AffineMap.get_empty(ctx) + # CHECK: mlir.ir.AffineMap._CAPIPtr + affine_map_capsule = am1._CAPIPtr + print(affine_map_capsule) + am2 = AffineMap._CAPICreate(affine_map_capsule) + assert am2 == am1 + assert am2.context is ctx + +run(testAffineMapCapsule)