diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1208,6 +1208,26 @@ return created->createOpView(); } +py::object PyOperation::clone(const py::object &maybeIp) { + MlirOperation clonedOperation = mlirOperationClone(operation); + PyOperationRef cloned = + PyOperation::createDetached(getContext(), clonedOperation); + + // InsertPoint active? + if (!maybeIp.is(py::cast(false))) { + PyInsertionPoint *ip; + if (maybeIp.is_none()) { + ip = PyThreadContextEntry::getDefaultInsertionPoint(); + } else { + ip = py::cast(maybeIp); + } + if (ip) + ip->insert(*cloned.get()); + } + + return cloned->createOpView(); +} + py::object PyOperation::createOpView() { checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); @@ -2616,6 +2636,7 @@ return py::none(); }) .def("erase", &PyOperation::erase) + .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -575,6 +575,9 @@ /// parent context's live operations map, and sets the valid bit false. void erase(); + /// Clones this operation. + pybind11::object clone(const pybind11::object &ip); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -767,6 +767,26 @@ Operation.create("custom.op2") +# CHECK-LABEL: TEST: testOperationClone +@run +def testOperationClone(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + m = Module.create() + with InsertionPoint(m.body): + op = Operation.create("custom.op1") + + # CHECK: "custom.op1" + print(m) + + clone = op.operation.clone() + op.operation.erase() + + # CHECK: "custom.op1" + print(m) + + # CHECK-LABEL: TEST: testOperationLoc @run def testOperationLoc():