diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -346,6 +346,10 @@ /// Takes an operation owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op); +/// Removes the given operation from its parent block. The operation is not +/// destroyed. The ownership of the operation is transferred to the caller. +MLIR_CAPI_EXPORTED void mlirOperationRemoveFromParent(MlirOperation op); + /// Checks whether the underlying operation is null. static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; } @@ -455,6 +459,19 @@ /// Verify the operation and return true if it passes, false if it fails. MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op); +/// Moves the given operation immediately after the other operation in its +/// parent block. The given operation may be owned by the caller or by its +/// current block. The other operation must belong to a block. In any case, the +/// ownership is transferred to the block of the other operation. +MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, + MlirOperation other); + +/// Moves the given operation immediately before the other operation in its +/// parent block. The given operation may be owner by the caller or by its +/// current block. The other operation must belong to a block. In any case, the +/// ownership is transferred to the block of the other operation. +MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, + MlirOperation other); //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -875,6 +875,24 @@ return fileObject.attr("getvalue")(); } +void PyOperationBase::moveAfter(PyOperationBase &other) { + PyOperation &operation = getOperation(); + PyOperation &otherOp = other.getOperation(); + operation.checkValid(); + otherOp.checkValid(); + mlirOperationMoveAfter(operation, otherOp); + operation.parentKeepAlive = otherOp.parentKeepAlive; +} + +void PyOperationBase::moveBefore(PyOperationBase &other) { + PyOperation &operation = getOperation(); + PyOperation &otherOp = other.getOperation(); + operation.checkValid(); + otherOp.checkValid(); + mlirOperationMoveBefore(operation, otherOp); + operation.parentKeepAlive = otherOp.parentKeepAlive; +} + llvm::Optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) @@ -2185,7 +2203,25 @@ return mlirOperationVerify(self.getOperation()); }, "Verify the operation and return true if it passes, false if it " - "fails."); + "fails.") + .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), + "Puts self immediately after the other operation in its parent " + "block.") + .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), + "Puts self immediately before the other operation in its parent " + "block.") + .def( + "detach_from_parent", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + operation.checkValid(); + if (!operation.isAttached()) + throw py::value_error("Detached operation has no parent."); + + operation.detachFromParent(); + return operation.createOpView(); + }, + "Detaches the operation from its parent block."); py::class_(m, "Operation", py::module_local()) .def_static("create", &PyOperation::create, py::arg("name"), @@ -2380,7 +2416,20 @@ printAccum.getUserData()); return printAccum.join(); }, - "Returns the assembly form of the block."); + "Returns the assembly form of the block.") + .def( + "append", + [](PyBlock &self, PyOperationBase &operation) { + if (operation.getOperation().isAttached()) + operation.getOperation().detachFromParent(); + + MlirOperation mlirOperation = operation.getOperation().get(); + mlirBlockAppendOwnedOperation(self.get(), mlirOperation); + operation.getOperation().setAttached( + self.getParentOperation().getObject()); + }, + "Appends an operation to this block. If the operation is currently " + "in another block, it will be moved."); //---------------------------------------------------------------------------- // Mapping of PyInsertionPoint. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -399,6 +399,10 @@ bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); + /// Moves the operation before or after the other operation. + void moveAfter(PyOperationBase &other); + void moveBefore(PyOperationBase &other); + /// Each must provide access to the raw Operation. virtual PyOperation &getOperation() = 0; }; @@ -428,6 +432,14 @@ createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); + /// Detaches the operation from its parent block and updates its state + /// accordingly. + void detachFromParent() { + mlirOperationRemoveFromParent(getOperation()); + setDetached(); + parentKeepAlive = pybind11::object(); + } + /// Gets the backing operation. operator MlirOperation() const { return get(); } MlirOperation get() const { @@ -441,10 +453,14 @@ } bool isAttached() { return attached; } - void setAttached() { + void setAttached(pybind11::object parent = pybind11::object()) { assert(!attached && "operation already attached"); attached = true; } + void setDetached() { + assert(attached && "operation already detached"); + attached = false; + } void checkValid() const; /// Gets the owning block or raises an exception if the operation has no @@ -495,6 +511,8 @@ pybind11::object parentKeepAlive; bool attached = true; bool valid = true; + + friend class PyOperationBase; }; /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -338,6 +338,8 @@ void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } +void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); } + bool mlirOperationEqual(MlirOperation op, MlirOperation other) { return unwrap(op) == unwrap(other); } @@ -451,6 +453,14 @@ return succeeded(verify(unwrap(op))); } +void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { + return unwrap(op)->moveAfter(unwrap(other)); +} + +void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { + return unwrap(op)->moveBefore(unwrap(other)); +} + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -505,7 +505,7 @@ void Operation::moveAfter(Block *block, llvm::iplist::iterator iterator) { assert(iterator != block->end() && "cannot move after end of block"); - moveBefore(&*std::next(iterator)); + moveBefore(block, std::next(iterator)); } /// This drops all operand uses from this operation, which is an essential diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -740,3 +740,66 @@ op = Operation.create("custom.op", loc=loc) assert op.location == loc assert op.operation.location == loc + +# CHECK-LABEL: TEST: testModuleMerge +@run +def testModuleMerge(): + with Context(): + m1 = Module.parse("func private @foo()") + m2 = Module.parse(""" + func private @bar() + func private @qux() + """) + foo = m1.body.operations[0] + bar = m2.body.operations[0] + qux = m2.body.operations[1] + bar.move_before(foo) + qux.move_after(foo) + + # CHECK: module + # CHECK: func private @bar + # CHECK: func private @foo + # CHECK: func private @qux + print(m1) + + # CHECK: module { + # CHECK-NEXT: } + print(m2) + + +# CHECK-LABEL: TEST: testAppendMoveFromAnotherBlock +@run +def testAppendMoveFromAnotherBlock(): + with Context(): + m1 = Module.parse("func private @foo()") + m2 = Module.parse("func private @bar()") + func = m1.body.operations[0] + m2.body.append(func) + + # CHECK: module + # CHECK: func private @bar + # CHECK: func private @foo + + print(m2) + # CHECK: module { + # CHECK-NEXT: } + print(m1) + + +# CHECK-LABEL: TEST: testDetachFromParent +@run +def testDetachFromParent(): + with Context(): + m1 = Module.parse("func private @foo()") + func = m1.body.operations[0].detach_from_parent() + + try: + func.detach_from_parent() + except ValueError as e: + if "has no parent" not in str(e): + raise + else: + assert False, "expected ValueError when detaching a detached operation" + + print(m1) + # CHECK-NOT: func private @foo