diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -207,7 +207,10 @@ } /// Return true if this block has no predecessors. - bool hasNoPredecessors(); + bool hasNoPredecessors() { return pred_begin() == pred_end(); } + + /// Returns true if this blocks has no successors. + bool hasNoSuccessors() { return succ_begin() == succ_end(); } /// If this block has exactly one predecessor, return it. Otherwise, return /// null. diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -326,6 +326,11 @@ virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues = llvm::None); + // Merge the operations of block 'source' before the operation 'op'. Source + // block should not have existing predecessors or successors. + void mergeBlockBefore(Block *source, Operation *op, + ValueRange argValues = llvm::None); + /// Split the operations starting at "before" (inclusive) out of the given /// block into a new block, and return it. virtual Block *splitBlock(Block *block, Block::iterator before); diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -201,9 +201,6 @@ return &back(); } -/// Return true if this block has no predecessors. -bool Block::hasNoPredecessors() { return pred_begin() == pred_end(); } - // Indexed successor access. unsigned Block::getNumSuccessors() { return empty() ? 0 : back().getNumSuccessors(); diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -128,6 +128,28 @@ source->erase(); } +// Merge the operations of block 'source' before the operation 'op'. Source +// block should not have existing predecessors or successors. +void PatternRewriter::mergeBlockBefore(Block *source, Operation *op, + ValueRange argValues) { + assert(source->hasNoPredecessors() && + "expected 'source' to have no predecessors"); + assert(source->hasNoSuccessors() && + "expected 'source' to have no successors"); + + // Split the block containing 'op' into two, one containg all operations + // before 'op' (prologue) and another (epilogue) containing 'op' and all + // operations after it. + Block *prologue = op->getBlock(); + Block *epilogue = splitBlock(prologue, op->getIterator()); + + // Merge the source block at the end of the prologue. + mergeBlocks(source, prologue, argValues); + + // Merge the epilogue at the end the prologue. + mergeBlocks(epilogue, prologue); +} + /// Split the operations starting at "before" (inclusive) out of the given /// block into a new block, and return it. Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) { diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -889,16 +889,12 @@ op.getParentOfType(); if (!parentOp) return failure(); - Block &parentBlock = parentOp.region().front(); Block &innerBlock = op.region().front(); TerminatorOp innerTerminator = cast(innerBlock.getTerminator()); - Block *parentPrologue = - rewriter.splitBlock(&parentBlock, Block::iterator(op)); + rewriter.mergeBlockBefore(&innerBlock, op); rewriter.eraseOp(innerTerminator); - rewriter.mergeBlocks(&innerBlock, &parentBlock, {}); rewriter.eraseOp(op); - rewriter.mergeBlocks(parentPrologue, &parentBlock, {}); rewriter.updateRootInPlace(op, [] {}); return success(); }