diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -346,6 +346,10 @@ /// Takes an operation owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op); +/// Removes the given operation from its parent block, but does not destroy +/// it. After this, the operation passed as argument is owned by the caller. +MLIR_CAPI_EXPORTED void mlirOperationRemoveFromParent(MlirOperation op); + /// Checks whether the underlying operation is null. static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2185,7 +2185,18 @@ return mlirOperationVerify(self.getOperation()); }, "Verify the operation and return true if it passes, false if it " - "fails."); + "fails.") + .def( + "remove_from_parent", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + if (!operation.isAttached()) { + throw py::value_error("Detached operation has no parent."); + } + operation.detachFromParent(); + return operation.createOpView(); + }, + "Detaches the operation from its parent block."); py::class_(m, "Operation", py::module_local()) .def_static("create", &PyOperation::create, py::arg("name"), @@ -2380,7 +2391,19 @@ printAccum.getUserData()); return printAccum.join(); }, - "Returns the assembly form of the block."); + "Returns the assembly form of the block.") + .def( + "append", + [](PyBlock &self, PyOperationBase &operation) { + if (operation.getOperation().isAttached()) { + throw py::value_error("Operation is already in a block."); + } + mlirBlockAppendOwnedOperation(self.get(), + operation.getOperation().get()); + operation.getOperation().setAttached( + self.getParentOperation().getObject()); + }, + "Appends an operation to this block."); //---------------------------------------------------------------------------- // Mapping of PyInsertionPoint. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -428,6 +428,14 @@ createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); + /// Detaches the operaiton from its parent block and updates its state + /// accordingly. + void detachFromParent() { + mlirOperationRemoveFromParent(getOperation()); + setDetached(); + parentKeepAlive = pybind11::object(); + } + /// Gets the backing operation. operator MlirOperation() const { return get(); } MlirOperation get() const { @@ -441,10 +449,14 @@ } bool isAttached() { return attached; } - void setAttached() { + void setAttached(pybind11::object parent = pybind11::object()) { assert(!attached && "operation already attached"); attached = true; } + void setDetached() { + assert(attached && "operation already detached"); + attached = false; + } void checkValid() const; /// Gets the owning block or raises an exception if the operation has no diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -338,6 +338,8 @@ void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } +void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); } + bool mlirOperationEqual(MlirOperation op, MlirOperation other) { return unwrap(op) == unwrap(other); } diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -740,3 +740,20 @@ op = Operation.create("custom.op", loc=loc) assert op.location == loc assert op.operation.location == loc + + +# CHECK-LABEL: TEST: testModuleMerge +@run +def testModuleMerge(): + with Context(): + m1 = Module.parse("func private @foo()") + m2 = Module.parse("func private @bar()") + func = m1.body.operations[0].remove_from_parent() + m2.body.append(func) + # CHECK: module + # CHECK: func private @bar + # CHECK: func private @foo + print(m2) + # CHECK: module { + # CHECK-NEXT: } + print(m1)