diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -77,6 +77,16 @@ }; typedef struct MlirNamedAttribute MlirNamedAttribute; +/// MLIR op operand. +/// +/// An op operand is essentially an (operation, index) pair where the operation +/// is the owner MlirOperation and the index is the operand number. +struct MlirOpOperand { + MlirOperation owner; + size_t operandNumber; +}; +typedef struct MlirOpOperand MlirOpOperand; + //===----------------------------------------------------------------------===// // Context API. //===----------------------------------------------------------------------===// @@ -727,6 +737,17 @@ MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); +MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); + +//===----------------------------------------------------------------------===// +// OpOperand API. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand); + +MLIR_CAPI_EXPORTED MlirOpOperand +mlirOpOperandGetNextUse(MlirOpOperand opOperand); + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -719,6 +719,41 @@ unwrap(value).print(stream); } +MlirOpOperand mlirValueGetFirstUse(MlirValue mlirValue) { + Value value = unwrap(mlirValue); + if (value.use_empty()) + return {}; + + OpOperand &opOperand = *value.use_begin(); + + return MlirOpOperand{wrap(opOperand.getOwner()), + opOperand.getOperandNumber()}; +} + +//===----------------------------------------------------------------------===// +// OpOperand API. +//===----------------------------------------------------------------------===// + +bool mlirOpOperandIsNull(MlirOpOperand mlirOpOperand) { + return mlirOpOperand.owner.ptr == nullptr; +} + +MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand mlirOpOperand) { + if (mlirOpOperandIsNull(mlirOpOperand)) + return {}; + + Operation *op = unwrap(mlirOpOperand.owner); + OpOperand &opOperand = op->getOpOperand(mlirOpOperand.operandNumber); + OpOperand *nextOpOperand = + static_cast(opOperand.getNextOperandUsingThisValue()); + + if (!nextOpOperand) + return {}; + + return MlirOpOperand{wrap(nextOpOperand->getOwner()), + nextOpOperand->getOperandNumber()}; +} + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1758,17 +1758,44 @@ fprintf(stderr, "Num Operands: %" PRIdPTR "\n", numOperands); // CHECK: Num Operands: 1 - MlirValue opOperand = mlirOperationGetOperand(op, 0); + MlirValue opOperand1 = mlirOperationGetOperand(op, 0); fprintf(stderr, "Original operand: "); - mlirValuePrint(opOperand, printToStderr, NULL); + mlirValuePrint(opOperand1, printToStderr, NULL); // CHECK: Original operand: {{.+}} arith.constant 0 : index mlirOperationSetOperand(op, 0, constOneValue); - opOperand = mlirOperationGetOperand(op, 0); + MlirValue opOperand2 = mlirOperationGetOperand(op, 0); fprintf(stderr, "Updated operand: "); - mlirValuePrint(opOperand, printToStderr, NULL); + mlirValuePrint(opOperand2, printToStderr, NULL); // CHECK: Updated operand: {{.+}} arith.constant 1 : index + // Test op operand APIs. + MlirOpOperand use1 = mlirValueGetFirstUse(opOperand1); + if (!mlirOpOperandIsNull(use1)) { + fprintf(stderr, "ERROR: Use should be null\n"); + return 1; + } + + MlirOpOperand use2 = mlirValueGetFirstUse(opOperand2); + if (mlirOpOperandIsNull(use2)) { + fprintf(stderr, "ERROR: Use should not be null\n"); + return 2; + } + + fprintf(stderr, "Use owner: "); + mlirOperationPrint(use2.owner, printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: Use owner: "dummy.op" + + fprintf(stderr, "Use operandNumber: %zu\n", use2.operandNumber); + // CHECK: Use operandNumber: 0 + + use2 = mlirOpOperandGetNextUse(use2); + if (!mlirOpOperandIsNull(use2)) { + fprintf(stderr, "ERROR: Next use should be null\n"); + return 3; + } + mlirOperationDestroy(op); mlirOperationDestroy(constZero); mlirOperationDestroy(constOne);