diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -550,9 +550,8 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. - void - replaceUsesWithIf(Value from, Value to, - llvm::unique_function functor); + void replaceUsesWithIf(Value from, Value to, + function_ref functor); /// Find uses of `from` and replace them with `to` except if the user is /// `exceptedUser`. It also marks every modified uses and notifies the diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -235,14 +235,14 @@ assert(op->getNumResults() == newValues.size() && "incorrect number of values to replace operation"); - // Notify the rewriter subclass that we're about to replace this root. + // Notify the listener that we're about to replace this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) rewriteListener->notifyOperationReplaced(op, newValues); // Replace each use of the results when the functor is true. bool replacedAllUses = true; for (auto it : llvm::zip(op->getResults(), newValues)) { - std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor); + replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor); replacedAllUses &= std::get<0>(it).use_empty(); } if (allUsesReplaced) @@ -264,17 +264,19 @@ /// values. The number of provided values must match the number of results of /// the operation. void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { - // Notify the rewriter subclass that we're about to replace this root. - if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationReplaced(op, newValues); - assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - op->replaceAllUsesWith(newValues); + // Notify the listener that we're about to remove this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) - rewriteListener->notifyOperationRemoved(op); - op->erase(); + rewriteListener->notifyOperationReplaced(op, newValues); + + // Replace results one-by-one. Also notifies the listener of modifications. + for (auto it : llvm::zip(op->getResults(), newValues)) + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); + + // Erase the op. Also notifies the listener. + eraseOp(op); } /// This method erases an operation that is known to have no uses. The uses of @@ -314,7 +316,7 @@ // Replace all of the successor arguments with the provided values. for (auto it : llvm::zip(source->getArguments(), argValues)) - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); // Splice the operations of the 'source' block into the 'dest' block and erase // it. @@ -326,9 +328,8 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. -void RewriterBase::replaceUsesWithIf( - Value from, Value to, - llvm::unique_function functor) { +void RewriterBase::replaceUsesWithIf(Value from, Value to, + function_ref functor) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { if (functor(operand)) updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); }); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -383,9 +383,6 @@ }); if (config.listener) config.listener->notifyOperationReplaced(op, replacement); - for (auto result : op->getResults()) - for (auto *user : result.getUsers()) - addToWorklist(user); } LogicalResult GreedyPatternRewriteDriver::notifyMatchFailure(