diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -801,6 +801,9 @@ /// Returns the type of the value. MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value); +/// Set the type of the value. +MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type); + /// Prints the value to the standard error stream. MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3431,6 +3431,12 @@ py::arg("use_local_scope") = false, kGetNameAsOperand) .def_property_readonly( "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) + .def( + "set_type", + [](PyValue &self, const PyType &type) { + return mlirValueSetType(self.get(), type); + }, + py::arg("type")) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -823,6 +823,10 @@ return wrap(unwrap(value).getType()); } +void mlirValueSetType(MlirValue value, MlirType type) { + unwrap(value).setType(unwrap(type)); +} + void mlirValueDump(MlirValue value) { unwrap(value).dump(); } void mlirValuePrint(MlirValue value, MlirStringCallback callback, diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -238,3 +238,25 @@ value2.owner.detach_from_parent() # CHECK: %0 print(value2.get_name()) + + +# CHECK-LABEL: TEST: testValueSetType +@run +def testValueSetType(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + module = Module.create() + with InsertionPoint(module.body): + value = Operation.create("custom.op1", results=[i32]).results[0] + # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32) + print(value) + + value.set_type(i64) + # CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64) + print(value) + + # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64 + print(value.owner)