diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -546,9 +546,9 @@ /// they would like to be notified about certain types of mutations. /// Notify the rewriter that the specified operation is about to be replaced - /// with another set of operations. This is called before the uses of the - /// operation have been changed. - virtual void notifyRootReplaced(Operation *op) {} + /// with the set of values potentially produced by new operations. This is + /// called before the uses of the operation have been changed. + virtual void notifyRootReplaced(Operation *op, ValueRange replacement) {} /// This is called on an operation that a rewrite is removing, right before /// the operation is deleted. At this point, the operation has zero uses. diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -216,7 +216,7 @@ "incorrect number of values to replace operation"); // Notify the rewriter subclass that we're about to replace this root. - notifyRootReplaced(op); + notifyRootReplaced(op, newValues); // Replace each use of the results when the functor is true. bool replacedAllUses = true; @@ -244,7 +244,7 @@ /// the operation. void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { // Notify the rewriter subclass that we're about to replace this root. - notifyRootReplaced(op); + notifyRootReplaced(op, newValues); assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -69,7 +69,7 @@ // When the root of a pattern is about to be replaced, it can trigger // simplifications to its users - make sure to add them to the worklist // before the root is changed. - void notifyRootReplaced(Operation *op) override; + void notifyRootReplaced(Operation *op, ValueRange replacement) override; /// PatternRewriter hook for erasing a dead operation. void eraseOp(Operation *op) override; @@ -348,7 +348,8 @@ }); } -void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op) { +void GreedyPatternRewriteDriver::notifyRootReplaced(Operation *op, + ValueRange replacement) { LLVM_DEBUG({ logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; @@ -437,7 +438,7 @@ // When a root is going to be replaced, its removal will be notified as well. // So there is nothing to do here. - void notifyRootReplaced(Operation *op) override {} + void notifyRootReplaced(Operation *op, ValueRange replacement) override {} private: /// The low-level pattern applicator.