diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -755,6 +755,12 @@ /// operand if there are no uses. MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); +/// Replace all uses of 'of' value with the 'with' value, updating anything in +/// the IR that uses 'of' to use the other value instead. When this returns +/// there are zero uses of 'of'. +MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, + MlirValue with); + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -13,11 +13,9 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" -#include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" -//#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -154,6 +152,11 @@ equivalent to printing the operation that produced it. )"; +static const char kValueReplaceAllUsesWithDocstring[] = + R"(Replace all uses of value with the new value, updating anything in +the IR that uses 'self' to use the other value instead. +)"; + //------------------------------------------------------------------------------ // Utilities. //------------------------------------------------------------------------------ @@ -3316,10 +3319,18 @@ return printAccum.join(); }, kValueDunderStrDocstring) - .def_property_readonly("type", [](PyValue &self) { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }); + .def_property_readonly("type", + [](PyValue &self) { + return PyType( + self.getParentOperation()->getContext(), + mlirValueGetType(self.get())); + }) + .def( + "replace_all_uses_with", + [](PyValue &self, PyValue &with) { + mlirValueReplaceAllUsesOfWith(self.get(), with.get()); + }, + kValueReplaceAllUsesWithDocstring); PyBlockArgument::bind(m); PyOpResult::bind(m); PyOpOperand::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -751,6 +751,10 @@ return wrap(opOperand); } +void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { + unwrap(oldValue).replaceAllUsesWith(unwrap(newValue)); +} + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -28,6 +28,25 @@ #include #include +MlirValue makeConstantLiteral(MlirContext ctx, const char *literalStr, + const char *typeStr) { + MlirLocation loc = mlirLocationUnknownGet(ctx); + char attrStr[50]; + sprintf(attrStr, "%s : %s", literalStr, typeStr); + MlirAttribute literal = + mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(attrStr)); + MlirNamedAttribute valueAttr = mlirNamedAttributeGet( + mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), literal); + MlirOperationState constState = mlirOperationStateGet( + mlirStringRefCreateFromCString("arith.constant"), loc); + MlirType type = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(typeStr)); + mlirOperationStateAddResults(&constState, 1, &type); + mlirOperationStateAddAttributes(&constState, 1, &valueAttr); + MlirOperation constOp = mlirOperationCreate(&constState); + return mlirOperationGetResult(constOp, 0); +} + static void registerAllUpstreamDialects(MlirContext ctx) { MlirDialectRegistry registry = mlirDialectRegistryCreate(); mlirRegisterAllDialects(registry); @@ -115,26 +134,17 @@ MlirOperation func = mlirOperationCreate(&funcState); mlirBlockInsertOwnedOperation(moduleBody, 0, func); - MlirType indexType = - mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index")); - MlirAttribute indexZeroLiteral = - mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index")); - MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( - mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), - indexZeroLiteral); - MlirOperationState constZeroState = mlirOperationStateGet( - mlirStringRefCreateFromCString("arith.constant"), location); - mlirOperationStateAddResults(&constZeroState, 1, &indexType); - mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr); - MlirOperation constZero = mlirOperationCreate(&constZeroState); + MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index"); + MlirOperation constZero = mlirOpResultGetOwner(constZeroValue); mlirBlockAppendOwnedOperation(funcBody, constZero); MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0); - MlirValue constZeroValue = mlirOperationGetResult(constZero, 0); MlirValue dimOperands[] = {funcArg0, constZeroValue}; MlirOperationState dimState = mlirOperationStateGet( mlirStringRefCreateFromCString("memref.dim"), location); mlirOperationStateAddOperands(&dimState, 2, dimOperands); + MlirType indexType = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index")); mlirOperationStateAddResults(&dimState, 1, &indexType); MlirOperation dim = mlirOperationCreate(&dimState); mlirBlockAppendOwnedOperation(funcBody, dim); @@ -153,11 +163,11 @@ mlirStringRefCreateFromCString("arith.constant"), location); mlirOperationStateAddResults(&constOneState, 1, &indexType); mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr); - MlirOperation constOne = mlirOperationCreate(&constOneState); + MlirValue constOneValue = makeConstantLiteral(ctx, "1", "index"); + MlirOperation constOne = mlirOpResultGetOwner(constOneValue); mlirBlockAppendOwnedOperation(funcBody, constOne); MlirValue dimValue = mlirOperationGetResult(dim, 0); - MlirValue constOneValue = mlirOperationGetResult(constOne, 0); MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue}; MlirOperationState loopState = mlirOperationStateGet( mlirStringRefCreateFromCString("scf.for"), location); @@ -820,11 +830,6 @@ return 0; } -void callbackSetFixedLengthString(const char *data, intptr_t len, - void *userData) { - strncpy(userData, data, len); -} - bool stringIsEqual(const char *lhs, MlirStringRef rhs) { if (strlen(lhs) != rhs.length) { return false; @@ -1794,32 +1799,10 @@ mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith")); mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test")); MlirLocation loc = mlirLocationUnknownGet(ctx); - MlirType indexType = mlirIndexTypeGet(ctx); // Create some constants to use as operands. - MlirAttribute indexZeroLiteral = - mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index")); - MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( - mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), - indexZeroLiteral); - MlirOperationState constZeroState = mlirOperationStateGet( - mlirStringRefCreateFromCString("arith.constant"), loc); - mlirOperationStateAddResults(&constZeroState, 1, &indexType); - mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr); - MlirOperation constZero = mlirOperationCreate(&constZeroState); - MlirValue constZeroValue = mlirOperationGetResult(constZero, 0); - - MlirAttribute indexOneLiteral = - mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index")); - MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet( - mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), - indexOneLiteral); - MlirOperationState constOneState = mlirOperationStateGet( - mlirStringRefCreateFromCString("arith.constant"), loc); - mlirOperationStateAddResults(&constOneState, 1, &indexType); - mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr); - MlirOperation constOne = mlirOperationCreate(&constOneState); - MlirValue constOneValue = mlirOperationGetResult(constOne, 0); + MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index"); + MlirValue constOneValue = makeConstantLiteral(ctx, "1", "index"); // Create the operation under test. mlirContextSetAllowUnregisteredDialects(ctx, true); @@ -1873,9 +1856,49 @@ return 3; } + MlirOperationState op2State = + mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc); + MlirValue initialOperands2[] = {constOneValue}; + mlirOperationStateAddOperands(&op2State, 1, initialOperands2); + (void)mlirOperationCreate(&op2State); + + MlirOpOperand use3 = mlirValueGetFirstUse(constOneValue); + fprintf(stderr, "First use owner: "); + mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: First use owner: "dummy.op2" + + use3 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constOneValue)); + fprintf(stderr, "Second use owner: "); + mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: Second use owner: "dummy.op" + + MlirValue constTwoValue = makeConstantLiteral(ctx, "2", "index"); + mlirValueReplaceAllUsesOfWith(constOneValue, constTwoValue); + + use3 = mlirValueGetFirstUse(constOneValue); + if (!mlirOpOperandIsNull(use3)) { + fprintf(stderr, "ERROR: Use should be null\n"); + return 4; + } + + MlirOpOperand use4 = mlirValueGetFirstUse(constTwoValue); + fprintf(stderr, "First replacement use owner: "); + mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: First replacement use owner: "dummy.op" + + use4 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constTwoValue)); + fprintf(stderr, "Second replacement use owner: "); + mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: Second replacement use owner: "dummy.op2" + mlirOperationDestroy(op); - mlirOperationDestroy(constZero); - mlirOperationDestroy(constOne); + mlirOperationDestroy(mlirOpResultGetOwner(constZeroValue)); + mlirOperationDestroy(mlirOpResultGetOwner(constOneValue)); + mlirOperationDestroy(mlirOpResultGetOwner(constTwoValue)); mlirContextDestroy(ctx); return 0; @@ -1891,18 +1914,10 @@ mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func")); MlirLocation loc = mlirLocationUnknownGet(ctx); - MlirType indexType = mlirIndexTypeGet(ctx); MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value"); - MlirAttribute indexZeroLiteral = - mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index")); - MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( - mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral); - MlirOperationState constZeroState = mlirOperationStateGet( - mlirStringRefCreateFromCString("arith.constant"), loc); - mlirOperationStateAddResults(&constZeroState, 1, &indexType); - mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr); - MlirOperation constZero = mlirOperationCreate(&constZeroState); + MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index"); + MlirOperation constZero = mlirOpResultGetOwner(constZeroValue); MlirAttribute indexOneLiteral = mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index")); @@ -1980,19 +1995,10 @@ } MlirLocation loc = mlirLocationUnknownGet(ctx); - MlirType indexType = mlirIndexTypeGet(ctx); - MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value"); // Create a registered operation, which should have a type id. - MlirAttribute indexZeroLiteral = - mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index")); - MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( - mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral); - MlirOperationState constZeroState = mlirOperationStateGet( - mlirStringRefCreateFromCString("arith.constant"), loc); - mlirOperationStateAddResults(&constZeroState, 1, &indexType); - mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr); - MlirOperation constZero = mlirOperationCreate(&constZeroState); + MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index"); + MlirOperation constZero = mlirOpResultGetOwner(constZeroValue); if (!mlirOperationVerify(constZero)) { fprintf(stderr, "ERROR: Expected operation to verify correctly\n"); diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -111,3 +111,29 @@ assert use.owner in [op1, op2] print(f"Use owner: {use.owner}") print(f"Use operand_number: {use.operand_number}") + +# CHECK-LABEL: TEST: testValueReplaceAllUsesWith +@run +def testValueReplaceAllUsesWith(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + value = Operation.create("custom.op1", results=[i32]).results[0] + op1 = Operation.create("custom.op2", operands=[value]) + op2 = Operation.create("custom.op2", operands=[value]) + value2 = Operation.create("custom.op3", results=[i32]).results[0] + value.replace_all_uses_with(value2) + + assert len(list(value.uses)) == 0 + + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + for use in value2.uses: + assert use.owner in [op1, op2] + print(f"Use owner: {use.owner}") + print(f"Use operand_number: {use.operand_number}")