diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -507,10 +507,21 @@ /// modification is about to happen. void replaceAllUsesWith(Value from, Value to); + /// Find uses of `from` and replace them with `to` if the `functor` returns + /// true. It also marks every modified uses and notifies the rewriter that an + /// in-place operation modification is about to happen. + void replaceUseIf(Value from, Value to, + llvm::unique_function functor); + /// Find uses of `from` and replace them with `to` except if the user is /// `exceptedUser`. It also marks every modified uses and notifies the /// rewriter that an in-place operation modification is about to happen. - void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser); + void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) { + return replaceUseIf(from, to, [&](OpOperand &use) { + Operation *user = use.getOwner(); + return user != exceptedUser; + }); + } /// Used to notify the rewriter that the IR failed to be rewritten because of /// a match failure, and provide a callback to populate a diagnostic with the diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -317,15 +317,15 @@ } } -/// Find uses of `from` and replace them with `to` except if the user is -/// `exceptedUser`. It also marks every modified uses and notifies the -/// rewriter that an in-place operation modification is about to happen. -void RewriterBase::replaceAllUsesExcept(Value from, Value to, - Operation *exceptedUser) { +/// Find uses of `from` and replace them with `to` if the `functor` returns +/// true. It also marks every modified uses and notifies the rewriter that an +/// in-place operation modification is about to happen. +void RewriterBase::replaceUseIf( + Value from, Value to, + llvm::unique_function functor) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { - Operation *user = operand.getOwner(); - if (user != exceptedUser) - updateRootInPlace(user, [&]() { operand.set(to); }); + if (functor(operand)) + updateRootInPlace(operand.getOwner(), [&]() { operand.set(to); }); } }