diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -390,12 +390,12 @@ ]; let extraClassDeclaration = [{ - OpBuilder getThenBodyBuilder(OpBuilder::Listener *listener = nullptr) { + OpBuilder getThenBodyBuilder(RewriteListener *listener = nullptr) { Block* body = getBody(0); return getResults().empty() ? OpBuilder::atBlockTerminator(body, listener) : OpBuilder::atBlockEnd(body, listener); } - OpBuilder getElseBodyBuilder(OpBuilder::Listener *listener = nullptr) { + OpBuilder getElseBodyBuilder(RewriteListener *listener = nullptr) { Block* body = getBody(1); return getResults().empty() ? OpBuilder::atBlockTerminator(body, listener) : OpBuilder::atBlockEnd(body, listener); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -44,6 +44,10 @@ class AffineMap; class UnitAttr; +//===----------------------------------------------------------------------===// +// Builder +//===----------------------------------------------------------------------===// + /// This class is a general helper class for creating context-global objects /// like types, attributes, and affine expressions. class Builder { @@ -174,54 +178,138 @@ MLIRContext *context; }; +//===----------------------------------------------------------------------===// +// RewriteListener +//===----------------------------------------------------------------------===// + +/// This class represents a listener that can be used to hook on to various +/// rewrite events in an `OpBuilder` or `PatternRewriter`. The class is notified +/// by when: +/// +/// - an operation is removed +/// - an operation is inserted +/// - an operation is replaced +/// - a block is created +/// - a pattern match failed +/// +/// Listeners can be used to track IR mutations throughout pattern rewrites. +struct RewriteListener { + virtual ~RewriteListener(); + + /// These are the callback methods that subclasses can choose to implement if + /// they would like to be notified about certain types of mutations. + + /// Notification handler for when an operation is inserted into the builder. + /// op` is the operation that was inserted. + virtual void notifyOperationInserted(Operation *op) {} + + /// Notification handler for when a block is created using the builder. + /// `block` is the block that was created. + virtual void notifyBlockCreated(Block *block) {} + + /// Notification handler for when the specified operation is about to be + /// replaced with another set of operations. This is called before the uses of + /// the operation have been changed. + virtual void notifyRootReplaced(Operation *op) {} + + /// Notification handler for when an the specified operation is about to be + /// deleted. At this point, the operation has zero uses. + virtual void notifyOperationRemoved(Operation *op) {} + + /// Notify the listener that a pattern failed to match the given operation, + /// and provide a callback to populate a diagnostic with the reason why the + /// failure occurred. This method allows for derived listeners to optionally + /// hook into the reason why a rewrite failed, and display it to users. + virtual void + notifyMatchFailure(Operation *op, + function_ref reasonCallback) {} +}; + +//===----------------------------------------------------------------------===// +// ListenerList +//===----------------------------------------------------------------------===// + +/// This class contains multiple listeners to which rewrite events can be sent. +class ListenerList : public RewriteListener { +public: + /// Add a listener to the list. + void addListener(RewriteListener *listener) { listeners.push_back(listener); } + + /// Send notification of an operation being inserted to all listeners. + void notifyOperationInserted(Operation *op) override; + + /// Send notification of a block being created to all listeners. + void notifyBlockCreated(Block *block) override; + + /// Send notification that an operation has been replaced to all listeners. + void notifyRootReplaced(Operation *op) override; + + /// Send notification that an operation is about to be deleted to all + /// listeners. + void notifyOperationRemoved(Operation *op) override; + + /// Notify all listeners that a pattern match failed. + void + notifyMatchFailure(Operation *op, + function_ref reasonCallback) override; + +private: + /// The list of listeners to send events to. + SmallVector listeners; +}; + +//===----------------------------------------------------------------------===// +// OpBuilder +//===----------------------------------------------------------------------===// + /// This class helps build Operations. Operations that are created are /// automatically inserted at an insertion point. The builder is copyable. class OpBuilder : public Builder { public: - struct Listener; - /// Create a builder with the given context. - explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr) + explicit OpBuilder(MLIRContext *ctx, RewriteListener *listener = nullptr) : Builder(ctx), listener(listener) {} /// Create a builder and set the insertion point to the start of the region. - explicit OpBuilder(Region *region, Listener *listener = nullptr) + explicit OpBuilder(Region *region, RewriteListener *listener = nullptr) : OpBuilder(region->getContext(), listener) { if (!region->empty()) setInsertionPoint(®ion->front(), region->front().begin()); } - explicit OpBuilder(Region ®ion, Listener *listener = nullptr) + explicit OpBuilder(Region ®ion, RewriteListener *listener = nullptr) : OpBuilder(®ion, listener) {} /// Create a builder and set insertion point to the given operation, which /// will cause subsequent insertions to go right before it. - explicit OpBuilder(Operation *op, Listener *listener = nullptr) + explicit OpBuilder(Operation *op, RewriteListener *listener = nullptr) : OpBuilder(op->getContext(), listener) { setInsertionPoint(op); } OpBuilder(Block *block, Block::iterator insertPoint, - Listener *listener = nullptr) + RewriteListener *listener = nullptr) : OpBuilder(block->getParent()->getContext(), listener) { setInsertionPoint(block, insertPoint); } /// Create a builder and set the insertion point to before the first operation /// in the block but still inside the block. - static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) { + static OpBuilder atBlockBegin(Block *block, + RewriteListener *listener = nullptr) { return OpBuilder(block, block->begin(), listener); } /// Create a builder and set the insertion point to after the last operation /// in the block but still inside the block. - static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) { + static OpBuilder atBlockEnd(Block *block, + RewriteListener *listener = nullptr) { return OpBuilder(block, block->end(), listener); } /// Create a builder and set the insertion point to before the block /// terminator. static OpBuilder atBlockTerminator(Block *block, - Listener *listener = nullptr) { + RewriteListener *listener = nullptr) { auto *terminator = block->getTerminator(); assert(terminator != nullptr && "the block has no terminator"); return OpBuilder(block, Block::iterator(terminator), listener); @@ -231,26 +319,12 @@ // Listeners //===--------------------------------------------------------------------===// - /// This class represents a listener that may be used to hook into various - /// actions within an OpBuilder. - struct Listener { - virtual ~Listener(); - - /// Notification handler for when an operation is inserted into the builder. - /// `op` is the operation that was inserted. - virtual void notifyOperationInserted(Operation *op) {} - - /// Notification handler for when a block is created using the builder. - /// `block` is the block that was created. - virtual void notifyBlockCreated(Block *block) {} - }; - /// Sets the listener of this builder to the one provided. - void setListener(Listener *newListener) { listener = newListener; } + void setListener(RewriteListener *newListener) { listener = newListener; } /// Returns the current listener of this builder, or nullptr if this builder /// doesn't have a listener. - Listener *getListener() const { return listener; } + RewriteListener *getListener() const { return listener; } //===--------------------------------------------------------------------===// // Insertion Point Management @@ -509,7 +583,7 @@ /// before. Block::iterator insertPoint; /// The optional listener for events of this builder. - Listener *listener; + RewriteListener *listener; }; } // namespace mlir diff --git a/mlir/include/mlir/IR/ImplicitLocOpBuilder.h b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h --- a/mlir/include/mlir/IR/ImplicitLocOpBuilder.h +++ b/mlir/include/mlir/IR/ImplicitLocOpBuilder.h @@ -30,22 +30,24 @@ /// Create a builder and set the insertion point to before the first operation /// in the block but still inside the block. - static ImplicitLocOpBuilder atBlockBegin(Location loc, Block *block, - Listener *listener = nullptr) { + static ImplicitLocOpBuilder + atBlockBegin(Location loc, Block *block, + RewriteListener *listener = nullptr) { return ImplicitLocOpBuilder(loc, block, block->begin(), listener); } /// Create a builder and set the insertion point to after the last operation /// in the block but still inside the block. static ImplicitLocOpBuilder atBlockEnd(Location loc, Block *block, - Listener *listener = nullptr) { + RewriteListener *listener = nullptr) { return ImplicitLocOpBuilder(loc, block, block->end(), listener); } /// Create a builder and set the insertion point to before the block /// terminator. - static ImplicitLocOpBuilder atBlockTerminator(Location loc, Block *block, - Listener *listener = nullptr) { + static ImplicitLocOpBuilder + atBlockTerminator(Location loc, Block *block, + RewriteListener *listener = nullptr) { auto *terminator = block->getTerminator(); assert(terminator != nullptr && "the block has no terminator"); return ImplicitLocOpBuilder(loc, block, Block::iterator(terminator), diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -685,7 +685,7 @@ /// This class serves as a common API for IR mutation between pattern rewrites /// and non-pattern rewrites, and facilitates the development of shared /// IR transformation utilities. -class RewriterBase : public OpBuilder, public OpBuilder::Listener { +class RewriterBase : public OpBuilder { public: /// Move the blocks that belong to "region" before the given position in /// another region "parent". The two regions must be different. The caller @@ -801,11 +801,11 @@ std::enable_if_t::value, LogicalResult> notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { #ifndef NDEBUG - return notifyMatchFailure(op, - function_ref(reasonCallback)); -#else - return failure(); + if (RewriteListener *listener = getListener()) + listener->notifyMatchFailure( + op, function_ref(reasonCallback)); #endif + return failure(); } LogicalResult notifyMatchFailure(Operation *op, const Twine &msg) { return notifyMatchFailure(op, [&](Diagnostic &diag) { diag << msg; }); @@ -815,35 +815,10 @@ } protected: - /// Initialize the builder with this rewriter as the listener. - explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {} - explicit RewriterBase(const OpBuilder &otherBuilder) - : OpBuilder(otherBuilder) { - setListener(this); - } - ~RewriterBase() override; - - /// These are the callback methods that subclasses can choose to implement if - /// they would like to be notified about certain types of mutations. - - /// Notify the rewriter that the specified operation is about to be replaced - /// with another set of operations. This is called before the uses of the - /// operation have been changed. - virtual void notifyRootReplaced(Operation *op) {} - - /// This is called on an operation that a rewrite is removing, right before - /// the operation is deleted. At this point, the operation has zero uses. - virtual void notifyOperationRemoved(Operation *op) {} - - /// Notify the rewriter that the pattern failed to match the given operation, - /// and provide a callback to populate a diagnostic with the reason why the - /// failure occurred. This method allows for derived rewriters to optionally - /// hook into the reason why a rewrite failed, and display it to users. - virtual LogicalResult - notifyMatchFailure(Operation *op, - function_ref reasonCallback) { - return failure(); - } + /// Inherit constructors. + using OpBuilder::OpBuilder; + + virtual ~RewriterBase(); private: void operator=(const RewriterBase &) = delete; @@ -864,8 +839,7 @@ /// such as a `PatternRewriter`, is not available. class IRRewriter : public RewriterBase { public: - explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} - explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} + using RewriterBase::RewriterBase; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -523,7 +523,8 @@ /// This class implements a pattern rewriter for use with ConversionPatterns. It /// extends the base PatternRewriter and provides special conversion specific /// hooks. -class ConversionPatternRewriter final : public PatternRewriter { +class ConversionPatternRewriter final : public PatternRewriter, + public RewriteListener { public: explicit ConversionPatternRewriter(MLIRContext *ctx); ~ConversionPatternRewriter() override; @@ -634,7 +635,7 @@ void cancelRootUpdate(Operation *op) override; /// PatternRewriter hook for notifying match failure reasons. - LogicalResult + void notifyMatchFailure(Operation *op, function_ref reasonCallback) override; using PatternRewriter::notifyMatchFailure; diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -59,15 +59,22 @@ /// These methods also perform folding and simple dead-code elimination /// before attempting to match any of the provided patterns. /// -/// You may configure several aspects of this with GreedyRewriteConfig. -LogicalResult applyPatternsAndFoldGreedily( - MutableArrayRef regions, const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config = GreedyRewriteConfig()); +/// You may configure several aspects of this with GreedyRewriteConfig. A +/// rewrite listener can be supplied to hook on to rewrite events, such as +/// operations being removed or replaced. +LogicalResult +applyPatternsAndFoldGreedily(MutableArrayRef regions, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + RewriteListener *listener = nullptr); -/// Rewrite the given regions, which must be isolated from above. -inline LogicalResult applyPatternsAndFoldGreedily( - Operation *op, const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config = GreedyRewriteConfig()) { +/// Rewrite the regions of the given operation, which must be isolated from +/// above. +inline LogicalResult +applyPatternsAndFoldGreedily(Operation *op, + const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config = GreedyRewriteConfig(), + RewriteListener *listener = nullptr) { return applyPatternsAndFoldGreedily(op->getRegions(), patterns, config); } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -339,10 +339,44 @@ } //===----------------------------------------------------------------------===// -// OpBuilder +// RewriteListener +//===----------------------------------------------------------------------===// + +RewriteListener::~RewriteListener() {} + +//===----------------------------------------------------------------------===// +// ListenerList //===----------------------------------------------------------------------===// -OpBuilder::Listener::~Listener() {} +void ListenerList::notifyOperationInserted(Operation *op) { + for (RewriteListener *listener : listeners) + listener->notifyOperationInserted(op); +} + +void ListenerList::notifyBlockCreated(Block *block) { + for (RewriteListener *listener : listeners) + listener->notifyBlockCreated(block); +} + +void ListenerList::notifyRootReplaced(Operation *op) { + for (RewriteListener *listener : listeners) + listener->notifyRootReplaced(op); +} + +void ListenerList::notifyOperationRemoved(Operation *op) { + for (RewriteListener *listener : listeners) + listener->notifyOperationRemoved(op); +} + +void ListenerList::notifyMatchFailure( + Operation *op, function_ref reasonCallback) { + for (RewriteListener *listener : listeners) + listener->notifyMatchFailure(op, reasonCallback); +} + +//===----------------------------------------------------------------------===// +// OpBuilder +//===----------------------------------------------------------------------===// /// Insert the given operation at the current insertion point and return it. Operation *OpBuilder::insert(Operation *op) { diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -216,7 +216,8 @@ "incorrect number of values to replace operation"); // Notify the rewriter subclass that we're about to replace this root. - notifyRootReplaced(op); + if (RewriteListener *listener = getListener()) + listener->notifyRootReplaced(op); // Replace each use of the results when the functor is true. bool replacedAllUses = true; @@ -244,13 +245,15 @@ /// the operation. void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { // Notify the rewriter subclass that we're about to replace this root. - notifyRootReplaced(op); + if (RewriteListener *listener = getListener()) + listener->notifyRootReplaced(op); assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); op->replaceAllUsesWith(newValues); - notifyOperationRemoved(op); + if (RewriteListener *listener = getListener()) + listener->notifyOperationRemoved(op); op->erase(); } @@ -258,7 +261,8 @@ /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); - notifyOperationRemoved(op); + if (RewriteListener *listener = getListener()) + listener->notifyOperationRemoved(op); op->erase(); } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1490,7 +1490,7 @@ //===----------------------------------------------------------------------===// ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx) - : PatternRewriter(ctx), + : PatternRewriter(ctx, /*listener=*/this), impl(new detail::ConversionPatternRewriterImpl(*this)) {} ConversionPatternRewriter::~ConversionPatternRewriter() {} @@ -1669,9 +1669,9 @@ rootUpdates.erase(rootUpdates.begin() + updateIdx); } -LogicalResult ConversionPatternRewriter::notifyMatchFailure( +void ConversionPatternRewriter::notifyMatchFailure( Operation *op, function_ref reasonCallback) { - return impl->notifyMatchFailure(op->getLoc(), reasonCallback); + (void)impl->notifyMatchFailure(op->getLoc(), reasonCallback); } detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -32,11 +32,12 @@ namespace { /// This is a worklist-driven driver for the PatternMatcher, which repeatedly /// applies the locally optimal patterns in a roughly "bottom up" way. -class GreedyPatternRewriteDriver : public PatternRewriter { +class GreedyPatternRewriteDriver : public RewriteListener { public: explicit GreedyPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config); + const GreedyRewriteConfig &config, + RewriteListener *listener); /// Simplify the operations within the given regions. bool simplify(MutableArrayRef regions); @@ -71,11 +72,8 @@ // before the root is changed. void notifyRootReplaced(Operation *op) override; - /// PatternRewriter hook for erasing a dead operation. - void eraseOp(Operation *op) override; - /// PatternRewriter hook for notifying match failure reasons. - LogicalResult + void notifyMatchFailure(Operation *op, function_ref reasonCallback) override; @@ -92,10 +90,17 @@ /// Non-pattern based folder for operations. OperationFolder folder; + /// The pattern rewriter instance to use to perform rewrite operations. + PatternRewriter rewriter; + private: /// Configuration information for how to simplify. GreedyRewriteConfig config; + /// The listeners attached to the pattern rewriter, including this and an + /// optional user-provided listener to notify of rewrite events. + ListenerList listeners; + #ifndef NDEBUG /// A logger used to emit information during the application process. llvm::ScopedPrinter logger{llvm::dbgs()}; @@ -105,8 +110,15 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, - const GreedyRewriteConfig &config) - : PatternRewriter(ctx), matcher(patterns), folder(ctx), config(config) { + const GreedyRewriteConfig &config, RewriteListener *listener) + : matcher(patterns), folder(ctx), rewriter(ctx, &listeners), + config(config) { + // Attach ourselves as a listener. + listeners.addListener(this); + // Attach the user-provided listener, if there is one. + if (listener) + listeners.addListener(listener); + worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -247,16 +259,15 @@ }; LogicalResult matchResult = - matcher.matchAndRewrite(op, *this, canApply, onFailure, onSuccess); + matcher.matchAndRewrite(op, rewriter, canApply, onFailure, onSuccess); if (succeeded(matchResult)) LLVM_DEBUG(logResultWithLine("success", "pattern matched")); else LLVM_DEBUG(logResultWithLine("failure", "pattern failed to match")); #else - LogicalResult matchResult = matcher.matchAndRewrite(op, *this); + LogicalResult matchResult = matcher.matchAndRewrite(op, rewriter); #endif - #ifndef NDEBUG #endif @@ -266,7 +277,7 @@ // After applying patterns, make sure that the CFG of each of the regions // is kept up to date. if (config.enableRegionSimplification) - changed |= succeeded(simplifyRegions(*this, regions)); + changed |= succeeded(simplifyRegions(rewriter, regions)); } while (changed && (++iteration < config.maxIterations || config.maxIterations == GreedyRewriteConfig::kNoIterationLimit)); @@ -327,6 +338,10 @@ } void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) { + LLVM_DEBUG({ + logger.startLine() << "** Erase : '" << op->getName() << "'(" << op + << ")\n"; + }); addToWorklist(op->getOperands()); op->walk([this](Operation *operation) { removeFromWorklist(operation); @@ -344,22 +359,13 @@ addToWorklist(user); } -void GreedyPatternRewriteDriver::eraseOp(Operation *op) { - LLVM_DEBUG({ - logger.startLine() << "** Erase : '" << op->getName() << "'(" << op - << ")\n"; - }); - PatternRewriter::eraseOp(op); -} - -LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure( +void GreedyPatternRewriteDriver::notifyMatchFailure( Operation *op, function_ref reasonCallback) { LLVM_DEBUG({ Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; }); - return failure(); } /// Rewrite the regions of the specified operation, which must be isolated from @@ -368,10 +374,9 @@ /// in the result operation regions. Note: This does not apply patterns to the /// top-level operation itself. /// -LogicalResult -mlir::applyPatternsAndFoldGreedily(MutableArrayRef regions, - const FrozenRewritePatternSet &patterns, - GreedyRewriteConfig config) { +LogicalResult mlir::applyPatternsAndFoldGreedily( + MutableArrayRef regions, const FrozenRewritePatternSet &patterns, + GreedyRewriteConfig config, RewriteListener *listener) { if (regions.empty()) return success(); @@ -386,7 +391,8 @@ "patterns can only be applied to operations IsolatedFromAbove"); // Start the pattern driver. - GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config); + GreedyPatternRewriteDriver driver(regions[0].getContext(), patterns, config, + listener); bool converged = driver.simplify(regions); LLVM_DEBUG(if (!converged) { llvm::dbgs() << "The pattern rewrite doesn't converge after scanning " @@ -402,11 +408,11 @@ namespace { /// This is a simple driver for the PatternMatcher to apply patterns and perform /// folding on a single op. It repeatedly applies locally optimal patterns. -class OpPatternRewriteDriver : public PatternRewriter { +class OpPatternRewriteDriver : public RewriteListener { public: explicit OpPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns) - : PatternRewriter(ctx), matcher(patterns), folder(ctx) { + : matcher(patterns), folder(ctx), rewriter(ctx, this) { // Apply a simple cost model based solely on pattern benefit. matcher.applyDefaultCostModel(); } @@ -432,6 +438,9 @@ /// Non-pattern based folder for operations. OperationFolder folder; + /// The pattern rewriter instance to use to perform rewrite operations. + PatternRewriter rewriter; + /// Set to true if the operation has been erased via pattern rewrites. bool opErasedViaPatternRewrites = false; }; @@ -476,7 +485,7 @@ // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do here. - changed |= succeeded(matcher.matchAndRewrite(op, *this)); + changed |= succeeded(matcher.matchAndRewrite(op, rewriter)); if ((erased = opErasedViaPatternRewrites)) return success(); } while (changed && @@ -502,7 +511,8 @@ explicit MultiOpPatternRewriteDriver(MLIRContext *ctx, const FrozenRewritePatternSet &patterns, bool strict) - : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig()), + : GreedyPatternRewriteDriver(ctx, patterns, GreedyRewriteConfig(), + /*listener=*/nullptr), strictMode(strict) {} bool simplifyLocally(ArrayRef op); @@ -631,7 +641,7 @@ // Try to match one of the patterns. The rewriter is automatically // notified of any necessary changes, so there is nothing else to do // here. - changed |= succeeded(matcher.matchAndRewrite(op, *this)); + changed |= succeeded(matcher.matchAndRewrite(op, rewriter)); } return changed;