diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -241,6 +241,10 @@ /** Checks whether the underlying operation is null. */ static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; } +/** Checks whether two operation handles point to the same operation. This does + * not perform deep comparison. */ +int mlirOperationEqual(MlirOperation op, MlirOperation other); + /** Returns the number of regions attached to the given operation. */ intptr_t mlirOperationGetNumRegions(MlirOperation op); @@ -348,6 +352,10 @@ /** Checks whether a block is null. */ static inline int mlirBlockIsNull(MlirBlock block) { return !block.ptr; } +/** Checks whether two blocks handles point to the same block. This does not + * perform deep comparison. */ +int mlirBlockEqual(MlirBlock block, MlirBlock other); + /** Returns the block immediately following the given block in its parent * region. */ MlirBlock mlirBlockGetNextInRegion(MlirBlock block); @@ -397,6 +405,30 @@ /** Returns whether the value is null. */ static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; } +/** Returns 1 if the value is a block argument, 0 otherwise. */ +int mlirValueIsABlockArgument(MlirValue value); + +/** Returns 1 if the value is an operation result, 0 otherwise. */ +int mlirValueIsAOpResult(MlirValue value); + +/** Returns the block in which this value is defined as an argument. Asserts if + * the value is not a block argument. */ +MlirBlock mlirBlockArgumentGetOwner(MlirValue value); + +/** Returns the position of the value in the argument list of its block. */ +intptr_t mlirBlockArgumentGetArgNumber(MlirValue value); + +/** Sets the type of the block argument to the given type. */ +void mlirBlockArgumentSetType(MlirValue value, MlirType type); + +/** Returns an operation that produced this value as its result. Asserts if the + * value is not an op result. */ +MlirOperation mlirOpResultGetOwner(MlirValue value); + +/** Returns the position of the value in the list of results of the operation + * that produced it. */ +intptr_t mlirOpResultGetResultNumber(MlirValue value); + /** Returns the type of the value. */ MlirType mlirValueGetType(MlirValue value); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -211,6 +211,10 @@ void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } +int mlirOperationEqual(MlirOperation op, MlirOperation other) { + return unwrap(op) == unwrap(other); +} + intptr_t mlirOperationGetNumRegions(MlirOperation op) { return static_cast(unwrap(op)->getNumRegions()); } @@ -343,6 +347,10 @@ return wrap(b); } +int mlirBlockEqual(MlirBlock block, MlirBlock other) { + return unwrap(block) == unwrap(other); +} + MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { return wrap(unwrap(block)->getNextNode()); } @@ -412,6 +420,36 @@ /* Value API. */ /* ========================================================================== */ +int mlirValueIsABlockArgument(MlirValue value) { + return unwrap(value).isa(); +} + +int mlirValueIsAOpResult(MlirValue value) { + return unwrap(value).isa(); +} + +MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { + return wrap(unwrap(value).cast().getOwner()); +} + +intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { + return static_cast( + unwrap(value).cast().getArgNumber()); +} + +void mlirBlockArgumentSetType(MlirValue value, MlirType type) { + unwrap(value).cast().setType(unwrap(type)); +} + +MlirOperation mlirOpResultGetOwner(MlirValue value) { + return wrap(unwrap(value).cast().getOwner()); +} + +intptr_t mlirOpResultGetResultNumber(MlirValue value) { + return static_cast( + unwrap(value).cast().getResultNumber()); +} + MlirType mlirValueGetType(MlirValue value) { return wrap(unwrap(value).getType()); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -153,10 +153,12 @@ unsigned numBlocks; unsigned numRegions; unsigned numValues; + unsigned numBlockArguments; + unsigned numOpResults; }; typedef struct ModuleStats ModuleStats; -void collectStatsSingle(OpListNode *head, ModuleStats *stats) { +int collectStatsSingle(OpListNode *head, ModuleStats *stats) { MlirOperation operation = head->op; stats->numOperations += 1; stats->numValues += mlirOperationGetNumResults(operation); @@ -166,12 +168,39 @@ stats->numRegions += numRegions; + intptr_t numResults = mlirOperationGetNumResults(operation); + for (intptr_t i = 0; i < numResults; ++i) { + MlirValue result = mlirOperationGetResult(operation, i); + if (!mlirValueIsAOpResult(result)) + return 1; + if (mlirValueIsABlockArgument(result)) + return 2; + if (!mlirOperationEqual(operation, mlirOpResultGetOwner(result))) + return 3; + if (i != mlirOpResultGetResultNumber(result)) + return 4; + ++stats->numOpResults; + } + for (unsigned i = 0; i < numRegions; ++i) { MlirRegion region = mlirOperationGetRegion(operation, i); for (MlirBlock block = mlirRegionGetFirstBlock(region); !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) { ++stats->numBlocks; - stats->numValues += mlirBlockGetNumArguments(block); + intptr_t numArgs = mlirBlockGetNumArguments(block); + stats->numValues += numArgs; + for (intptr_t j = 0; j < numArgs; ++j) { + MlirValue arg = mlirBlockGetArgument(block, j); + if (!mlirValueIsABlockArgument(arg)) + return 5; + if (mlirValueIsAOpResult(arg)) + return 6; + if (!mlirBlockEqual(block, mlirBlockArgumentGetOwner(arg))) + return 7; + if (j != mlirBlockArgumentGetArgNumber(arg)) + return 8; + ++stats->numBlockArguments; + } for (MlirOperation child = mlirBlockGetFirstOperation(block); !mlirOperationIsNull(child); @@ -183,9 +212,10 @@ } } } + return 0; } -void collectStats(MlirOperation operation) { +int collectStats(MlirOperation operation) { OpListNode *head = malloc(sizeof(OpListNode)); head->op = operation; head->next = NULL; @@ -196,9 +226,13 @@ stats.numBlocks = 0; stats.numRegions = 0; stats.numValues = 0; + stats.numBlockArguments = 0; + stats.numOpResults = 0; do { - collectStatsSingle(head, &stats); + int retval = collectStatsSingle(head, &stats); + if (retval) + return retval; OpListNode *next = head->next; free(head); head = next; @@ -209,6 +243,11 @@ fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks); fprintf(stderr, "Number of regions: %u\n", stats.numRegions); fprintf(stderr, "Number of values: %u\n", stats.numValues); + fprintf(stderr, "Number of block arguments: %u\n", stats.numBlockArguments); + fprintf(stderr, "Number of op results: %u\n", stats.numOpResults); + if (stats.numValues != stats.numBlockArguments + stats.numOpResults) + return 100; + return 0; } static void printToStderr(const char *str, intptr_t len, void *userData) { @@ -914,13 +953,19 @@ // CHECK: } // clang-format on - collectStats(module); + fprintf(stderr, "@stats\n"); + int errcode = collectStats(module); + fprintf(stderr, "%d\n", errcode); // clang-format off + // CHECK-LABEL: @stats // CHECK: Number of operations: 13 // CHECK: Number of attributes: 4 // CHECK: Number of blocks: 3 // CHECK: Number of regions: 3 // CHECK: Number of values: 9 + // CHECK: Number of block arguments: 3 + // CHECK: Number of op results: 6 + // CHECK: 0 // clang-format on printFirstOfEach(ctx, module); @@ -988,7 +1033,7 @@ // CHECK: 0 // clang-format on fprintf(stderr, "@types\n"); - int errcode = printStandardTypes(ctx); + errcode = printStandardTypes(ctx); fprintf(stderr, "%d\n", errcode); // clang-format off