diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -491,18 +491,36 @@ /// This method erases all operations in a block. virtual void eraseBlock(Block *block); - /// Merge the operations of block 'source' into the end of block 'dest'. - /// 'source's predecessors must either be empty or only contain 'dest`. - /// 'argValues' is used to replace the block arguments of 'source' after - /// merging. - virtual void mergeBlocks(Block *source, Block *dest, - ValueRange argValues = std::nullopt); + /// Merge the operations of block 'source' to block 'dest' before the given + /// position. The source block will be deleted and must have no uses. + /// 'argValues' is used to replace the block arguments of 'source'. + /// + /// If the source block is inserted at the end of the dest block, the dest + /// block must have no successors. Similarly, if the source block is inserted + /// somewhere in the middle (or beginning) of the dest block, the source block + /// must have no successors. Otherwise, the resulting IR would have + /// unreachable operations. + void mergeBlockBefore(Block *source, Block *dest, Block::iterator before, + ValueRange argValues = std::nullopt); - // Merge the operations of block 'source' before the operation 'op'. Source - // block should not have existing predecessors or successors. + /// Merge the operations of block 'source' before the operation 'op'. The + /// source block will be deleted and must have no uses. 'argValues' is used to + /// replace the block arguments of 'source' + /// + /// The source block must have no successors. Otherwise, the resulting IR + /// would have unreachable operations. void mergeBlockBefore(Block *source, Operation *op, ValueRange argValues = std::nullopt); + /// Merge the operations of block 'source' into the end of block 'dest'. The + /// source block will be deleted and must have no uses. 'argValues' is used to + /// replace the block arguments of 'source' + /// + /// The dest block must have no successors. Otherwise, the resulting IR would + /// have unreachable operation. + virtual void mergeBlocks(Block *source, Block *dest, + ValueRange argValues = std::nullopt); + /// Split the operations starting at "before" (inclusive) out of the given /// block into a new block, and return it. virtual Block *splitBlock(Block *block, Block::iterator before); diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -147,8 +147,9 @@ return failure(); // Merge the successor into the current block and erase the branch. - rewriter.mergeBlocks(succ, opParent, op.getOperands()); + SmallVector brOperands(op.getOperands()); rewriter.eraseOp(op); + rewriter.mergeBlocks(succ, opParent, brOperands); return success(); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -339,29 +339,6 @@ rewriteListener->notifyOperationModified(op); } -/// Merge the operations of block 'source' into the end of block 'dest'. -/// 'source's predecessors must be empty or only contain 'dest`. -/// 'argValues' is used to replace the block arguments of 'source' after -/// merging. -void RewriterBase::mergeBlocks(Block *source, Block *dest, - ValueRange argValues) { - assert(llvm::all_of(source->getPredecessors(), - [dest](Block *succ) { return succ == dest; }) && - "expected 'source' to have no predecessors or only 'dest'"); - assert(argValues.size() == source->getNumArguments() && - "incorrect # of argument replacement values"); - - // Replace all of the successor arguments with the provided values. - for (auto it : llvm::zip(source->getArguments(), argValues)) - replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - - // Splice the operations of the 'source' block into the 'dest' block and erase - // it. - dest->getOperations().splice(dest->end(), source->getOperations()); - source->dropAllUses(); - source->erase(); -} - /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. @@ -373,26 +350,48 @@ } } -// Merge the operations of block 'source' before the operation 'op'. Source -// block should not have existing predecessors or successors. -void RewriterBase::mergeBlockBefore(Block *source, Operation *op, +void RewriterBase::mergeBlockBefore(Block *source, Block *dest, + Block::iterator before, ValueRange argValues) { + assert(argValues.size() == source->getNumArguments() && + "incorrect # of argument replacement values"); + + // The source block will be deleted, so it should not have any users (i.e., + // there should be no predecessors). assert(source->hasNoPredecessors() && "expected 'source' to have no predecessors"); - assert(source->hasNoSuccessors() && - "expected 'source' to have no successors"); - // Split the block containing 'op' into two, one containing all operations - // before 'op' (prologue) and another (epilogue) containing 'op' and all - // operations after it. - Block *prologue = op->getBlock(); - Block *epilogue = splitBlock(prologue, op->getIterator()); + if (dest->end() != before) { + // The source block will be inserted in the middle of the dest block, so + // the source block should have no successors. Otherwise, the remainder of + // the dest block would be unreachable. + assert(source->hasNoSuccessors() && + "expected 'source' to have no successors"); + } else { + // The source block will be inserted at the end of the dest block, so the + // dest block should have no successors. Otherwise, the inserted operations + // will be unreachable. + assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors"); + } + + // Replace all of the successor arguments with the provided values. + for (auto it : llvm::zip(source->getArguments(), argValues)) + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); + + // Move operations from the source block to the dest block and erase the + // source block. + dest->getOperations().splice(before, source->getOperations()); + source->erase(); +} - // Merge the source block at the end of the prologue. - mergeBlocks(source, prologue, argValues); +void RewriterBase::mergeBlockBefore(Block *source, Operation *op, + ValueRange argValues) { + mergeBlockBefore(source, op->getBlock(), op->getIterator(), argValues); +} - // Merge the epilogue at the end the prologue. - mergeBlocks(epilogue, prologue); +void RewriterBase::mergeBlocks(Block *source, Block *dest, + ValueRange argValues) { + mergeBlockBefore(source, dest, dest->end(), argValues); } /// Split the operations starting at "before" (inclusive) out of the given diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1539,12 +1539,13 @@ void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest, ValueRange argValues) { - impl->notifyBlocksBeingMerged(dest, source); - assert(llvm::all_of(source->getPredecessors(), - [dest](Block *succ) { return succ == dest; }) && - "expected 'source' to have no predecessors or only 'dest'"); assert(argValues.size() == source->getNumArguments() && "incorrect # of argument replacement values"); + assert(llvm::all_of(source->getUsers(), + [&](Operation *op) { return impl->isOpIgnored(op); }) && + "expected 'source' to have no predecessors"); + + impl->notifyBlocksBeingMerged(dest, source); for (auto it : llvm::zip(source->getArguments(), argValues)) replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); dest->getOperations().splice(dest->end(), source->getOperations());