diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -414,20 +414,15 @@ }; //===----------------------------------------------------------------------===// -// PatternRewriter +// RewriterBase //===----------------------------------------------------------------------===// -/// This class coordinates the application of a pattern to the current function, -/// providing a way to create operations and keep track of what gets deleted. -/// -/// These class serves two purposes: -/// 1) it is the interface that patterns interact with to make mutations to the -/// IR they are being applied to. -/// 2) It is a base class that clients of the PatternMatcher use when they want -/// to apply patterns and observe their effects (e.g. to keep worklists or -/// other data structures up to date). -/// -class PatternRewriter : public OpBuilder, public OpBuilder::Listener { +/// This class coordinates the application of a rewrite on a set of IR, +/// providing a way for clients to track mutations and create new operations. +/// This class serves as a common API for IR mutation between pattern rewrites +/// and non-pattern rewrites, and facilitates the development of shared +/// IR transformation utilities. +class RewriterBase : public OpBuilder, public OpBuilder::Listener { public: /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller @@ -452,10 +447,10 @@ /// `newValues` when the provided `functor` returns true for a specific use. /// The number of values in `newValues` is required to match the number of /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of - /// the uses of `op` were replaced. Note that in some pattern rewriters, the - /// given 'functor' may be stored beyond the lifetime of the pattern being - /// applied. As such, the function should not capture by reference and instead - /// use value capture as necessary. + /// the uses of `op` were replaced. Note that in some rewriters, the given + /// 'functor' may be stored beyond the lifetime of the rewrite being applied. + /// As such, the function should not capture by reference and instead use + /// value capture as necessary. virtual void replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function functor); @@ -472,9 +467,9 @@ void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, bool *allUsesReplaced = nullptr); - /// This method performs the final replacement for a pattern, where the - /// results of the operation are updated to use the specified list of SSA - /// values. + /// This method replaces the results of the operation with the specified list + /// of values. The number of provided values must match the number of results + /// of the operation. virtual void replaceOp(Operation *op, ValueRange newValues); /// Replaces the result op with a new op that is created without verification. @@ -534,10 +529,10 @@ finalizeRootUpdate(root); } - /// Notify the pattern rewriter that the pattern is failing to match the given - /// operation, and provide a callback to populate a diagnostic with the reason - /// why the failure occurred. This method allows for derived rewriters to - /// optionally hook into the reason why a pattern failed, and display it to + /// Used to notify the rewriter that the IR failed to be rewritten because of + /// a match failure, and provide a callback to populate a diagnostic with the + /// reason why the failure occurred. This method allows for derived rewriters + /// to optionally hook into the reason why a rewrite failed, and display it to /// users. template std::enable_if_t::value, LogicalResult> @@ -558,28 +553,29 @@ protected: /// Initialize the builder with this rewriter as the listener. - explicit PatternRewriter(MLIRContext *ctx) - : OpBuilder(ctx, /*listener=*/this) {} - ~PatternRewriter() override; + explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {} + explicit RewriterBase(const OpBuilder &otherBuilder) + : OpBuilder(otherBuilder) { + setListener(this); + } + ~RewriterBase() override; /// These are the callback methods that subclasses can choose to implement if /// they would like to be notified about certain types of mutations. - /// Notify the pattern rewriter that the specified operation is about to be - /// replaced with another set of operations. This is called before the uses - /// of the operation have been changed. + /// Notify the rewriter that the specified operation is about to be replaced + /// with another set of operations. This is called before the uses of the + /// operation have been changed. virtual void notifyRootReplaced(Operation *op) {} - /// This is called on an operation that a pattern match is removing, right - /// before the operation is deleted. At this point, the operation has zero - /// uses. + /// This is called on an operation that a rewrite is removing, right before + /// the operation is deleted. At this point, the operation has zero uses. virtual void notifyOperationRemoved(Operation *op) {} - /// Notify the pattern rewriter that the pattern is failing to match the given - /// operation, and provide a callback to populate a diagnostic with the reason - /// why the failure occurred. This method allows for derived rewriters to - /// optionally hook into the reason why a pattern failed, and display it to - /// users. + /// Notify the rewriter that the pattern failed to match the given operation, + /// and provide a callback to populate a diagnostic with the reason why the + /// failure occurred. This method allows for derived rewriters to optionally + /// hook into the reason why a rewrite failed, and display it to users. virtual LogicalResult notifyMatchFailure(Operation *op, function_ref reasonCallback) { @@ -592,6 +588,35 @@ void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp); }; +//===----------------------------------------------------------------------===// +// IRRewriter +//===----------------------------------------------------------------------===// + +/// This class coordinates rewriting a piece of IR outside of a pattern rewrite, +/// providing a way to keep track of the mutations made to the IR. This class +/// should only be used in situations where another `RewriterBase` instance, +/// such as a `PatternRewriter`, is not available. +class IRRewriter : public RewriterBase { +public: + explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} + explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} +}; + +//===----------------------------------------------------------------------===// +// PatternRewriter +//===----------------------------------------------------------------------===// + +/// A special type of `RewriterBase` that coordinates the application of a +/// rewrite pattern on the current IR being matched, providing a way to keep +/// track of any mutations made. This class should be used to perform all +/// necessary IR mutations within a rewrite pattern, as the pattern driver may +/// be tracking various state that would be invalidated when a mutation takes +/// place. +class PatternRewriter : public RewriterBase { +public: + using RewriterBase::RewriterBase; +}; + //===----------------------------------------------------------------------===// // OwningRewritePatternList //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -148,10 +148,10 @@ } //===----------------------------------------------------------------------===// -// PatternRewriter +// RewriterBase //===----------------------------------------------------------------------===// -PatternRewriter::~PatternRewriter() { +RewriterBase::~RewriterBase() { // Out of line to provide a vtable anchor for the class. } @@ -159,7 +159,7 @@ /// `newValues` when the provided `functor` returns true for a specific use. /// The number of values in `newValues` is required to match the number of /// results of `op`. -void PatternRewriter::replaceOpWithIf( +void RewriterBase::replaceOpWithIf( Operation *op, ValueRange newValues, bool *allUsesReplaced, llvm::unique_function functor) { assert(op->getNumResults() == newValues.size() && @@ -182,18 +182,17 @@ /// `newValues` when a use is nested within the given `block`. The number of /// values in `newValues` is required to match the number of results of `op`. /// If all uses of this operation are replaced, the operation is erased. -void PatternRewriter::replaceOpWithinBlock(Operation *op, ValueRange newValues, - Block *block, - bool *allUsesReplaced) { +void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues, + Block *block, bool *allUsesReplaced) { replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) { return block->getParentOp()->isProperAncestor(use.getOwner()); }); } -/// This method performs the final replacement for a pattern, where the -/// results of the operation are updated to use the specified list of SSA -/// values. -void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) { +/// This method replaces the results of the operation with the specified list of +/// values. The number of provided values must match the number of results of +/// the operation. +void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { // Notify the rewriter subclass that we're about to replace this root. notifyRootReplaced(op); @@ -207,13 +206,13 @@ /// This method erases an operation that is known to have no uses. The uses of /// the given operation *must* be known to be dead. -void PatternRewriter::eraseOp(Operation *op) { +void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); notifyOperationRemoved(op); op->erase(); } -void PatternRewriter::eraseBlock(Block *block) { +void RewriterBase::eraseBlock(Block *block) { for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) { assert(op.use_empty() && "expected 'op' to have no uses"); eraseOp(&op); @@ -225,8 +224,8 @@ /// 'source's predecessors must be empty or only contain 'dest`. /// 'argValues' is used to replace the block arguments of 'source' after /// merging. -void PatternRewriter::mergeBlocks(Block *source, Block *dest, - ValueRange argValues) { +void RewriterBase::mergeBlocks(Block *source, Block *dest, + ValueRange argValues) { assert(llvm::all_of(source->getPredecessors(), [dest](Block *succ) { return succ == dest; }) && "expected 'source' to have no predecessors or only 'dest'"); @@ -246,8 +245,8 @@ // Merge the operations of block 'source' before the operation 'op'. Source // block should not have existing predecessors or successors. -void PatternRewriter::mergeBlockBefore(Block *source, Operation *op, - ValueRange argValues) { +void RewriterBase::mergeBlockBefore(Block *source, Operation *op, + ValueRange argValues) { assert(source->hasNoPredecessors() && "expected 'source' to have no predecessors"); assert(source->hasNoSuccessors() && @@ -268,14 +267,14 @@ /// Split the operations starting at "before" (inclusive) out of the given /// block into a new block, and return it. -Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) { +Block *RewriterBase::splitBlock(Block *block, Block::iterator before) { return block->splitBlock(before); } /// 'op' and 'newOp' are known to have the same number of results, replace the /// uses of op with uses of newOp -void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op, - Operation *newOp) { +void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op, + Operation *newOp) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); if (op->getNumResults() == 1) @@ -287,11 +286,11 @@ /// another region. The two regions must be different. The caller is in /// charge to update create the operation transferring the control flow to the /// region and pass it the correct block arguments. -void PatternRewriter::inlineRegionBefore(Region ®ion, Region &parent, - Region::iterator before) { +void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent, + Region::iterator before) { parent.getBlocks().splice(before, region.getBlocks()); } -void PatternRewriter::inlineRegionBefore(Region ®ion, Block *before) { +void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) { inlineRegionBefore(region, *before->getParent(), before->getIterator()); } @@ -299,17 +298,16 @@ /// another region "parent". The two regions must be different. The caller is /// responsible for creating or updating the operation transferring flow of /// control to the region and passing it the correct block arguments. -void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent, - Region::iterator before, - BlockAndValueMapping &mapping) { +void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, + Region::iterator before, + BlockAndValueMapping &mapping) { region.cloneInto(&parent, before, mapping); } -void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent, - Region::iterator before) { +void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, + Region::iterator before) { BlockAndValueMapping mapping; cloneRegionBefore(region, parent, before, mapping); } -void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) { +void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) { cloneRegionBefore(region, *before->getParent(), before->getIterator()); } -