diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -259,10 +259,23 @@ void mlirRegionAppendOwnedBlock(MlirRegion region, MlirBlock block); /** Takes a block owned by the caller and inserts it at `pos` to the given - * region. */ + * region. This is an expensive operation that linearly scans the region, prefer + * insertAfter/Before instead. */ void mlirRegionInsertOwnedBlock(MlirRegion region, intptr_t pos, MlirBlock block); +/** Takes a block owned by the caller and inserts it after the (non-owned) + * reference block in the given region. The reference block must belong to the + * region. If the reference block is null, prepends the block to the region. */ +void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, + MlirBlock block); + +/** Takes a block owned by the caller and inserts it before the (non-owned) + * reference block in the given region. The reference block must belong to the + * region. If the reference block is null, appends the block to the region. */ +void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, + MlirBlock block); + /*============================================================================*/ /* Block API. */ /*============================================================================*/ @@ -288,10 +301,25 @@ void mlirBlockAppendOwnedOperation(MlirBlock block, MlirOperation operation); /** Takes an operation owned by the caller and inserts it as `pos` to the block. - */ + This is an expensive operation that scans the block linearly, prefer + insertBefore/After instead. */ void mlirBlockInsertOwnedOperation(MlirBlock block, intptr_t pos, MlirOperation operation); +/** Takes an operation owned by the caller and inserts it after the (non-owned) + * reference operation in the given block. If the reference is null, prepends + * the operation. Otherwise, the reference must belong to the block. */ +void mlirBlockInsertOwnedOperationAfter(MlirBlock block, + MlirOperation reference, + MlirOperation operation); + +/** Takes an operation owned by the caller and inserts it before the (non-owned) + * reference operation in the given block. If the reference is null, appends the + * operation. Otherwise, the reference must belong to the block. */ +void mlirBlockInsertOwnedOperationBefore(MlirBlock block, + MlirOperation reference, + MlirOperation operation); + /** Returns the number of arguments of the block. */ intptr_t mlirBlockGetNumArguments(MlirBlock block); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -255,6 +255,31 @@ blockList.insert(std::next(blockList.begin(), pos), unwrap(block)); } +void mlirRegionInsertOwnedBlockAfter(MlirRegion region, MlirBlock reference, + MlirBlock block) { + Region *cppRegion = unwrap(region); + if (mlirBlockIsNull(reference)) { + cppRegion->getBlocks().insert(cppRegion->begin(), unwrap(block)); + return; + } + + assert(unwrap(reference)->getParent() == unwrap(region) && + "expected reference block to belong to the region"); + cppRegion->getBlocks().insertAfter(Region::iterator(unwrap(reference)), + unwrap(block)); +} + +void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, + MlirBlock block) { + if (mlirBlockIsNull(reference)) + return mlirRegionAppendOwnedBlock(region, block); + + assert(unwrap(reference)->getParent() == unwrap(region) && + "expected reference block to belong to the region"); + unwrap(region)->getBlocks().insert(Region::iterator(unwrap(reference)), + unwrap(block)); +} + void mlirRegionDestroy(MlirRegion region) { delete static_cast(region.ptr); } @@ -293,6 +318,33 @@ opList.insert(std::next(opList.begin(), pos), unwrap(operation)); } +void mlirBlockInsertOwnedOperationAfter(MlirBlock block, + MlirOperation reference, + MlirOperation operation) { + Block *cppBlock = unwrap(block); + if (mlirOperationIsNull(reference)) { + cppBlock->getOperations().insert(cppBlock->begin(), unwrap(operation)); + return; + } + + assert(unwrap(reference)->getBlock() == unwrap(block) && + "expected reference operation to belong to the block"); + cppBlock->getOperations().insertAfter(Block::iterator(unwrap(reference)), + unwrap(operation)); +} + +void mlirBlockInsertOwnedOperationBefore(MlirBlock block, + MlirOperation reference, + MlirOperation operation) { + if (mlirOperationIsNull(reference)) + return mlirBlockAppendOwnedOperation(block, operation); + + assert(unwrap(reference)->getBlock() == unwrap(block) && + "expected reference operation to belong to the block"); + unwrap(block)->getOperations().insert(Block::iterator(unwrap(reference)), + unwrap(operation)); +} + void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -245,6 +245,68 @@ fprintf(stderr, "\n"); } +/// Creates an operation with a region containing multiple blocks with +/// operations and dumps it. The blocks and operations are inserted using +/// block/operation-relative API and their final order is checked. +static void buildWithInsertionsAndPrint(MlirContext ctx) { + MlirLocation loc = mlirLocationUnknownGet(ctx); + + MlirRegion owningRegion = mlirRegionCreate(); + MlirBlock nullBlock = mlirRegionGetFirstBlock(owningRegion); + MlirOperationState state = mlirOperationStateGet("insertion.order.test", loc); + mlirOperationStateAddOwnedRegions(&state, 1, &owningRegion); + MlirOperation op = mlirOperationCreate(&state); + MlirRegion region = mlirOperationGetRegion(op, 0); + + // Use integer types of different bitwidth as block arguments in order to + // differentiate blocks. + MlirType i1 = mlirIntegerTypeGet(ctx, 1); + MlirType i2 = mlirIntegerTypeGet(ctx, 2); + MlirType i3 = mlirIntegerTypeGet(ctx, 3); + MlirType i4 = mlirIntegerTypeGet(ctx, 4); + MlirBlock block1 = mlirBlockCreate(1, &i1); + MlirBlock block2 = mlirBlockCreate(1, &i2); + MlirBlock block3 = mlirBlockCreate(1, &i3); + MlirBlock block4 = mlirBlockCreate(1, &i4); + // Insert blocks so as to obtain the 1-2-3-4 order, + mlirRegionInsertOwnedBlockBefore(region, nullBlock, block3); + mlirRegionInsertOwnedBlockBefore(region, block3, block2); + mlirRegionInsertOwnedBlockAfter(region, nullBlock, block1); + mlirRegionInsertOwnedBlockAfter(region, block3, block4); + + MlirOperationState op1State = mlirOperationStateGet("dummy.op1", loc); + MlirOperationState op2State = mlirOperationStateGet("dummy.op2", loc); + MlirOperationState op3State = mlirOperationStateGet("dummy.op3", loc); + MlirOperationState op4State = mlirOperationStateGet("dummy.op4", loc); + MlirOperationState op5State = mlirOperationStateGet("dummy.op5", loc); + MlirOperationState op6State = mlirOperationStateGet("dummy.op6", loc); + MlirOperationState op7State = mlirOperationStateGet("dummy.op7", loc); + MlirOperation op1 = mlirOperationCreate(&op1State); + MlirOperation op2 = mlirOperationCreate(&op2State); + MlirOperation op3 = mlirOperationCreate(&op3State); + MlirOperation op4 = mlirOperationCreate(&op4State); + MlirOperation op5 = mlirOperationCreate(&op5State); + MlirOperation op6 = mlirOperationCreate(&op6State); + MlirOperation op7 = mlirOperationCreate(&op7State); + + // Insert operations in the first block so as to obtain the 1-2-3-4 order. + MlirOperation nullOperation = mlirBlockGetFirstOperation(block1); + assert(mlirOperationIsNull(nullOperation)); + mlirBlockInsertOwnedOperationBefore(block1, nullOperation, op3); + mlirBlockInsertOwnedOperationBefore(block1, op3, op2); + mlirBlockInsertOwnedOperationAfter(block1, nullOperation, op1); + mlirBlockInsertOwnedOperationAfter(block1, op3, op4); + + // Append operations to the rest of blocks to make them non-empty and thus + // printable. + mlirBlockAppendOwnedOperation(block2, op5); + mlirBlockAppendOwnedOperation(block3, op6); + mlirBlockAppendOwnedOperation(block4, op7); + + mlirOperationDump(op); + mlirOperationDestroy(op); +} + /// Dumps instances of all standard types to check that C API works correctly. /// Additionally, performs simple identity checks that a standard type /// constructed with C API can be inspected and has the expected type. The @@ -763,6 +825,21 @@ mlirModuleDestroy(moduleOp); + buildWithInsertionsAndPrint(ctx); + // CHECK-LABEL: "insertion.order.test" + // CHECK: ^{{.*}}(%{{.*}}: i1 + // CHECK: "dummy.op1" + // CHECK-NEXT: "dummy.op2" + // CHECK-NEXT: "dummy.op3" + // CHECK-NEXT: "dummy.op4" + // CHECK: ^{{.*}}(%{{.*}}: i2 + // CHECK: "dummy.op5" + // CHECK: ^{{.*}}(%{{.*}}: i3 + // CHECK: "dummy.op6" + // CHECK: ^{{.*}}(%{{.*}}: i4 + // CHECK: "dummy.op7" + + // clang-format off // CHECK-LABEL: @types // CHECK: i32