diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -755,6 +755,12 @@ /// operand if there are no uses. MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); +/// Replace all uses of 'of' value with the 'with' value, updating anything in +/// the IR that uses 'of' to use the other value instead. When this returns +/// there are zero uses of 'of'. +MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, + MlirValue with); + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -13,11 +13,9 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" -#include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" -//#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -154,6 +152,11 @@ equivalent to printing the operation that produced it. )"; +static const char kValueReplaceAllUsesWithDocstring[] = + R"(Replace all uses of value with the new value, updating anything in +the IR that uses 'self' to use the other value instead. +)"; + //------------------------------------------------------------------------------ // Utilities. //------------------------------------------------------------------------------ @@ -3316,10 +3319,18 @@ return printAccum.join(); }, kValueDunderStrDocstring) - .def_property_readonly("type", [](PyValue &self) { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }); + .def_property_readonly("type", + [](PyValue &self) { + return PyType( + self.getParentOperation()->getContext(), + mlirValueGetType(self.get())); + }) + .def( + "replace_all_uses_with", + [](PyValue &self, PyValue &with) { + mlirValueReplaceAllUsesOfWith(self.get(), with.get()); + }, + kValueReplaceAllUsesWithDocstring); PyBlockArgument::bind(m); PyOpResult::bind(m); PyOpOperand::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -751,6 +751,10 @@ return wrap(opOperand); } +void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { + unwrap(oldValue).replaceAllUsesWith(unwrap(newValue)); +} + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1873,9 +1873,61 @@ return 3; } + MlirOperationState op2State = + mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc); + MlirValue initialOperands2[] = {constOneValue}; + mlirOperationStateAddOperands(&op2State, 1, initialOperands2); + MlirOperation op2 = mlirOperationCreate(&op2State); + + MlirOpOperand use3 = mlirValueGetFirstUse(constOneValue); + fprintf(stderr, "First use owner: "); + mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: First use owner: "dummy.op2" + + use3 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constOneValue)); + fprintf(stderr, "Second use owner: "); + mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: Second use owner: "dummy.op" + + MlirAttribute indexTwoLiteral = + mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("2 : index")); + MlirNamedAttribute indexTwoValueAttr = mlirNamedAttributeGet( + mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), + indexTwoLiteral); + MlirOperationState constTwoState = mlirOperationStateGet( + mlirStringRefCreateFromCString("arith.constant"), loc); + mlirOperationStateAddResults(&constTwoState, 1, &indexType); + mlirOperationStateAddAttributes(&constTwoState, 1, &indexTwoValueAttr); + MlirOperation constTwo = mlirOperationCreate(&constTwoState); + MlirValue constTwoValue = mlirOperationGetResult(constTwo, 0); + + mlirValueReplaceAllUsesOfWith(constOneValue, constTwoValue); + + use3 = mlirValueGetFirstUse(constOneValue); + if (!mlirOpOperandIsNull(use3)) { + fprintf(stderr, "ERROR: Use should be null\n"); + return 4; + } + + MlirOpOperand use4 = mlirValueGetFirstUse(constTwoValue); + fprintf(stderr, "First replacement use owner: "); + mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: First replacement use owner: "dummy.op" + + use4 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constTwoValue)); + fprintf(stderr, "Second replacement use owner: "); + mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: Second replacement use owner: "dummy.op2" + mlirOperationDestroy(op); + mlirOperationDestroy(op2); mlirOperationDestroy(constZero); mlirOperationDestroy(constOne); + mlirOperationDestroy(constTwo); mlirContextDestroy(ctx); return 0; diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -111,3 +111,29 @@ assert use.owner in [op1, op2] print(f"Use owner: {use.owner}") print(f"Use operand_number: {use.operand_number}") + +# CHECK-LABEL: TEST: testValueReplaceAllUsesWith +@run +def testValueReplaceAllUsesWith(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + value = Operation.create("custom.op1", results=[i32]).results[0] + op1 = Operation.create("custom.op2", operands=[value]) + op2 = Operation.create("custom.op2", operands=[value]) + value2 = Operation.create("custom.op3", results=[i32]).results[0] + value.replace_all_uses_with(value2) + + assert len(list(value.uses)) == 0 + + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + for use in value2.uses: + assert use.owner in [op1, op2] + print(f"Use owner: {use.owner}") + print(f"Use operand_number: {use.operand_number}")