diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -550,14 +550,15 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. - void replaceUseIf(Value from, Value to, + void + replaceUsesWithIf(Value from, Value to, llvm::unique_function functor); /// Find uses of `from` and replace them with `to` except if the user is /// `exceptedUser`. It also marks every modified uses and notifies the /// rewriter that an in-place operation modification is about to happen. void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser) { - return replaceUseIf(from, to, [&](OpOperand &use) { + return replaceUsesWithIf(from, to, [&](OpOperand &use) { Operation *user = use.getOwner(); return user != exceptedUser; }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -428,7 +428,7 @@ return rewriter.notifyMatchFailure(genericOp, "fusion failed"); Operation *producer = opOperand.get().getDefiningOp(); for (auto [origVal, replacement] : fusionResult->replacements) { - rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) { + rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { // Only replace consumer uses. return use.get().getDefiningOp() != producer; }); diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -326,7 +326,7 @@ /// Find uses of `from` and replace them with `to` if the `functor` returns /// true. It also marks every modified uses and notifies the rewriter that an /// in-place operation modification is about to happen. -void RewriterBase::replaceUseIf( +void RewriterBase::replaceUsesWithIf( Value from, Value to, llvm::unique_function functor) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -74,7 +74,7 @@ if (!fusionResult) return rewriter.notifyMatchFailure(genericOp, "fusion failed"); for (auto [origValue, replacement] : fusionResult->replacements) { - rewriter.replaceUseIf(origValue, replacement, [&](OpOperand &use) { + rewriter.replaceUsesWithIf(origValue, replacement, [&](OpOperand &use) { return use.getOwner() != genericOp.getOperation(); }); } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -335,7 +335,7 @@ scf::ForOp outermostLoop = tilingResult->loops.front(); for (auto [index, origVal] : llvm::enumerate(yieldedValuesToOrigValues)) { Value replacement = outermostLoop.getResult(index); - rewriter.replaceUseIf(origVal, replacement, [&](OpOperand &use) { + rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { return !isIgnoredUser(use.getOwner(), outermostLoop); }); }