diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -366,6 +366,10 @@ MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos); +/// Sets the `pos`-th operand of the operation. +MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, + MlirValue newValue); + /// Returns the number of results of the operation. MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1640,6 +1640,15 @@ return PyOpOperandList(operation, startIndex, length, step); } + void dunderSetItem(intptr_t index, PyValue value) { + index = wrapIndex(index); + mlirOperationSetOperand(operation->get(), index, value.get()); + } + + static void bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpOperandList::dunderSetItem); + } + private: PyOperationRef operation; }; diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -215,6 +215,16 @@ protected: using ClassTy = pybind11::class_; + intptr_t wrapIndex(intptr_t index) { + if (index < 0) + index = length + index; + if (index < 0 || index >= length) { + throw python::SetPyError(PyExc_IndexError, + "attempt to access out of bounds"); + } + return index; + } + public: explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) : startIndex(startIndex), length(length), step(step) { @@ -228,12 +238,7 @@ /// by taking elements in inverse order. Throws if the index is out of bounds. ElementTy dunderGetItem(intptr_t index) { // Negative indices mean we count from the end. - if (index < 0) - index = length + index; - if (index < 0 || index >= length) { - throw python::SetPyError(PyExc_IndexError, - "attempt to access out of bounds"); - } + index = wrapIndex(index); // Compute the linear index given the current slice properties. int linearIndex = index * step + startIndex; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -351,6 +351,11 @@ return wrap(unwrap(op)->getOperand(static_cast(pos))); } +void mlirOperationSetOperand(MlirOperation op, intptr_t pos, + MlirValue newValue) { + unwrap(op)->setOperand(static_cast(pos), unwrap(newValue)); +} + intptr_t mlirOperationGetNumResults(MlirOperation op) { return static_cast(unwrap(op)->getNumResults()); } diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -215,6 +215,38 @@ run(testOperationOperandsSlice) +# CHECK-LABEL: TEST: testOperationOperandsSet +def testOperationOperandsSet(): + with Context() as ctx, Location.unknown(ctx): + ctx.allow_unregistered_dialects = True + module = Module.parse(r""" + func @f1() { + %0 = "test.producer0"() : () -> i64 + %1 = "test.producer1"() : () -> i64 + %2 = "test.producer2"() : () -> i64 + "test.consumer"(%0) : (i64) -> () + return + }""") + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + producer1 = entry_block.operations[1] + producer2 = entry_block.operations[2] + consumer = entry_block.operations[3] + assert len(consumer.operands) == 1 + type = consumer.operands[0].type + + # CHECK: test.producer1 + consumer.operands[0] = producer1.result + print(consumer.operands[0]) + + # CHECK: test.producer2 + consumer.operands[-1] = producer2.result + print(consumer.operands[0]) + + +run(testOperationOperandsSet) + + # CHECK-LABEL: TEST: testDetachedOperation def testDetachedOperation(): ctx = Context()