diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -86,6 +86,16 @@ return PyCapsule_New(ptr, MLIR_PYTHON_CAPSULE_MODULE, NULL); } +/** Extracts an MlirModule from a capsule as produced from + * mlirPythonModuleToCapsule. If the capsule is not of the right type, then + * a null module is returned (as checked via mlirModuleIsNull). In such a + * case, the Python APIs will have already set an error. */ +inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_MODULE); + MlirModule module = {ptr}; + return module; +} + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -113,7 +113,8 @@ /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. /// Note that PyMlirContext instances are uniqued, so the returned object - /// may be a pre-existing object. + /// may be a pre-existing object. Ownership of the underlying MlirContext + /// is taken by calling this function. static pybind11::object createFromCapsule(pybind11::object capsule); /// Gets the count of live context objects. Used for testing. @@ -123,6 +124,10 @@ /// Used for testing. size_t getLiveOperationCount(); + /// Gets the count of live modules associated with this context. + /// Used for testing. + size_t getLiveModuleCount(); + /// Creates an operation. See corresponding python docstring. pybind11::object createOperation(std::string name, PyLocation location, @@ -142,6 +147,14 @@ using LiveContextMap = llvm::DenseMap; static LiveContextMap &getLiveContexts(); + // Interns all live modules associated with this context. Modules tracked + // in this map are valid. When a module is invalidated, it is removed + // from this map, and while it still exists as an instance, any + // attempt to access it will raise an error. + using LiveModuleMap = + llvm::DenseMap>; + LiveModuleMap liveModules; + // Interns all live operations associated with this context. Operations // tracked in this map are valid. When an operation is invalidated, it is // removed from this map, and while it still exists as an instance, any @@ -151,6 +164,7 @@ LiveOperationMap liveOperations; MlirContext context; + friend class PyModule; friend class PyOperation; }; @@ -186,13 +200,12 @@ using PyModuleRef = PyObjectRef; class PyModule : public BaseContextObject { public: - /// Creates a reference to the module - static PyModuleRef create(PyMlirContextRef contextRef, MlirModule module); + /// Returns a PyModule reference for the given MlirModule. This may return + /// a pre-existing or new object. + static PyModuleRef forModule(MlirModule module); PyModule(PyModule &) = delete; - ~PyModule() { - if (module.ptr) - mlirModuleDestroy(module); - } + PyModule(PyMlirContext &&) = delete; + ~PyModule(); /// Gets the backing MlirModule. MlirModule get() { return module; } @@ -209,9 +222,14 @@ /// instances, which is not currently done. pybind11::object getCapsule(); + /// Creates a PyModule from the MlirModule wrapped by a capsule. + /// Note that PyModule instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirModule + /// is taken by calling this function. + static pybind11::object createFromCapsule(pybind11::object capsule); + private: - PyModule(PyMlirContextRef contextRef, MlirModule module) - : BaseContextObject(std::move(contextRef)), module(module) {} + PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; pybind11::handle handle; }; diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -497,6 +497,8 @@ size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } +size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } + py::object PyMlirContext::createOperation( std::string name, PyLocation location, llvm::Optional> results, @@ -582,15 +584,49 @@ // PyModule //------------------------------------------------------------------------------ -PyModuleRef PyModule::create(PyMlirContextRef contextRef, MlirModule module) { - PyModule *unownedModule = new PyModule(std::move(contextRef), module); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. - py::object pyRef = - py::cast(unownedModule, py::return_value_policy::take_ownership); - unownedModule->handle = pyRef; - return PyModuleRef(unownedModule, std::move(pyRef)); +PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) + : BaseContextObject(std::move(contextRef)), module(module) {} + +PyModule::~PyModule() { + py::gil_scoped_acquire acquire; + auto &liveModules = getContext()->liveModules; + assert(liveModules.count(module.ptr) == 1 && + "destroying module not in live map"); + liveModules.erase(module.ptr); + mlirModuleDestroy(module); +} + +PyModuleRef PyModule::forModule(MlirModule module) { + MlirContext context = mlirModuleGetContext(module); + PyMlirContextRef contextRef = PyMlirContext::forContext(context); + + py::gil_scoped_acquire acquire; + auto &liveModules = contextRef->liveModules; + auto it = liveModules.find(module.ptr); + if (it == liveModules.end()) { + // Create. + PyModule *unownedModule = new PyModule(std::move(contextRef), module); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + py::object pyRef = + py::cast(unownedModule, py::return_value_policy::take_ownership); + unownedModule->handle = pyRef; + liveModules[module.ptr] = + std::make_pair(unownedModule->handle, unownedModule); + return PyModuleRef(unownedModule, std::move(pyRef)); + } + // Use existing. + PyModule *existing = it->second.second; + py::object pyRef = py::reinterpret_borrow(it->second.first); + return PyModuleRef(existing, std::move(pyRef)); +} + +py::object PyModule::createFromCapsule(py::object capsule) { + MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); + if (mlirModuleIsNull(rawModule)) + throw py::error_already_set(); + return forModule(rawModule).releaseObject(); } py::object PyModule::getCapsule() { @@ -1461,6 +1497,7 @@ return ref.releaseObject(); }) .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) + .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) @@ -1489,9 +1526,16 @@ PyExc_ValueError, "Unable to parse module assembly (see diagnostics)"); } - return PyModule::create(self.getRef(), module).releaseObject(); + return PyModule::forModule(module).releaseObject(); }, kContextParseDocstring) + .def( + "create_module", + [](PyMlirContext &self, PyLocation &loc) { + MlirModule module = mlirModuleCreateEmpty(loc.loc); + return PyModule::forModule(module).releaseObject(); + }, + py::arg("loc"), "Creates an empty module") .def( "parse_attr", [](PyMlirContext &self, std::string attrSpec) { @@ -1538,16 +1582,26 @@ kContextGetFileLocationDocstring, py::arg("filename"), py::arg("line"), py::arg("col")); - py::class_(m, "Location").def("__repr__", [](PyLocation &self) { - PyPrintAccumulator printAccum; - mlirLocationPrint(self.loc, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }); + py::class_(m, "Location") + .def_property_readonly( + "context", + [](PyLocation &self) { return self.getContext().getObject(); }, + "Context that owns the Location") + .def("__repr__", [](PyLocation &self) { + PyPrintAccumulator printAccum; + mlirLocationPrint(self.loc, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }); // Mapping of Module py::class_(m, "Module") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) + .def_property_readonly( + "context", + [](PyModule &self) { return self.getContext().getObject(); }, + "Context that created the Module") .def_property_readonly( "operation", [](PyModule &self) { @@ -1576,6 +1630,10 @@ // Mapping of Operation. py::class_(m, "Operation") + .def_property_readonly( + "context", + [](PyOperation &self) { return self.getContext().getObject(); }, + "Context that owns the Operation") .def_property_readonly( "regions", [](PyOperation &self) { return PyRegionList(self.getRef()); }) @@ -1657,6 +1715,10 @@ // Mapping of Type. py::class_(m, "Attribute") + .def_property_readonly( + "context", + [](PyAttribute &self) { return self.getContext().getObject(); }, + "Context that owns the Attribute") .def( "get_named", [](PyAttribute &self, std::string name) { @@ -1737,6 +1799,9 @@ // Mapping of Type. py::class_(m, "Type") + .def_property_readonly( + "context", [](PyType &self) { return self.getContext().getObject(); }, + "Context that owns the Type") .def("__eq__", [](PyType &self, py::object &other) { try { diff --git a/mlir/test/Bindings/Python/ir_attributes.py b/mlir/test/Bindings/Python/ir_attributes.py --- a/mlir/test/Bindings/Python/ir_attributes.py +++ b/mlir/test/Bindings/Python/ir_attributes.py @@ -14,6 +14,7 @@ def testParsePrint(): ctx = mlir.ir.Context() t = ctx.parse_attr('"hello"') + assert t.context is ctx ctx = None gc.collect() # CHECK: "hello" diff --git a/mlir/test/Bindings/Python/ir_location.py b/mlir/test/Bindings/Python/ir_location.py --- a/mlir/test/Bindings/Python/ir_location.py +++ b/mlir/test/Bindings/Python/ir_location.py @@ -14,6 +14,7 @@ def testUnknown(): ctx = mlir.ir.Context() loc = ctx.get_unknown_location() + assert loc.context is ctx ctx = None gc.collect() # CHECK: unknown str: loc(unknown) diff --git a/mlir/test/Bindings/Python/ir_module.py b/mlir/test/Bindings/Python/ir_module.py --- a/mlir/test/Bindings/Python/ir_module.py +++ b/mlir/test/Bindings/Python/ir_module.py @@ -16,6 +16,7 @@ def testParseSuccess(): ctx = mlir.ir.Context() module = ctx.parse_module(r"""module @successfulParse {}""") + assert module.context is ctx print("CLEAR CONTEXT") ctx = None # Ensure that module captures the context. gc.collect() @@ -40,6 +41,21 @@ run(testParseError) +# Verify successful parse. +# CHECK-LABEL: TEST: testCreateEmpty +# CHECK: module { +def testCreateEmpty(): + ctx = mlir.ir.Context() + loc = ctx.get_unknown_location() + module = ctx.create_module(loc) + print("CLEAR CONTEXT") + ctx = None # Ensure that module captures the context. + gc.collect() + print(str(module)) + +run(testCreateEmpty) + + # Verify round-trip of ASM that contains unicode. # Note that this does not test that the print path converts unicode properly # because MLIR asm always normalizes it to the hex encoding. @@ -61,6 +77,7 @@ def testModuleOperation(): ctx = mlir.ir.Context() module = ctx.parse_module(r"""module @successfulParse {}""") + assert ctx._get_live_module_count() == 1 op1 = module.operation assert ctx._get_live_operation_count() == 1 # CHECK: module @successfulParse @@ -82,6 +99,7 @@ gc.collect() print("LIVE OPERATIONS:", ctx._get_live_operation_count()) assert ctx._get_live_operation_count() == 0 + assert ctx._get_live_module_count() == 0 run(testModuleOperation) @@ -90,7 +108,19 @@ def testModuleCapsule(): ctx = mlir.ir.Context() module = ctx.parse_module(r"""module @successfulParse {}""") + assert ctx._get_live_module_count() == 1 # CHECK: "mlir.ir.Module._CAPIPtr" - print(module._CAPIPtr) + module_capsule = module._CAPIPtr + print(module_capsule) + module_dup = mlir.ir.Module._CAPICreate(module_capsule) + assert module is module_dup + assert module_dup.context is ctx + # Gc and verify destructed. + module = None + module_capsule = None + module_dup = None + gc.collect() + assert ctx._get_live_module_count() == 0 + run(testModuleCapsule) diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -23,6 +23,7 @@ } """) op = module.operation + assert op.context is ctx # Get the block using iterators off of the named collections. regions = list(op.regions) blocks = list(regions[0].blocks) diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -14,6 +14,7 @@ def testParsePrint(): ctx = mlir.ir.Context() t = ctx.parse_type("i32") + assert t.context is ctx ctx = None gc.collect() # CHECK: i32