diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -402,6 +402,9 @@ Listener() : OpBuilder::Listener(ListenerBase::Kind::RewriterBaseListener) {} + /// Notify the listener that the specified operation was modified in-place. + virtual void notifyOperationModified(Operation *op) {} + /// Notify the listener that the specified operation is about to be replaced /// with the set of values potentially produced by new operations. This is /// called before the uses of the operation have been changed. @@ -514,7 +517,7 @@ /// This method is used to signal the end of a root update on the given /// operation. This can only be called on operations that were provided to a /// call to `startRootUpdate`. - virtual void finalizeRootUpdate(Operation *op) {} + virtual void finalizeRootUpdate(Operation *op); /// This method cancels a pending root update. This can only be called on /// operations that were provided to a call to `startRootUpdate`. diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -294,6 +294,12 @@ block->erase(); } +void RewriterBase::finalizeRootUpdate(Operation *op) { + // Notify the listener that the operation was modified. + if (auto *rewriteListener = dyn_cast_if_present(listener)) + rewriteListener->notifyOperationModified(op); +} + /// Merge the operations of block 'source' into the end of block 'dest'. /// 'source's predecessors must be empty or only contain 'dest`. /// 'argValues' is used to replace the block arguments of 'source' after diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1654,6 +1654,7 @@ } void ConversionPatternRewriter::finalizeRootUpdate(Operation *op) { + PatternRewriter::finalizeRootUpdate(op); // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -54,7 +54,7 @@ /// Notify the driver that the specified operation may have been modified /// in-place. The operation is added to the worklist. - void finalizeRootUpdate(Operation *op) override; + void notifyOperationModified(Operation *op) override; /// Notify the driver that the specified operation was inserted. Update the /// worklist as needed: The operation is enqueued depending on scope and @@ -335,7 +335,7 @@ addToWorklist(op); } -void GreedyPatternRewriteDriver::finalizeRootUpdate(Operation *op) { +void GreedyPatternRewriteDriver::notifyOperationModified(Operation *op) { LLVM_DEBUG({ logger.startLine() << "** Modified: '" << op->getName() << "'(" << op << ")\n";