diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -611,12 +611,11 @@ TypeConversion }; -/// Original position of the given block in its parent region. We cannot use -/// a region iterator because it could have been invalidated by other region -/// operations since the position was stored. +/// Original position of the given block in its parent region. During undo +/// actions, the block needs to be placed after `insertAfterBlock`. struct BlockPosition { Region *region; - Region::iterator::difference_type position; + Block *insertAfterBlock; }; /// Information needed to undo the merge actions. @@ -634,16 +633,16 @@ static BlockAction getCreate(Block *block) { return {BlockActionKind::Create, block, {}}; } - static BlockAction getErase(Block *block, BlockPosition originalPos) { - return {BlockActionKind::Erase, block, {originalPos}}; + static BlockAction getErase(Block *block, BlockPosition originalPosition) { + return {BlockActionKind::Erase, block, {originalPosition}}; } static BlockAction getMerge(Block *block, Block *sourceBlock) { BlockAction action{BlockActionKind::Merge, block, {}}; action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()}; return action; } - static BlockAction getMove(Block *block, BlockPosition originalPos) { - return {BlockActionKind::Move, block, {originalPos}}; + static BlockAction getMove(Block *block, BlockPosition originalPosition) { + return {BlockActionKind::Move, block, {originalPosition}}; } static BlockAction getSplit(Block *block, Block *originalBlock) { BlockAction action{BlockActionKind::Split, block, {}}; @@ -988,9 +987,11 @@ // Put the block (owned by action) back into its original position. case BlockActionKind::Erase: { auto &blockList = action.originalPosition.region->getBlocks(); - blockList.insert( - std::next(blockList.begin(), action.originalPosition.position), - action.block); + Block *insertAfterBlock = action.originalPosition.insertAfterBlock; + blockList.insert((insertAfterBlock + ? std::next(Region::iterator(insertAfterBlock)) + : blockList.end()), + action.block); break; } // Split the block at the position which was originally the end of the @@ -1010,8 +1011,10 @@ // Move the block back to its original position. case BlockActionKind::Move: { Region *originalRegion = action.originalPosition.region; + Block *insertAfterBlock = action.originalPosition.insertAfterBlock; originalRegion->getBlocks().splice( - std::next(originalRegion->begin(), action.originalPosition.position), + (insertAfterBlock ? std::next(Region::iterator(insertAfterBlock)) + : originalRegion->end()), action.block->getParent()->getBlocks(), action.block); break; } @@ -1189,8 +1192,8 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) { Region *region = block->getParent(); - auto position = std::distance(region->begin(), Region::iterator(block)); - blockActions.push_back(BlockAction::getErase(block, {region, position})); + Block *origPrevBlock = block->getPrevNode(); + blockActions.push_back(BlockAction::getErase(block, {region, origPrevBlock})); } void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) { @@ -1209,10 +1212,12 @@ void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( Region ®ion, Region &parent, Region::iterator before) { + Block *origPrevBlock = nullptr; for (auto &pair : llvm::enumerate(region)) { Block &block = pair.value(); - Region::iterator::difference_type position = pair.index(); - blockActions.push_back(BlockAction::getMove(&block, {®ion, position})); + blockActions.push_back( + BlockAction::getMove(&block, {®ion, origPrevBlock})); + origPrevBlock = █ } }