diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -641,6 +641,9 @@ MLIR_CAPI_EXPORTED MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str); +/// Returns the context associated with this identifier +MLIR_CAPI_EXPORTED MlirContext mlirIdentifierGetContext(MlirIdentifier); + /// Checks whether two identifiers are the same. MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -661,6 +661,10 @@ return wrap(Identifier::get(unwrap(str), unwrap(context))); } +MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { + return wrap(unwrap(ident).getContext()); +} + bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { return unwrap(ident) == unwrap(other); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1456,6 +1456,7 @@ mlirRegionAppendOwnedBlock(region, block); mlirOperationStateAddOwnedRegions(&opState, 1, ®ion); MlirOperation op = mlirOperationCreate(&opState); + MlirIdentifier ident = mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("identifier")); if (!mlirContextEqual(ctx, mlirOperationGetContext(op))) { fprintf(stderr, "ERROR: Getting context from operation failed\n"); @@ -1465,6 +1466,10 @@ fprintf(stderr, "ERROR: Getting parent operation from block failed\n"); return 2; } + if (!mlirContextEqual(ctx, mlirIdentifierGetContext(ident))) { + fprintf(stderr, "ERROR: Getting context from identifier failed\n"); + return 3; + } mlirOperationDestroy(op); mlirContextDestroy(ctx);