diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -491,17 +491,36 @@ /// This method erases all operations in a block. virtual void eraseBlock(Block *block); - /// Merge the operations of block 'source' into the end of block 'dest'. - /// 'source's predecessors must either be empty or only contain 'dest`. - /// 'argValues' is used to replace the block arguments of 'source' after - /// merging. - virtual void mergeBlocks(Block *source, Block *dest, - ValueRange argValues = std::nullopt); - - // Merge the operations of block 'source' before the operation 'op'. Source - // block should not have existing predecessors or successors. - void mergeBlockBefore(Block *source, Operation *op, - ValueRange argValues = std::nullopt); + /// Inline the operations of block 'source' into block 'dest' before the given + /// position. The source block will be deleted and must have no uses. + /// 'argValues' is used to replace the block arguments of 'source'. + /// + /// If the source block is inserted at the end of the dest block, the dest + /// block must have no successors. Similarly, if the source block is inserted + /// somewhere in the middle (or beginning) of the dest block, the source block + /// must have no successors. Otherwise, the resulting IR would have + /// unreachable operations. + virtual void inlineBlockBefore(Block *source, Block *dest, + Block::iterator before, + ValueRange argValues = std::nullopt); + + /// Inline the operations of block 'source' before the operation 'op'. The + /// source block will be deleted and must have no uses. 'argValues' is used to + /// replace the block arguments of 'source' + /// + /// The source block must have no successors. Otherwise, the resulting IR + /// would have unreachable operations. + void inlineBlockBefore(Block *source, Operation *op, + ValueRange argValues = std::nullopt); + + /// Inline the operations of block 'source' into the end of block 'dest'. The + /// source block will be deleted and must have no uses. 'argValues' is used to + /// replace the block arguments of 'source' + /// + /// The dest block must have no successors. Otherwise, the resulting IR would + /// have unreachable operation. + void mergeBlocks(Block *source, Block *dest, + ValueRange argValues = std::nullopt); /// Split the operations starting at "before" (inclusive) out of the given /// block into a new block, and return it. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -702,8 +702,10 @@ /// PatternRewriter hook for splitting a block into two parts. Block *splitBlock(Block *block, Block::iterator before) override; - /// PatternRewriter hook for merging a block into another. - void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override; + /// PatternRewriter hook for inlining the ops of a block into another block. + void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, + ValueRange argValues = std::nullopt) override; + using PatternRewriter::inlineBlockBefore; /// PatternRewriter hook for moving blocks out of a region. void inlineRegionBefore(Region ®ion, Region &parent, diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -506,7 +506,7 @@ Value arg = iterArgs[yieldOperands.size()]; yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0)); rewriter.eraseOp(reduceBlock.getTerminator()); - rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()}); + rewriter.inlineBlockBefore(&reduceBlock, &op, {arg, reduce.getOperand()}); rewriter.eraseOp(reduce); } @@ -516,8 +516,8 @@ if (newBody->empty()) rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs); else - rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(), - ivs); + rewriter.inlineBlockBefore(parallelOp.getBody(), newBody->getTerminator(), + ivs); // Finally, create the terminator if required (for loops with no results, it // has been already created in loop construction). diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2724,7 +2724,7 @@ Operation *blockToMoveTerminator = blockToMove->getTerminator(); // Promote the "blockToMove" block to the parent operation block between the // prologue and epilogue of "op". - rewriter.mergeBlockBefore(blockToMove, op); + rewriter.inlineBlockBefore(blockToMove, op); // Replace the "op" operation with the operands of the // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is // the affine.yield operation present in the "blockToMove" block. It has no diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp --- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp @@ -147,8 +147,9 @@ return failure(); // Merge the successor into the current block and erase the branch. - rewriter.mergeBlocks(succ, opParent, op.getOperands()); + SmallVector brOperands(op.getOperands()); rewriter.eraseOp(op); + rewriter.mergeBlocks(succ, opParent, brOperands); return success(); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -553,7 +553,7 @@ Block *block = &op.getRegion().front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); - rewriter.mergeBlockBefore(block, op); + rewriter.inlineBlockBefore(block, op); rewriter.replaceOp(op, results); rewriter.eraseOp(terminator); return success(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -107,7 +107,7 @@ Block *block = ®ion.front(); Operation *terminator = block->getTerminator(); ValueRange results = terminator->getOperands(); - rewriter.mergeBlockBefore(block, op, blockArgs); + rewriter.inlineBlockBefore(block, op, blockArgs); rewriter.replaceOp(op, results); rewriter.eraseOp(terminator); } @@ -623,7 +623,7 @@ // the ForOp region and can just be forwarded after simplifying the op inits, // yields and returns. // -// The implementation uses `mergeBlockBefore` to steal the content of the +// The implementation uses `inlineBlockBefore` to steal the content of the // original ForOp and avoid cloning. struct ForOpIterArgsFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -638,7 +638,7 @@ // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to // transformed block argument mappings. This plays the role of a // IRMapping for the particular use case of calling into - // `mergeBlockBefore`. + // `inlineBlockBefore`. SmallVector keepMask; keepMask.reserve(yieldOp.getNumOperands()); SmallVector newBlockTransferArgs, newIterArgs, newYieldValues, @@ -708,7 +708,7 @@ // original terminator that has been merged in. if (newIterArgs.empty()) { auto newYieldOp = cast(newBlock.getTerminator()); - rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs); + rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs); rewriter.eraseOp(newBlock.getTerminator()->getPrevNode()); rewriter.replaceOp(forOp, newResultValues); return success(); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -168,7 +168,7 @@ auto cloned = cast(rewriter.clone(*op.getOperation())); assert(args.size() == cloned.getBody()->getNumArguments()); Operation *yield = cloned.getBody()->getTerminator(); - rewriter.mergeBlockBefore(cloned.getBody(), op, args); + rewriter.inlineBlockBefore(cloned.getBody(), op, args); // clean up rewriter.eraseOp(cloned); reduc = yield->getOperands(); @@ -988,7 +988,8 @@ // This is annoying, since scf.for inserts a implicit yield op when // there is no reduction variable upon creation, in this case we need to // merge the block *before* the yield op. - rewriter.mergeBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), args); + rewriter.inlineBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), + args); } for (Dimension d = 0; d < dimRank; d++) { diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -1211,7 +1211,7 @@ YieldOp clonedYield = cast(clonedBlock.getTerminator()); // Merge cloned block and return yield value. Operation *placeholder = rewriter.create(loc, 0); - rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); + rewriter.inlineBlockBefore(&tmpRegion.front(), placeholder, vals); Value val = clonedYield.getResult(); rewriter.eraseOp(clonedYield); rewriter.eraseOp(placeholder); diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -339,29 +339,6 @@ rewriteListener->notifyOperationModified(op); } -/// Merge the operations of block 'source' into the end of block 'dest'. -/// 'source's predecessors must be empty or only contain 'dest`. -/// 'argValues' is used to replace the block arguments of 'source' after -/// merging. -void RewriterBase::mergeBlocks(Block *source, Block *dest, - ValueRange argValues) { - assert(llvm::all_of(source->getPredecessors(), - [dest](Block *succ) { return succ == dest; }) && - "expected 'source' to have no predecessors or only 'dest'"); - assert(argValues.size() == source->getNumArguments() && - "incorrect # of argument replacement values"); - - // Replace all of the successor arguments with the provided values. - for (auto it : llvm::zip(source->getArguments(), argValues)) - replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - - // Splice the operations of the 'source' block into the 'dest' block and erase - // it. - dest->getOperations().splice(dest->end(), source->getOperations()); - source->dropAllUses(); - source->erase(); -} - /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. @@ -373,26 +350,48 @@ } } -// Merge the operations of block 'source' before the operation 'op'. Source -// block should not have existing predecessors or successors. -void RewriterBase::mergeBlockBefore(Block *source, Operation *op, - ValueRange argValues) { +void RewriterBase::inlineBlockBefore(Block *source, Block *dest, + Block::iterator before, + ValueRange argValues) { + assert(argValues.size() == source->getNumArguments() && + "incorrect # of argument replacement values"); + + // The source block will be deleted, so it should not have any users (i.e., + // there should be no predecessors). assert(source->hasNoPredecessors() && "expected 'source' to have no predecessors"); - assert(source->hasNoSuccessors() && - "expected 'source' to have no successors"); - // Split the block containing 'op' into two, one containing all operations - // before 'op' (prologue) and another (epilogue) containing 'op' and all - // operations after it. - Block *prologue = op->getBlock(); - Block *epilogue = splitBlock(prologue, op->getIterator()); + if (dest->end() != before) { + // The source block will be inserted in the middle of the dest block, so + // the source block should have no successors. Otherwise, the remainder of + // the dest block would be unreachable. + assert(source->hasNoSuccessors() && + "expected 'source' to have no successors"); + } else { + // The source block will be inserted at the end of the dest block, so the + // dest block should have no successors. Otherwise, the inserted operations + // will be unreachable. + assert(dest->hasNoSuccessors() && "expected 'dest' to have no successors"); + } - // Merge the source block at the end of the prologue. - mergeBlocks(source, prologue, argValues); + // Replace all of the successor arguments with the provided values. + for (auto it : llvm::zip(source->getArguments(), argValues)) + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - // Merge the epilogue at the end the prologue. - mergeBlocks(epilogue, prologue); + // Move operations from the source block to the dest block and erase the + // source block. + dest->getOperations().splice(before, source->getOperations()); + source->erase(); +} + +void RewriterBase::inlineBlockBefore(Block *source, Operation *op, + ValueRange argValues) { + inlineBlockBefore(source, op->getBlock(), op->getIterator(), argValues); +} + +void RewriterBase::mergeBlocks(Block *source, Block *dest, + ValueRange argValues) { + inlineBlockBefore(source, dest, dest->end(), argValues); } /// Split the operations starting at "before" (inclusive) out of the given diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -241,7 +241,7 @@ enum class BlockActionKind { Create, Erase, - Merge, + Inline, Move, Split, TypeConversion @@ -254,13 +254,14 @@ Block *insertAfterBlock; }; -/// Information needed to undo the merge actions. -/// - the source block, and -/// - the Operation that was the last operation in the dest block before the -/// merge (could be null if the dest block was empty). -struct MergeInfo { +/// Information needed to undo inlining actions. +/// - the source block +/// - the first inlined operation (could be null if the source block was empty) +/// - the last inlined operation (could be null if the source block was empty) +struct InlineInfo { Block *sourceBlock; - Operation *destBlockLastInst; + Operation *firstInlinedInst; + Operation *lastInlinedInst; }; /// The storage class for an undoable block action (one of BlockActionKind), @@ -272,9 +273,12 @@ static BlockAction getErase(Block *block, BlockPosition originalPosition) { return {BlockActionKind::Erase, block, {originalPosition}}; } - static BlockAction getMerge(Block *block, Block *sourceBlock) { - BlockAction action{BlockActionKind::Merge, block, {}}; - action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()}; + static BlockAction getInline(Block *block, Block *srcBlock, + Block::iterator before) { + BlockAction action{BlockActionKind::Inline, block, {}}; + action.inlineInfo = {srcBlock, + srcBlock->empty() ? nullptr : &srcBlock->front(), + srcBlock->empty() ? nullptr : &srcBlock->back()}; return action; } static BlockAction getMove(Block *block, BlockPosition originalPosition) { @@ -296,16 +300,16 @@ Block *block; union { - // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and + // In use if kind == BlockActionKind::Inline or BlockActionKind::Erase, and // contains a pointer to the region that originally contained the block as // well as the position of the block in that region. BlockPosition originalPosition; // In use if kind == BlockActionKind::Split and contains a pointer to the // block that was split into two parts. Block *originalBlock; - // In use if kind == BlockActionKind::Merge, and contains the information - // needed to undo the merge. - MergeInfo mergeInfo; + // In use if kind == BlockActionKind::Inline, and contains the information + // needed to undo the inlining. + InlineInfo inlineInfo; }; }; @@ -897,8 +901,9 @@ /// Notifies that a block was split. void notifySplitBlock(Block *block, Block *continuation); - /// Notifies that `block` is being merged with `srcBlock`. - void notifyBlocksBeingMerged(Block *block, Block *srcBlock); + /// Notifies that a block is being inlined into another block. + void notifyBlockBeingInlined(Block *block, Block *srcBlock, + Block::iterator before); /// Notifies that the blocks of a region are about to be moved. void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, @@ -1161,18 +1166,17 @@ action.block); break; } - // Split the block at the position which was originally the end of the - // destination block (owned by action), and put the instructions back into - // the block used before the merge. - case BlockActionKind::Merge: { - Block *sourceBlock = action.mergeInfo.sourceBlock; - Block::iterator splitPoint = - (action.mergeInfo.destBlockLastInst - ? ++Block::iterator(action.mergeInfo.destBlockLastInst) - : action.block->begin()); - sourceBlock->getOperations().splice(sourceBlock->begin(), - action.block->getOperations(), - splitPoint, action.block->end()); + // Put the instructions from the destination block (owned by the action) + // back into the source block. + case BlockActionKind::Inline: { + Block *sourceBlock = action.inlineInfo.sourceBlock; + if (action.inlineInfo.firstInlinedInst) { + assert(action.inlineInfo.lastInlinedInst && "expected operation"); + sourceBlock->getOperations().splice( + sourceBlock->begin(), action.block->getOperations(), + Block::iterator(action.inlineInfo.firstInlinedInst), + ++Block::iterator(action.inlineInfo.lastInlinedInst)); + } break; } // Move the block back to its original position. @@ -1393,9 +1397,9 @@ blockActions.push_back(BlockAction::getSplit(continuation, block)); } -void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block, - Block *srcBlock) { - blockActions.push_back(BlockAction::getMerge(block, srcBlock)); +void ConversionPatternRewriterImpl::notifyBlockBeingInlined( + Block *block, Block *srcBlock, Block::iterator before) { + blockActions.push_back(BlockAction::getInline(block, srcBlock, before)); } void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( @@ -1537,17 +1541,23 @@ return continuation; } -void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest, - ValueRange argValues) { - impl->notifyBlocksBeingMerged(dest, source); - assert(llvm::all_of(source->getPredecessors(), - [dest](Block *succ) { return succ == dest; }) && - "expected 'source' to have no predecessors or only 'dest'"); +void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, + Block::iterator before, + ValueRange argValues) { assert(argValues.size() == source->getNumArguments() && "incorrect # of argument replacement values"); +#ifndef NDEBUG + auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); }; +#endif // NDEBUG + // The source block will be deleted, so it should not have any users (i.e., + // there should be no predecessors). + assert(llvm::all_of(source->getUsers(), opIgnored) && + "expected 'source' to have no predecessors"); + + impl->notifyBlockBeingInlined(dest, source, before); for (auto it : llvm::zip(source->getArguments(), argValues)) replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); - dest->getOperations().splice(dest->end(), source->getOperations()); + dest->getOperations().splice(before, source->getOperations()); eraseBlock(source); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1592,7 +1592,7 @@ Block &innerBlock = op.getRegion().front(); TerminatorOp innerTerminator = cast(innerBlock.getTerminator()); - rewriter.mergeBlockBefore(&innerBlock, op); + rewriter.inlineBlockBefore(&innerBlock, op); rewriter.eraseOp(innerTerminator); rewriter.eraseOp(op); rewriter.updateRootInPlace(op, [] {});