diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -509,7 +509,7 @@ /// The kind of the block action performed during the rewrite. Actions can be /// undone if the conversion fails. - enum class BlockActionKind { Create, Move, Split, TypeConversion }; + enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion }; /// Original position of the given block in its parent region. We cannot use /// a region iterator because it could have been invalidated by other region @@ -525,6 +525,9 @@ static BlockAction getCreate(Block *block) { return {BlockActionKind::Create, block, {}}; } + static BlockAction getErase(Block *block, BlockPosition originalPos) { + return {BlockActionKind::Erase, block, {originalPos}}; + } static BlockAction getMove(Block *block, BlockPosition originalPos) { return {BlockActionKind::Move, block, {originalPos}}; } @@ -544,9 +547,9 @@ Block *block; union { - // In use if kind == BlockActionKind::Move and contains a pointer to the - // region that originally contained the block as well as the position of - // the block in that region. + // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and + // contains a pointer to the region that originally contained the block as + // well as the position of the block in that region. BlockPosition originalPosition; // In use if kind == BlockActionKind::Split and contains a pointer to the // block that was split into two parts. @@ -564,6 +567,10 @@ /// Reset the state of the rewriter to a previously saved point. void resetState(RewriterState state); + /// Erase any blocks that were unlinked from their regions and stored in block + /// actions. + void eraseDanglingBlocks(); + /// Undo the block actions (motions, splits) one by one in reverse order until /// "numActionsToKeep" actions remains. void undoBlockActions(unsigned numActionsToKeep = 0); @@ -587,6 +594,9 @@ /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ValueRange newValues); + /// Notifies that a block is about to be erased. + void notifyBlockIsBeingErased(Block *block); + /// Notifies that a block was created. void notifyCreatedBlock(Block *block); @@ -711,6 +721,14 @@ ignoredOps.pop_back(); } +void ConversionPatternRewriterImpl::eraseDanglingBlocks() { + for (auto &action : blockActions) { + if (action.kind != BlockActionKind::Erase) + continue; + delete action.block; + } +} + void ConversionPatternRewriterImpl::undoBlockActions( unsigned numActionsToKeep) { for (auto &action : @@ -727,6 +745,14 @@ action.block->erase(); break; } + // Put the block (owned by action) back into its original position. + case BlockActionKind::Erase: { + auto &blockList = action.originalPosition.region->getBlocks(); + blockList.insert( + std::next(blockList.begin(), action.originalPosition.position), + action.block); + break; + } // Move the block back to its original position. case BlockActionKind::Move: { Region *originalRegion = action.originalPosition.region; @@ -806,6 +832,9 @@ repl.op->erase(); argConverter.applyRewrites(mapping); + + // Now that the ops have been erased, also erase dangling blocks. + eraseDanglingBlocks(); } LogicalResult @@ -853,6 +882,12 @@ markNestedOpsIgnored(op); } +void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { + Region *region = block->getParent(); + auto position = std::distance(region->begin(), Region::iterator(block)); + blockActions.push_back(BlockAction::getErase(block, {region, position})); +} + void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) { blockActions.push_back(BlockAction::getCreate(block)); } @@ -942,7 +977,17 @@ } void ConversionPatternRewriter::eraseBlock(Block *block) { - llvm_unreachable("erasing blocks for dialect conversion not implemented"); + impl->notifyBlockIsBeingErased(block); + + // Mark all ops for erasure. + for (Operation &op : *block) + eraseOp(&op); + + // Unlink the block from its parent region. The block is kept in the block + // action and will be actually destroyed when rewrites are applied. This + // allows us to keep the operations in the block live and undo the removal by + // re-inserting the block. + block->getParent()->getBlocks().remove(block); } /// Apply a signature conversion to the entry block of the given region. @@ -1334,7 +1379,8 @@ i != e; ++i) { auto &action = rewriterImpl.blockActions[i]; if (action.kind == - ConversionPatternRewriterImpl::BlockActionKind::TypeConversion) + ConversionPatternRewriterImpl::BlockActionKind::TypeConversion || + action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase) continue; // Convert the block signature. diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -236,6 +236,27 @@ // ----- +// The op in this function is rewritten to itself (and thus remains illegal) by +// a pattern that removes its second block after adding an operation into it. +// Check that we can undo block removal succesfully. +// CHECK-LABEL: @undo_block_erase +func @undo_block_erase() { + // CHECK: test.undo_block_erase + "test.undo_block_erase"() ({ + // expected-remark@-1 {{not legalizable}} + // CHECK: "unregistered.return"()[^[[BB:.*]]] + "unregistered.return"()[^bb1] : () -> () + // expected-remark@-1 {{not legalizable}} + // CHECK: ^[[BB]] + ^bb1: + // CHECK: unregistered.return + "unregistered.return"() : () -> () + // expected-remark@-1 {{not legalizable}} + }) : () -> () +} + +// ----- + // The op in this function is attempted to be rewritten to another illegal op // with an attached region containing an invalid terminator. The terminator is // created before the parent op. The deletion should not crash when deleting diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -281,6 +281,23 @@ } }; +/// A rewrite pattern that tests the undo mechanism when erasing a block. +struct TestUndoBlockErase : public ConversionPattern { + TestUndoBlockErase(MLIRContext *ctx) + : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {} + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Block *secondBlock = &*std::next(op->getRegion(0).begin()); + rewriter.setInsertionPointToStart(secondBlock); + rewriter.create(op->getLoc(), rewriter.getF32Type()); + rewriter.eraseBlock(secondBlock); + rewriter.updateRootInPlace(op, [] {}); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Type-Conversion Rewrite Testing @@ -504,14 +521,14 @@ TestTypeConverter converter; mlir::OwningRewritePatternList patterns; populateWithGenerated(&getContext(), &patterns); - patterns.insert(&getContext()); + patterns.insert< + TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock, + TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase, + TestPassthroughInvalidOp, TestSplitReturnType, + TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64, + TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType, + TestNonRootReplacement, TestBoundedRecursiveRewrite, + TestNestedOpCreationUndoRewrite>(&getContext()); patterns.insert(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);