diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -54,11 +54,12 @@ DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); -DEFINE_C_API_STRUCT(MlirValue, const void); DEFINE_C_API_STRUCT(MlirAttribute, const void); -DEFINE_C_API_STRUCT(MlirType, const void); +DEFINE_C_API_STRUCT(MlirIdentifier, const void); DEFINE_C_API_STRUCT(MlirLocation, const void); DEFINE_C_API_STRUCT(MlirModule, const void); +DEFINE_C_API_STRUCT(MlirType, const void); +DEFINE_C_API_STRUCT(MlirValue, const void); /** Named MLIR attribute. * @@ -285,6 +286,9 @@ * not perform deep comparison. */ int mlirOperationEqual(MlirOperation op, MlirOperation other); +/** Gets the name of the operation as an identifier. */ +MlirIdentifier mlirOperationGetName(MlirOperation op); + /** Gets the block that owns this operation, returning null if the operation is * not owned. */ MlirBlock mlirOperationGetBlock(MlirOperation op); @@ -552,6 +556,19 @@ /** Associates an attribute with the name. Takes ownership of neither. */ MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr); +/*============================================================================*/ +/* Identifier API. */ +/*============================================================================*/ + +/** Gets an identifier with the given string value. */ +MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str); + +/** Checks whether two identifiers are the same. */ +int mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other); + +/** Gets the string value of the identifier. */ +MlirStringRef mlirIdentifierStr(MlirIdentifier ident); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -16,6 +16,7 @@ #define MLIR_INCLUDE_MLIR_CAPI_IR_H #include "mlir/CAPI/Wrap.h" +#include "mlir/IR/Identifier.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/Operation.h" @@ -28,9 +29,10 @@ DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute) +DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier); DEFINE_C_API_METHODS(MlirLocation, mlir::Location) +DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) DEFINE_C_API_METHODS(MlirType, mlir::Type) DEFINE_C_API_METHODS(MlirValue, mlir::Value) -DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) #endif // MLIR_INCLUDE_MLIR_CAPI_IR_H diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -249,6 +249,10 @@ return unwrap(op) == unwrap(other); } +MlirIdentifier mlirOperationGetName(MlirOperation op) { + return wrap(unwrap(op)->getName().getIdentifier()); +} + MlirBlock mlirOperationGetBlock(MlirOperation op) { return wrap(unwrap(op)->getBlock()); } @@ -576,3 +580,19 @@ MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { return MlirNamedAttribute{name, attr}; } + +/*============================================================================*/ +/* Identifier API. */ +/*============================================================================*/ + +MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { + return wrap(Identifier::get(unwrap(str), unwrap(context))); +} + +int mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { + return unwrap(ident) == unwrap(other); +} + +MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { + return wrap(unwrap(ident).strref()); +} diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -281,6 +281,19 @@ mlirOperationPrint(operation, printToStderr, NULL); fprintf(stderr, "\n"); + // Get the operation name and print it. + MlirIdentifier ident = mlirOperationGetName(operation); + MlirStringRef identStr = mlirIdentifierStr(ident); + fprintf(stderr, "Operation name: '"); + for (size_t i = 0; i < identStr.length; ++i) + fputc(identStr.data[i], stderr); + fprintf(stderr, "'\n"); + + // Get the identifier again and verify equal. + MlirIdentifier identAgain = mlirIdentifierGet(ctx, identStr); + fprintf(stderr, "Identifier equal: %d\n", + mlirIdentifierEqual(ident, identAgain)); + // Get the block terminator and print it. MlirOperation terminator = mlirBlockGetTerminator(block); fprintf(stderr, "Terminator: "); @@ -1127,6 +1140,8 @@ // CHECK: } // CHECK: return // CHECK: First operation: {{.*}} = constant 0 : index + // CHECK: Operation name: 'std.constant' + // CHECK: Identifier equal: 1 // CHECK: Terminator: return // CHECK: Get attr 0: 0 : index // CHECK: Get attr 0 by name: 0 : index