diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -447,6 +447,58 @@ MlirBlock block; }; +class PyOpOperand { +public: + PyOpOperand(PyMlirContextRef context, MlirOpOperand opOperand) + : context(std::move(context)), opOperand(opOperand) {} + + py::object getOwner() { + return PyOperation::forOperation(context, mlirOpOperandGetOwner(opOperand)) + ->createOpView(); + } + + size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } + + static void bind(py::module &m) { + py::class_(m, "OpOperand", py::module_local()) + .def_property_readonly("owner", &PyOpOperand::getOwner) + .def_property_readonly("operand_number", + &PyOpOperand::getOperandNumber); + } + +private: + PyMlirContextRef context; + MlirOpOperand opOperand; +}; + +class PyOpOperandIterator { +public: + PyOpOperandIterator(MlirOpOperand opOperand) + : context(PyMlirContext::forContext( + mlirOperationGetContext(mlirOpOperandGetOwner(opOperand)))), + opOperand(opOperand) {} + + PyOpOperandIterator &dunderIter() { return *this; } + + PyOpOperand dunderNext() { + if (mlirOpOperandIsNull(opOperand)) + throw py::stop_iteration(); + MlirOpOperand returnOpOperand = opOperand; + opOperand = mlirOpOperandGetNextUse(opOperand); + return PyOpOperand(context, returnOpOperand); + } + + static void bind(py::module &m) { + py::class_(m, "OpOperandIterator", py::module_local()) + .def("__iter__", &PyOpOperandIterator::dunderIter) + .def("__next__", &PyOpOperandIterator::dunderNext); + } + +private: + PyMlirContextRef context; + MlirOpOperand opOperand; +}; + } // namespace //------------------------------------------------------------------------------ @@ -3152,6 +3204,11 @@ assert(false && "Value must be a block argument or an op result"); return py::none(); }) + .def_property_readonly("uses", + [](PyValue &self) { + return PyOpOperandIterator( + mlirValueGetFirstUse(self.get())); + }) .def("__eq__", [](PyValue &self, PyValue &other) { return self.get().ptr == other.get().ptr; @@ -3178,6 +3235,7 @@ }); PyBlockArgument::bind(m); PyOpResult::bind(m); + PyOpOperand::bind(m); //---------------------------------------------------------------------------- // Mapping of SymbolTable. @@ -3216,6 +3274,7 @@ PyOperationIterator::bind(m); PyOperationList::bind(m); PyOpAttributeMap::bind(m); + PyOpOperandIterator::bind(m); PyOpOperandList::bind(m); PyOpResultList::bind(m); PyRegionIterator::bind(m); diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -89,3 +89,25 @@ op, ret = block.operations assert hash(block.arguments[0]) == hash(op.operands[0]) assert hash(op.result) == hash(ret.operands[0]) + +# CHECK-LABEL: TEST: testValueUses +# CHECK: Use owner: "custom.op2" +# CHECK: Use operand_number: 0 +# CHECK: Use owner: "custom.op2" +# CHECK: Use operand_number: 0 +@run +def testValueUses(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + value = Operation.create("custom.op1", results=[i32]).results[0] + op1 = Operation.create("custom.op2", operands=[value]) + op2 = Operation.create("custom.op2", operands=[value]) + + for use in value.uses: + assert use.owner in [op1, op2] + print(f"Use owner: {use.owner}") + print(f"Use operand_number: {use.operand_number}")