diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1075,6 +1075,21 @@ .releaseObject(); } +static void maybeInsertOperation(PyOperationRef &op, + const py::object &maybeIp) { + // InsertPoint active? + if (!maybeIp.is(py::cast(false))) { + PyInsertionPoint *ip; + if (maybeIp.is_none()) { + ip = PyThreadContextEntry::getDefaultInsertionPoint(); + } else { + ip = py::cast(maybeIp); + } + if (ip) + ip->insert(*op.get()); + } +} + py::object PyOperation::create( const std::string &name, llvm::Optional> results, llvm::Optional> operands, @@ -1192,22 +1207,20 @@ MlirOperation operation = mlirOperationCreate(&state); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); - - // InsertPoint active? - if (!maybeIp.is(py::cast(false))) { - PyInsertionPoint *ip; - if (maybeIp.is_none()) { - ip = PyThreadContextEntry::getDefaultInsertionPoint(); - } else { - ip = py::cast(maybeIp); - } - if (ip) - ip->insert(*created.get()); - } + maybeInsertOperation(created, maybeIp); return created->createOpView(); } +py::object PyOperation::clone(const py::object &maybeIp) { + MlirOperation clonedOperation = mlirOperationClone(operation); + PyOperationRef cloned = + PyOperation::createDetached(getContext(), clonedOperation); + maybeInsertOperation(cloned, maybeIp); + + return cloned->createOpView(); +} + py::object PyOperation::createOpView() { checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); @@ -2616,6 +2629,7 @@ return py::none(); }) .def("erase", &PyOperation::erase) + .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -575,6 +575,9 @@ /// parent context's live operations map, and sets the valid bit false. void erase(); + /// Clones this operation. + pybind11::object clone(const pybind11::object &ip); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -767,6 +767,26 @@ Operation.create("custom.op2") +# CHECK-LABEL: TEST: testOperationClone +@run +def testOperationClone(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + m = Module.create() + with InsertionPoint(m.body): + op = Operation.create("custom.op1") + + # CHECK: "custom.op1" + print(m) + + clone = op.operation.clone() + op.operation.erase() + + # CHECK: "custom.op1" + print(m) + + # CHECK-LABEL: TEST: testOperationLoc @run def testOperationLoc():