diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -257,6 +257,15 @@ getResults().replaceAllUsesWith(std::forward(values)); } + /// Replace uses of results of this operation with the provided `values` if + /// the given callback returns true. + template + void replaceUsesWithIf(ValuesT &&values, + function_ref shouldReplace) { + getResults().replaceUsesWithIf(std::forward(values), + shouldReplace); + } + /// Destroys this operation and its subclass data. void destroy(); diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h --- a/mlir/include/mlir/IR/ValueRange.h +++ b/mlir/include/mlir/IR/ValueRange.h @@ -279,6 +279,26 @@ /// Replace all uses of results of this range with results of 'op'. void replaceAllUsesWith(Operation *op); + /// Replace uses of results of this range with the provided 'values' if the + /// given callback returns true. The size of `values` must match the size of + /// this range. + template + std::enable_if_t::value> + replaceUsesWithIf(ValuesT &&values, + function_ref shouldReplace) { + assert(static_cast(std::distance(values.begin(), values.end())) == + size() && + "expected 'values' to correspond 1-1 with the number of results"); + + for (auto it : llvm::zip(*this, values)) + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), shouldReplace); + } + + /// Replace uses of results of this range with results of `op` if the given + /// callback returns true. + void replaceUsesWithIf(Operation *op, + function_ref shouldReplace); + //===--------------------------------------------------------------------===// // Users //===--------------------------------------------------------------------===// diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -589,6 +589,11 @@ replaceAllUsesWith(op->getResults()); } +void ResultRange::replaceUsesWithIf( + Operation *op, function_ref shouldReplace) { + replaceUsesWithIf(op->getResults(), shouldReplace); +} + //===----------------------------------------------------------------------===// // ValueRange diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -124,12 +124,9 @@ } else { // When the region does not have SSA dominance, we need to check if we // have visited a use before replacing any use. - for (auto it : llvm::zip(op->getResults(), existing->getResults())) { - std::get<0>(it).replaceUsesWithIf( - std::get<1>(it), [&](OpOperand &operand) { - return !knownValues.count(operand.getOwner()); - }); - } + op->replaceUsesWithIf(existing->getResults(), [&](OpOperand &operand) { + return !knownValues.count(operand.getOwner()); + }); // There may be some remaining uses of the operation. if (op->use_empty())