diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -558,6 +558,9 @@ /// Takes a block owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirBlockDestroy(MlirBlock block); +/// Detach a block from the owning region and assume ownership. +MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block); + /// Checks whether a block is null. static inline bool mlirBlockIsNull(MlirBlock block) { return !block.ptr; } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2755,6 +2755,15 @@ py::arg("parent"), py::arg("arg_types") = py::list(), "Creates and returns a new Block at the beginning of the given " "region (with given argument types).") + .def( + "append_to", + [](PyBlock &self, PyRegion ®ion) { + MlirBlock b = self.get(); + if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) + mlirBlockDetach(b); + mlirRegionAppendOwnedBlock(region.get(), b); + }, + "Append this block to a region, transferring ownership if necessary") .def( "create_before", [](PyBlock &self, py::args pyArgTypes) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -634,6 +634,11 @@ void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } +void mlirBlockDetach(MlirBlock block) { + Block *b = unwrap(block); + b->getParent()->getBlocks().remove(b); +} + intptr_t mlirBlockGetNumArguments(MlirBlock block) { return static_cast(unwrap(block)->getNumArguments()); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -510,15 +510,18 @@ MlirType i2 = mlirIntegerTypeGet(ctx, 2); MlirType i3 = mlirIntegerTypeGet(ctx, 3); MlirType i4 = mlirIntegerTypeGet(ctx, 4); + MlirType i5 = mlirIntegerTypeGet(ctx, 5); MlirBlock block1 = mlirBlockCreate(1, &i1, &loc); MlirBlock block2 = mlirBlockCreate(1, &i2, &loc); MlirBlock block3 = mlirBlockCreate(1, &i3, &loc); MlirBlock block4 = mlirBlockCreate(1, &i4, &loc); + MlirBlock block5 = mlirBlockCreate(1, &i5, &loc); // Insert blocks so as to obtain the 1-2-3-4 order, mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3); mlirRegionInsertOwnedBlockBefore(region, block3, block2); mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1); mlirRegionInsertOwnedBlockAfter(region, block3, block4); + mlirRegionInsertOwnedBlockBefore(region, block3, block5); MlirOperationState op1State = mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op1"), loc); @@ -534,6 +537,8 @@ mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op6"), loc); MlirOperationState op7State = mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op7"), loc); + MlirOperationState op8State = + mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op8"), loc); MlirOperation op1 = mlirOperationCreate(&op1State); MlirOperation op2 = mlirOperationCreate(&op2State); MlirOperation op3 = mlirOperationCreate(&op3State); @@ -541,6 +546,7 @@ MlirOperation op5 = mlirOperationCreate(&op5State); MlirOperation op6 = mlirOperationCreate(&op6State); MlirOperation op7 = mlirOperationCreate(&op7State); + MlirOperation op8 = mlirOperationCreate(&op8State); // Insert operations in the first block so as to obtain the 1-2-3-4 order. MlirOperation nullOperation = mlirBlockGetFirstOperation(block1); @@ -555,6 +561,11 @@ mlirBlockAppendOwnedOperation(block2, op5); mlirBlockAppendOwnedOperation(block3, op6); mlirBlockAppendOwnedOperation(block4, op7); + mlirBlockAppendOwnedOperation(block5, op8); + + // Remove block5. + mlirBlockDetach(block5); + mlirBlockDestroy(block5); mlirOperationDump(op); mlirOperationDestroy(op); @@ -568,6 +579,8 @@ // CHECK-NEXT: "dummy.op4" // CHECK: ^{{.*}}(%{{.*}}: i2 // CHECK: "dummy.op5" + // CHECK-NOT: ^{{.*}}(%{{.*}}: i5 + // CHECK-NOT: "dummy.op8" // CHECK: ^{{.*}}(%{{.*}}: i3 // CHECK: "dummy.op6" // CHECK: ^{{.*}}(%{{.*}}: i4 diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py --- a/mlir/test/python/ir/blocks.py +++ b/mlir/test/python/ir/blocks.py @@ -70,3 +70,27 @@ print(module) assert module.operation.verify() assert f.body.blocks[0] == entry_block + + +# CHECK-LABEL: TEST: testBlockMove +# CHECK: %0 = "realop"() ({ +# CHECK: ^bb0([[ARG0:%.+]]: f32): +# CHECK: "ret"([[ARG0]]) : (f32) -> () +# CHECK: }) : () -> f32 +@run +def testBlockMove(): + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + module = Module.create() + f32 = F32Type.get() + with InsertionPoint(module.body): + dummy = Operation.create("dummy", regions=1) + block = Block.create_at_start(dummy.operation.regions[0], [f32]) + with InsertionPoint(block): + ret_op = Operation.create("ret", operands=[block.arguments[0]]) + realop = Operation.create("realop", + results=[r.type for r in ret_op.operands], + regions=1) + block.append_to(realop.operation.regions[0]) + dummy.operation.erase() + print(module)