diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -602,7 +602,14 @@ /// The kind of the block action performed during the rewrite. Actions can be /// undone if the conversion fails. -enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion }; +enum class BlockActionKind { + Create, + Erase, + Merge, + Move, + Split, + TypeConversion +}; /// Original position of the given block in its parent region. We cannot use /// a region iterator because it could have been invalidated by other region @@ -612,6 +619,15 @@ Region::iterator::difference_type position; }; +/// Information needed to undo the merge actions. +/// - the source block, and +/// - the Operation that was the last operation in the dest block before the +/// merge (could be null if the dest block was empty). +struct MergeInfo { + Block *sourceBlock; + Operation *destBlockLastInst; +}; + /// The storage class for an undoable block action (one of BlockActionKind), /// contains the information necessary to undo this action. struct BlockAction { @@ -621,6 +637,11 @@ static BlockAction getErase(Block *block, BlockPosition originalPos) { return {BlockActionKind::Erase, block, {originalPos}}; } + static BlockAction getMerge(Block *block, Block *sourceBlock) { + BlockAction action{BlockActionKind::Merge, block, {}}; + action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()}; + return action; + } static BlockAction getMove(Block *block, BlockPosition originalPos) { return {BlockActionKind::Move, block, {originalPos}}; } @@ -647,6 +668,9 @@ // In use if kind == BlockActionKind::Split and contains a pointer to the // block that was split into two parts. Block *originalBlock; + // In use if kind == BlockActionKind::Merge, and contains the information + // needed to undo the merge. + MergeInfo mergeInfo; }; }; } // end anonymous namespace @@ -738,6 +762,9 @@ /// Notifies that a block was split. void notifySplitBlock(Block *block, Block *continuation); + /// Notifies that `block` is being merged with `srcBlock`. + void notifyBlocksBeingMerged(Block *block, Block *srcBlock); + /// Notifies that the blocks of a region are about to be moved. void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, Region::iterator before); @@ -966,6 +993,20 @@ action.block); break; } + // Split the block at the position which was originally the end of the + // destination block (owned by action), and put the instructions back into + // the block used before the merge. + case BlockActionKind::Merge: { + Block *sourceBlock = action.mergeInfo.sourceBlock; + Block::iterator splitPoint = + (action.mergeInfo.destBlockLastInst + ? ++Block::iterator(action.mergeInfo.destBlockLastInst) + : action.block->begin()); + sourceBlock->getOperations().splice(sourceBlock->begin(), + action.block->getOperations(), + splitPoint, action.block->end()); + break; + } // Move the block back to its original position. case BlockActionKind::Move: { Region *originalRegion = action.originalPosition.region; @@ -1161,6 +1202,11 @@ blockActions.push_back(BlockAction::getSplit(continuation, block)); } +void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block, + Block *srcBlock) { + blockActions.push_back(BlockAction::getMerge(block, srcBlock)); +} + void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( Region ®ion, Region &parent, Region::iterator before) { for (auto &pair : llvm::enumerate(region)) { @@ -1283,9 +1329,16 @@ /// PatternRewriter hook for merging a block into another. void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest, ValueRange argValues) { - // TODO: This requires fixing the implementation of - // 'replaceUsesOfBlockArgument', which currently isn't undoable. - llvm_unreachable("block merging updates are currently not supported"); + impl->notifyBlocksBeingMerged(dest, source); + assert(llvm::all_of(source->getPredecessors(), + [dest](Block *succ) { return succ == dest; }) && + "expected 'source' to have no predecessors or only 'dest'"); + assert(argValues.size() == source->getNumArguments() && + "incorrect # of argument replacement values"); + for (auto it : llvm::zip(source->getArguments(), argValues)) + replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); + dest->getOperations().splice(dest->end(), source->getOperations()); + eraseBlock(source); } /// PatternRewriter hook for moving blocks out of a region. diff --git a/mlir/test/Transforms/test-merge-blocks.mlir b/mlir/test/Transforms/test-merge-blocks.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-merge-blocks.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-merge-blocks -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: @merge_blocks +func @merge_blocks(%arg0: i32, %arg1 : i32) -> () { + // CHECK: "test.merge_blocks"() ( { + // CHECK-NEXT: "test.return" + // CHECK-NEXT: }) + // CHECK-NEXT: "test.return" + %0:2 = "test.merge_blocks"() ({ + ^bb0: + "test.br"(%arg0, %arg1)[^bb1] : (i32, i32) -> () + ^bb1(%arg3 : i32, %arg4 : i32): + "test.return"(%arg3, %arg4) : (i32, i32) -> () + }) : () -> (i32, i32) + "test.return"(%0#0, %0#1) : (i32, i32) -> () +} + +// ----- + +// The op in this function is rewritten to itself (and thus remains +// illegal) by a pattern that merges the second block with the first +// after adding an operation into it. Check that we can undo block +// removal succesfully. +// CHECK-LABEL: @undo_blocks_merge +func @undo_blocks_merge(%arg0: i32) { + "test.undo_blocks_merge"() ({ + // expected-remark@-1 {{op 'test.undo_blocks_merge' is not legalizable}} + // CHECK: "unregistered.return"(%{{.*}})[^[[BB:.*]]] + "unregistered.return"(%arg0)[^bb1] : (i32) -> () + // expected-remark@-1 {{op 'unregistered.return' is not legalizable}} + // CHECK: ^[[BB]] + ^bb1(%arg1 : i32): + // CHECK: "unregistered.return" + "unregistered.return"(%arg1) : (i32) -> () + // expected-remark@-1 {{op 'unregistered.return' is not legalizable}} + }) : () -> () +} + +// ----- + +// CHECK-LABEL: @inline_regions() +func @inline_regions() -> () +{ + // CHECK: test.SingleBlockImplicitTerminator + // CHECK-NEXT: %[[T0:.*]] = "test.type_producer" + // CHECK-NEXT: "test.type_consumer"(%[[T0]]) + // CHECK-NEXT: "test.finish" + "test.SingleBlockImplicitTerminator"() ({ + ^bb0: + %0 = "test.type_producer"() : () -> i32 + "test.SingleBlockImplicitTerminator"() ({ + ^bb1: + "test.type_consumer"(%0) : (i32) -> () + "test.finish"() : () -> () + }) : () -> () + "test.finish"() : () -> () + }) : () -> () + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1163,6 +1163,16 @@ def TestValidOp : TEST_Op<"valid", [Terminator]>, Arguments<(ins Variadic)>; +def TestMergeBlocksOp : TEST_Op<"merge_blocks"> { + let summary = "merge_blocks operation"; + let description = [{ + Test op with multiple blocks that are merged with Dialect Conversion" + }]; + + let regions = (region AnyRegion:$body); + let results = (outs Variadic:$result); +} + //===----------------------------------------------------------------------===// // Test parser. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -834,6 +834,118 @@ }; } // end anonymous namespace +namespace { +/// A rewriter pattern that tests that blocks can be merged. +struct TestMergeBlock : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TestMergeBlocksOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Block &firstBlock = op.body().front(); + Operation *branchOp = firstBlock.getTerminator(); + Block *secondBlock = &*(std::next(op.body().begin())); + auto succOperands = branchOp->getOperands(); + SmallVector replacements(succOperands); + rewriter.eraseOp(branchOp); + rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; + +/// A rewrite pattern to tests the undo mechanism of blocks being merged. +struct TestUndoBlocksMerge : public ConversionPattern { + TestUndoBlocksMerge(MLIRContext *ctx) + : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {} + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Block &firstBlock = op->getRegion(0).front(); + Operation *branchOp = firstBlock.getTerminator(); + Block *secondBlock = &*(std::next(op->getRegion(0).begin())); + rewriter.setInsertionPointToStart(secondBlock); + rewriter.create(op->getLoc(), rewriter.getF32Type()); + auto succOperands = branchOp->getOperands(); + SmallVector replacements(succOperands); + rewriter.eraseOp(branchOp); + rewriter.mergeBlocks(secondBlock, &firstBlock, replacements); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; + +/// A rewrite mechanism to inline the body of the op into its parent, when both +/// ops can have a single block. +struct TestMergeSingleBlockOps + : public OpConversionPattern { + using OpConversionPattern< + SingleBlockImplicitTerminatorOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + SingleBlockImplicitTerminatorOp parentOp = + op.getParentOfType(); + if (!parentOp) + return failure(); + Block &parentBlock = parentOp.region().front(); + Block &innerBlock = op.region().front(); + TerminatorOp innerTerminator = + cast(innerBlock.getTerminator()); + Block *parentPrologue = + rewriter.splitBlock(&parentBlock, Block::iterator(op)); + rewriter.eraseOp(innerTerminator); + rewriter.mergeBlocks(&innerBlock, &parentBlock, {}); + rewriter.eraseOp(op); + rewriter.mergeBlocks(parentPrologue, &parentBlock, {}); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; + +struct TestMergeBlocksPatternDriver + : public PassWrapper> { + void runOnOperation() override { + mlir::OwningRewritePatternList patterns; + MLIRContext *context = &getContext(); + patterns + .insert( + context); + ConversionTarget target(*context); + target.addLegalOp(); + target.addIllegalOp(); + + /// Expect the op to have a single block after legalization. + target.addDynamicallyLegalOp( + [&](TestMergeBlocksOp op) -> bool { + return llvm::hasSingleElement(op.body()); + }); + + /// Only allow `test.br` within test.merge_blocks op. + target.addDynamicallyLegalOp([&](TestBranchOp op) -> bool { + return op.getParentOfType(); + }); + + /// Expect that all nested test.SingleBlockImplicitTerminator ops are + /// inlined. + target.addDynamicallyLegalOp( + [&](SingleBlockImplicitTerminatorOp op) -> bool { + return !op.getParentOfType(); + }); + + DenseSet unlegalizedOps; + (void)applyPartialConversion(getOperation(), target, patterns, + &unlegalizedOps); + for (auto *op : unlegalizedOps) + op->emitRemark() << "op '" << op->getName() << "' is not legalizable"; + } +}; +} // namespace + //===----------------------------------------------------------------------===// // PassRegistration //===----------------------------------------------------------------------===// @@ -866,5 +978,9 @@ PassRegistration( "test-legalize-type-conversion", "Test various type conversion functionalities in DialectConversion"); + + PassRegistration{ + "test-merge-blocks", + "Test Merging operation in ConversionPatternRewriter"}; } } // namespace mlir