diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -366,6 +366,10 @@ MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos); +/// Sets the `pos`-th operand of the operation. +MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, + MlirValue newValue); + /// Returns the number of results of the operation. MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1640,6 +1640,15 @@ return PyOpOperandList(operation, startIndex, length, step); } + void dunderSetItem(intptr_t index, PyValue value) { + index = wrapIndex(index); + mlirOperationSetOperand(operation->get(), index, value.get()); + } + + static void bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpOperandList::dunderSetItem); + } + private: PyOperationRef operation; }; diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -215,6 +215,16 @@ protected: using ClassTy = pybind11::class_; + intptr_t wrapIndex(intptr_t index) { + if (index < 0) + index = length + index; + if (index < 0 || index >= length) { + throw python::SetPyError(PyExc_IndexError, + "attempt to access out of bounds"); + } + return index; + } + public: explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) : startIndex(startIndex), length(length), step(step) { @@ -228,12 +238,7 @@ /// by taking elements in inverse order. Throws if the index is out of bounds. ElementTy dunderGetItem(intptr_t index) { // Negative indices mean we count from the end. - if (index < 0) - index = length + index; - if (index < 0 || index >= length) { - throw python::SetPyError(PyExc_IndexError, - "attempt to access out of bounds"); - } + index = wrapIndex(index); // Compute the linear index given the current slice properties. int linearIndex = index * step + startIndex; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -351,6 +351,11 @@ return wrap(unwrap(op)->getOperand(static_cast(pos))); } +void mlirOperationSetOperand(MlirOperation op, intptr_t pos, + MlirValue newValue) { + unwrap(op)->setOperand(static_cast(pos), unwrap(newValue)); +} + intptr_t mlirOperationGetNumResults(MlirOperation op) { return static_cast(unwrap(op)->getNumResults()); } diff --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py --- a/mlir/test/Bindings/Python/ir_operation.py +++ b/mlir/test/Bindings/Python/ir_operation.py @@ -215,6 +215,38 @@ run(testOperationOperandsSlice) +# CHECK-LABEL: TEST: testOperationOperandsSet +def testOperationOperandsSet(): + with Context() as ctx, Location.unknown(ctx): + ctx.allow_unregistered_dialects = True + module = Module.parse(r""" + func @f1() { + %0 = "test.producer0"() : () -> i64 + %1 = "test.producer1"() : () -> i64 + %2 = "test.producer2"() : () -> i64 + "test.consumer"(%0) : (i64) -> () + return + }""") + func = module.body.operations[0] + entry_block = func.regions[0].blocks[0] + producer1 = entry_block.operations[1] + producer2 = entry_block.operations[2] + consumer = entry_block.operations[3] + assert len(consumer.operands) == 1 + type = consumer.operands[0].type + + # CHECK: test.producer1 + consumer.operands[0] = producer1.result + print(consumer.operands[0]) + + # CHECK: test.producer2 + consumer.operands[-1] = producer2.result + print(consumer.operands[0]) + + +run(testOperationOperandsSet) + + # CHECK-LABEL: TEST: testDetachedOperation def testDetachedOperation(): ctx = Context() diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1511,6 +1511,66 @@ return 0; } +/// Tests operand APIs. +int testOperands() { + fprintf(stderr, "@testOperands\n"); + // CHECK-LABEL: @testOperands + + MlirContext ctx = mlirContextCreate(); + MlirLocation loc = mlirLocationUnknownGet(ctx); + MlirType indexType = mlirIndexTypeGet(ctx); + + // Create some constants to use as operands. + MlirAttribute indexZeroLiteral = + mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index")); + MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet( + mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), + indexZeroLiteral); + MlirOperationState constZeroState = mlirOperationStateGet( + mlirStringRefCreateFromCString("std.constant"), loc); + mlirOperationStateAddResults(&constZeroState, 1, &indexType); + mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr); + MlirOperation constZero = mlirOperationCreate(&constZeroState); + MlirValue constZeroValue = mlirOperationGetResult(constZero, 0); + + MlirAttribute indexOneLiteral = + mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index")); + MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet( + mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), + indexOneLiteral); + MlirOperationState constOneState = mlirOperationStateGet( + mlirStringRefCreateFromCString("std.constant"), loc); + mlirOperationStateAddResults(&constOneState, 1, &indexType); + mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr); + MlirOperation constOne = mlirOperationCreate(&constOneState); + MlirValue constOneValue = mlirOperationGetResult(constOne, 0); + + // Create the operation under test. + MlirOperationState opState = + mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op"), loc); + MlirValue initialOperands[] = {constZeroValue}; + mlirOperationStateAddOperands(&opState, 1, initialOperands); + MlirOperation op = mlirOperationCreate(&opState); + + // Test operand APIs. + intptr_t numOperands = mlirOperationGetNumOperands(op); + fprintf(stderr, "Num Operands: %ld\n", numOperands); + // CHECK: Num Operands: 1 + + MlirValue opOperand = mlirOperationGetOperand(op, 0); + fprintf(stderr, "Original operand: "); + mlirValuePrint(opOperand, printToStderr, NULL); + // CHECK: Original operand: {{.+}} {value = 0 : index} + + mlirOperationSetOperand(op, 0, constOneValue); + opOperand = mlirOperationGetOperand(op, 0); + fprintf(stderr, "Updated operand: "); + mlirValuePrint(opOperand, printToStderr, NULL); + // CHECK: Updated operand: {{.+}} {value = 1 : index} + + return 0; +} + // Wraps a diagnostic into additional text we can match against. MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) { fprintf(stderr, "processing diagnostic (userData: %ld) <<\n", (long)userData); @@ -1588,6 +1648,8 @@ return 9; if (testBackreferences()) return 10; + if (testOperands()) + return 11; mlirContextDestroy(ctx);