diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -753,6 +753,9 @@ : BaseContextObject(std::move(contextRef)), operation(operation) {} PyOperation::~PyOperation() { + // If the operation has already been invalidated there is nothing to do. + if (!valid) + return; auto &liveOperations = getContext()->liveOperations; assert(liveOperations.count(operation.ptr) == 1 && "destroying operation not in live map"); @@ -869,6 +872,7 @@ } PyOperationRef PyOperation::getParentOperation() { + checkValid(); if (!isAttached()) throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); MlirOperation operation = mlirOperationGetParentOperation(get()); @@ -878,6 +882,7 @@ } PyBlock PyOperation::getBlock() { + checkValid(); PyOperationRef parentOperation = getParentOperation(); MlirBlock block = mlirOperationGetBlock(get()); assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); @@ -885,6 +890,7 @@ } py::object PyOperation::getCapsule() { + checkValid(); return py::reinterpret_steal(mlirPythonOperationToCapsule(get())); } @@ -1032,6 +1038,7 @@ } py::object PyOperation::createOpView() { + checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); auto opViewClass = PyGlobals::get().lookupRawOpViewClass( @@ -1041,6 +1048,18 @@ return py::cast(PyOpView(getRef().getObject())); } +void PyOperation::erase() { + checkValid(); + // TODO: Fix memory hazards when erasing a tree of operations for which a deep + // Python reference to a child operation is live. All children should also + // have their `valid` bit set to false. + auto &liveOperations = getContext()->liveOperations; + if (liveOperations.count(operation.ptr)) + liveOperations.erase(operation.ptr); + mlirOperationDestroy(operation); + valid = false; +} + //------------------------------------------------------------------------------ // PyOpView //------------------------------------------------------------------------------ @@ -2094,11 +2113,13 @@ py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), kOperationCreateDocstring) + .def("erase", &PyOperation::erase) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) .def_property_readonly("name", [](PyOperation &self) { + self.checkValid(); MlirOperation operation = self.get(); MlirStringRef name = mlirIdentifierStr( mlirOperationGetName(operation)); @@ -2106,7 +2127,10 @@ }) .def_property_readonly( "context", - [](PyOperation &self) { return self.getContext().getObject(); }, + [](PyOperation &self) { + self.checkValid(); + return self.getContext().getObject(); + }, "Context that owns the Operation") .def_property_readonly("opview", &PyOperation::createOpView); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -473,6 +473,10 @@ /// Creates an OpView suitable for this operation. pybind11::object createOpView(); + /// Erases the underlying MlirOperation, removes its pointer from the + /// parent context's live operations map, and sets the valid bit false. + void erase(); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -646,3 +646,25 @@ assert m2 is m run(testCapsuleConversions) + +# CHECK-LABEL: TEST: testOperationErase +def testOperationErase(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + m = Module.create() + with InsertionPoint(m.body): + op = Operation.create("custom.op1") + + # CHECK: "custom.op1" + print(m) + + op.operation.erase() + + # CHECK-NOT: "custom.op1" + print(m) + + # Ensure we can create another operation + Operation.create("custom.op2") + +run(testOperationErase)