diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -447,6 +447,55 @@ MlirBlock block; }; +class PyOpOperand { +public: + PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} + + py::object getOwner() { + MlirOperation owner = mlirOpOperandGetOwner(opOperand); + PyMlirContextRef context = + PyMlirContext::forContext(mlirOperationGetContext(owner)); + return PyOperation::forOperation(context, owner)->createOpView(); + } + + size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } + + static void bind(py::module &m) { + py::class_(m, "OpOperand", py::module_local()) + .def_property_readonly("owner", &PyOpOperand::getOwner) + .def_property_readonly("operand_number", + &PyOpOperand::getOperandNumber); + } + +private: + MlirOpOperand opOperand; +}; + +class PyOpOperandIterator { +public: + PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {} + + PyOpOperandIterator &dunderIter() { return *this; } + + PyOpOperand dunderNext() { + if (mlirOpOperandIsNull(opOperand)) + throw py::stop_iteration(); + + PyOpOperand returnOpOperand(opOperand); + opOperand = mlirOpOperandGetNextUse(opOperand); + return returnOpOperand; + } + + static void bind(py::module &m) { + py::class_(m, "OpOperandIterator", py::module_local()) + .def("__iter__", &PyOpOperandIterator::dunderIter) + .def("__next__", &PyOpOperandIterator::dunderNext); + } + +private: + MlirOpOperand opOperand; +}; + } // namespace //------------------------------------------------------------------------------ @@ -3156,6 +3205,11 @@ assert(false && "Value must be a block argument or an op result"); return py::none(); }) + .def_property_readonly("uses", + [](PyValue &self) { + return PyOpOperandIterator( + mlirValueGetFirstUse(self.get())); + }) .def("__eq__", [](PyValue &self, PyValue &other) { return self.get().ptr == other.get().ptr; @@ -3182,6 +3236,7 @@ }); PyBlockArgument::bind(m); PyOpResult::bind(m); + PyOpOperand::bind(m); //---------------------------------------------------------------------------- // Mapping of SymbolTable. @@ -3220,6 +3275,7 @@ PyOperationIterator::bind(m); PyOperationList::bind(m); PyOpAttributeMap::bind(m); + PyOpOperandIterator::bind(m); PyOpOperandList::bind(m); PyOpResultList::bind(m); PyRegionIterator::bind(m); diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -89,3 +89,25 @@ op, ret = block.operations assert hash(block.arguments[0]) == hash(op.operands[0]) assert hash(op.result) == hash(ret.operands[0]) + +# CHECK-LABEL: TEST: testValueUses +@run +def testValueUses(): + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + value = Operation.create("custom.op1", results=[i32]).results[0] + op1 = Operation.create("custom.op2", operands=[value]) + op2 = Operation.create("custom.op2", operands=[value]) + + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + # CHECK: Use owner: "custom.op2" + # CHECK: Use operand_number: 0 + for use in value.uses: + assert use.owner in [op1, op2] + print(f"Use owner: {use.owner}") + print(f"Use operand_number: {use.operand_number}")