diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -52,6 +52,7 @@ DEFINE_C_API_STRUCT(MlirDialect, void); DEFINE_C_API_STRUCT(MlirDialectRegistry, void); DEFINE_C_API_STRUCT(MlirOperation, void); +DEFINE_C_API_STRUCT(MlirOpOperand, void); DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void); DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); @@ -728,6 +729,29 @@ MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); +/// Returns an op operand representing the first use of the value, or a null op +/// operand if there are no uses. +MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); + +//===----------------------------------------------------------------------===// +// OpOperand API. +//===----------------------------------------------------------------------===// + +/// Returns whether the op operand is null. +MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand); + +/// Returns the owner operation of an op operand. +MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand); + +/// Returns the operand number of an op operand. +MLIR_CAPI_EXPORTED unsigned +mlirOpOperandGetOperandNumber(MlirOpOperand opOperand); + +/// Returns an op operand representing the next use of the value, or a null op +/// operand if there is no next use. +MLIR_CAPI_EXPORTED MlirOpOperand +mlirOpOperandGetNextUse(MlirOpOperand opOperand); + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -25,6 +25,7 @@ DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry) DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) +DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand) DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags) DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -719,6 +719,43 @@ unwrap(value).print(stream); } +MlirOpOperand mlirValueGetFirstUse(MlirValue value) { + Value cppValue = unwrap(value); + if (cppValue.use_empty()) + return {}; + + OpOperand *opOperand = cppValue.use_begin().getOperand(); + + return wrap(opOperand); +} + +//===----------------------------------------------------------------------===// +// OpOperand API. +//===----------------------------------------------------------------------===// + +bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; } + +MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) { + return wrap(unwrap(opOperand)->getOwner()); +} + +unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) { + return unwrap(opOperand)->getOperandNumber(); +} + +MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) { + if (mlirOpOperandIsNull(opOperand)) + return {}; + + OpOperand *nextOpOperand = static_cast( + unwrap(opOperand)->getNextOperandUsingThisValue()); + + if (!nextOpOperand) + return {}; + + return wrap(nextOpOperand); +} + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1758,17 +1758,45 @@ fprintf(stderr, "Num Operands: %" PRIdPTR "\n", numOperands); // CHECK: Num Operands: 1 - MlirValue opOperand = mlirOperationGetOperand(op, 0); + MlirValue opOperand1 = mlirOperationGetOperand(op, 0); fprintf(stderr, "Original operand: "); - mlirValuePrint(opOperand, printToStderr, NULL); + mlirValuePrint(opOperand1, printToStderr, NULL); // CHECK: Original operand: {{.+}} arith.constant 0 : index mlirOperationSetOperand(op, 0, constOneValue); - opOperand = mlirOperationGetOperand(op, 0); + MlirValue opOperand2 = mlirOperationGetOperand(op, 0); fprintf(stderr, "Updated operand: "); - mlirValuePrint(opOperand, printToStderr, NULL); + mlirValuePrint(opOperand2, printToStderr, NULL); // CHECK: Updated operand: {{.+}} arith.constant 1 : index + // Test op operand APIs. + MlirOpOperand use1 = mlirValueGetFirstUse(opOperand1); + if (!mlirOpOperandIsNull(use1)) { + fprintf(stderr, "ERROR: Use should be null\n"); + return 1; + } + + MlirOpOperand use2 = mlirValueGetFirstUse(opOperand2); + if (mlirOpOperandIsNull(use2)) { + fprintf(stderr, "ERROR: Use should not be null\n"); + return 2; + } + + fprintf(stderr, "Use owner: "); + mlirOperationPrint(mlirOpOperandGetOwner(use2), printToStderr, NULL); + fprintf(stderr, "\n"); + // CHECK: Use owner: "dummy.op" + + fprintf(stderr, "Use operandNumber: %d\n", + mlirOpOperandGetOperandNumber(use2)); + // CHECK: Use operandNumber: 0 + + use2 = mlirOpOperandGetNextUse(use2); + if (!mlirOpOperandIsNull(use2)) { + fprintf(stderr, "ERROR: Next use should be null\n"); + return 3; + } + mlirOperationDestroy(op); mlirOperationDestroy(constZero); mlirOperationDestroy(constOne);