diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -675,6 +675,19 @@ MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); +/// Replace all uses of 'this' value with the new value, updating anything in +/// the IR that uses 'this' to use the other value instead. When this returns +/// there are zero uses of 'this'. +MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesWith(MlirValue value, + MlirValue newValue); + +/// Replace all uses of 'this' value with 'newValue', updating anything in the +/// IR that uses 'this' to use the other value instead except if the user is +/// listed in 'exceptions' . +MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesExcept(MlirValue value, + MlirValue newValue, + MlirOperation except); + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3103,10 +3103,22 @@ return printAccum.join(); }, kValueDunderStrDocstring) - .def_property_readonly("type", [](PyValue &self) { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }); + .def_property_readonly("type", + [](PyValue &self) { + return PyType( + self.getParentOperation()->getContext(), + mlirValueGetType(self.get())); + }) + .def("replace_all_uses_with", + [](PyValue &self, PyValue &newValue) { + mlirValueReplaceAllUsesWith(self.get(), newValue.get()); + }) + .def("replace_all_uses_except", + [](PyValue &self, PyValue &newValue, PyOperation &operation) { + mlirValueReplaceAllUsesExcept(self.get(), newValue.get(), + operation.get()); + }); + PyBlockArgument::bind(m); PyOpResult::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -708,6 +708,15 @@ unwrap(value).print(stream); } +void mlirValueReplaceAllUsesWith(MlirValue value, MlirValue newValue) { + unwrap(value).replaceAllUsesWith(unwrap(newValue)); +} + +void mlirValueReplaceAllUsesExcept(MlirValue value, MlirValue newValue, + MlirOperation except) { + unwrap(value).replaceAllUsesExcept(unwrap(newValue), unwrap(except)); +} + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===//