diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -555,9 +555,8 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. - void - replaceUsesWithIf(Value from, Value to, - llvm::unique_function functor); + void replaceUsesWithIf(Value from, Value to, + function_ref functor); /// Find uses of `from` and replace them with `to` except if the user is /// `exceptedUser`. It also marks every modified uses and notifies the diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -235,14 +235,14 @@ assert(op->getNumResults() == newValues.size() && "incorrect number of values to replace operation"); - // Notify the rewriter subclass that we're about to replace this root. + // Notify the listener that we're about to replace this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) rewriteListener->notifyOperationReplaced(op, newValues); // Replace each use of the results when the functor is true. bool replacedAllUses = true; for (auto it : llvm::zip(op->getResults(), newValues)) { - std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor); + replaceUsesWithIf(std::get<0>(it), std::get<1>(it), functor); replacedAllUses &= std::get<0>(it).use_empty(); } if (allUsesReplaced) @@ -264,13 +264,16 @@ /// values. The number of provided values must match the number of results of /// the operation. void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { - // Notify the rewriter subclass that we're about to replace this root. + assert(op->getNumResults() == newValues.size() && + "incorrect # of replacement values"); + + // Notify the listener that we're about to remove this op. if (auto *rewriteListener = dyn_cast_if_present(listener)) rewriteListener->notifyOperationReplaced(op, newValues); - assert(op->getNumResults() == newValues.size() && - "incorrect # of replacement values"); - op->replaceAllUsesWith(newValues); + // Replace results one-by-one. Also notifies the listener of modifications. + for (auto it : llvm::zip(op->getResults(), newValues)) + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); if (auto *rewriteListener = dyn_cast_if_present(listener)) rewriteListener->notifyOperationRemoved(op); @@ -314,7 +317,7 @@ // Replace all of the successor arguments with the provided values. for (auto it : llvm::zip(source->getArguments(), argValues)) - std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); // Splice the operations of the 'source' block into the 'dest' block and erase // it. @@ -326,9 +329,8 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. -void RewriterBase::replaceUsesWithIf( - Value from, Value to, - llvm::unique_function functor) { +void RewriterBase::replaceUsesWithIf(Value from, Value to, + function_ref functor) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { if (functor(operand)) updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); });