diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -505,6 +505,14 @@ size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } +size_t PyMlirContext::clearLiveOperations() { + for (auto &op : liveOperations) + op.second.second->setInvalid(); + size_t numInvalidated = liveOperations.size(); + liveOperations.clear(); + return numInvalidated; +} + size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } pybind11::object PyMlirContext::contextEnter() { @@ -2208,6 +2216,7 @@ return ref.releaseObject(); }) .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) + .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -201,6 +201,12 @@ /// Used for testing. size_t getLiveOperationCount(); + /// Clears the live operations map, returning the number of entries which were + /// invalidated. To be used as a safety mechanism so that API end-users can't + /// corrupt by holding references they shouldn't have accessed in the first + /// place. + size_t clearLiveOperations(); + /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); @@ -575,6 +581,9 @@ /// parent context's live operations map, and sets the valid bit false. void erase(); + /// Invalidate the operation. + void setInvalid() { valid = false; } + /// Clones this operation. pybind11::object clone(const pybind11::object &ip); diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -104,6 +104,16 @@ assert ctx._get_live_operation_count() == 1 assert op1 is op2 + # Test live operation clearing. + op1 = module.operation + assert ctx._get_live_operation_count() == 1 + num_invalidated = ctx._clear_live_operations() + assert num_invalidated == 1 + assert ctx._get_live_operation_count() == 0 + op1 = None + gc.collect() + op1 = module.operation + # Ensure that if module is de-referenced, the operations are still valid. module = None gc.collect()