diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -76,6 +76,9 @@ /// were on the worklist at the very beginning) enqueued. All other ops are /// excluded. GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp; + + /// An optional listener that should be notified about IR modifications. + RewriterBase::Listener *listener = nullptr; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -103,6 +103,9 @@ /// Pop the next operation from the worklist. Operation *popFromWorklist(); + /// Notify the driver that the given block was created. + void notifyBlockCreated(Block *block) override; + /// For debugging only: Notify the driver of a pattern match failure. LogicalResult notifyMatchFailure(Location loc, @@ -315,11 +318,18 @@ } } +void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) { + if (config.listener) + config.listener->notifyBlockCreated(block); +} + void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) { LLVM_DEBUG({ logger.startLine() << "** Insert : '" << op->getName() << "'(" << op << ")\n"; }); + if (config.listener) + config.listener->notifyOperationInserted(op); if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps) strictModeFilteredOps.insert(op); addToWorklist(op); @@ -352,6 +362,8 @@ logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + if (config.listener) + config.listener->notifyOperationRemoved(op); addOperandsToWorklist(op->getOperands()); op->walk([this](Operation *operation) { @@ -369,6 +381,8 @@ logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + if (config.listener) + config.listener->notifyOperationReplaced(op, replacement); for (auto result : op->getResults()) for (auto *user : result.getUsers()) addToWorklist(user); @@ -381,6 +395,8 @@ reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; }); + if (config.listener) + return config.listener->notifyMatchFailure(loc, reasonCallback); return failure(); }