diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -988,7 +988,7 @@ Block *insertAfterBlock = action.originalPosition.insertAfterBlock; blockList.insert((insertAfterBlock ? std::next(Region::iterator(insertAfterBlock)) - : blockList.end()), + : blockList.begin()), action.block); break; } diff --git a/mlir/test/Transforms/test-legalizer-full.mlir b/mlir/test/Transforms/test-legalizer-full.mlir --- a/mlir/test/Transforms/test-legalizer-full.mlir +++ b/mlir/test/Transforms/test-legalizer-full.mlir @@ -84,3 +84,18 @@ "test.return"() : () -> () } + +// ----- + +// Test that multiple block erases can be properly undone. +func @test_undo_block_erase() { + // expected-error@+1 {{failed to legalize operation 'test.region'}} + "test.region"() ({ + ^bb1(%i0: i64): + br ^bb2(%i0 : i64) + ^bb2(%i1: i64): + "test.invalid"(%i1) : (i64) -> () + }) {legalizer.should_clone, legalizer.erase_old_blocks} : () -> () + + "test.return"() : () -> () +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -203,12 +203,16 @@ ConversionPatternRewriter &rewriter) const final { // Inline this region into the parent region. auto &parentRegion = *op->getParentRegion(); + auto &opRegion = op->getRegion(0); if (op->getAttr("legalizer.should_clone")) - rewriter.cloneRegionBefore(op->getRegion(0), parentRegion, - parentRegion.end()); + rewriter.cloneRegionBefore(opRegion, parentRegion, parentRegion.end()); else - rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, - parentRegion.end()); + rewriter.inlineRegionBefore(opRegion, parentRegion, parentRegion.end()); + + if (op->getAttr("legalizer.erase_old_blocks")) { + while (!opRegion.empty()) + rewriter.eraseBlock(&opRegion.front()); + } // Drop this operation. rewriter.eraseOp(op);