diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -611,12 +611,11 @@ TypeConversion }; -/// Original position of the given block in its parent region. We cannot use -/// a region iterator because it could have been invalidated by other region -/// operations since the position was stored. +/// Original position of the given block in its parent region. During undo +/// actions, the block needs to be placed after `insertAfterBlock`. struct BlockPosition { Region *region; - Region::iterator::difference_type position; + Block *insertAfterBlock; }; /// Information needed to undo the merge actions. @@ -634,8 +633,8 @@ static BlockAction getCreate(Block *block) { return {BlockActionKind::Create, block, {}}; } - static BlockAction getErase(Block *block, BlockPosition originalPos) { - return {BlockActionKind::Erase, block, {originalPos}}; + static BlockAction getErase(Block *block, BlockPosition originalPosition) { + return {BlockActionKind::Erase, block, {originalPosition}}; } static BlockAction getMerge(Block *block, Block *sourceBlock) { BlockAction action{BlockActionKind::Merge, block, {}}; @@ -644,8 +643,8 @@ : &(block->back())}; return action; } - static BlockAction getMove(Block *block, BlockPosition originalPos) { - return {BlockActionKind::Move, block, {originalPos}}; + static BlockAction getMove(Block *block, BlockPosition originalPosition) { + return {BlockActionKind::Move, block, {originalPosition}}; } static BlockAction getSplit(Block *block, Block *originalBlock) { BlockAction action{BlockActionKind::Split, block, {}}; @@ -990,9 +989,11 @@ // Put the block (owned by action) back into its original position. case BlockActionKind::Erase: { auto &blockList = action.originalPosition.region->getBlocks(); - blockList.insert( - std::next(blockList.begin(), action.originalPosition.position), - action.block); + blockList.insert((action.originalPosition.insertAfterBlock + ? std::next(Region::iterator( + action.originalPosition.insertAfterBlock)) + : blockList.end()), + action.block); break; } // Split the block at the position which was originally the end of the @@ -1013,7 +1014,10 @@ case BlockActionKind::Move: { Region *originalRegion = action.originalPosition.region; originalRegion->getBlocks().splice( - std::next(originalRegion->begin(), action.originalPosition.position), + (action.originalPosition.insertAfterBlock + ? std::next( + Region::iterator(action.originalPosition.insertAfterBlock)) + : originalRegion->end()), action.block->getParent()->getBlocks(), action.block); break; } @@ -1191,8 +1195,9 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { Region *region = block->getParent(); - auto position = std::distance(region->begin(), Region::iterator(block)); - blockActions.push_back(BlockAction::getErase(block, {region, position})); + Block *origPrevBlock = + (&*region->begin() == block ? nullptr : &(*(--Region::iterator(block)))); + blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock})); } void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) { @@ -1211,10 +1216,12 @@ void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( Region ®ion, Region &parent, Region::iterator before) { + Block *origPrevBlock = nullptr; for (auto &pair : llvm::enumerate(region)) { Block &block = pair.value(); - Region::iterator::difference_type position = pair.index(); - blockActions.push_back(BlockAction::getMove(&block, {®ion, position})); + blockActions.push_back( + BlockAction::getMove(&block, {®ion, origPrevBlock})); + origPrevBlock = █ } }