diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -60,6 +60,7 @@ DEFINE_C_API_STRUCT(MlirLocation, const void); DEFINE_C_API_STRUCT(MlirModule, const void); DEFINE_C_API_STRUCT(MlirType, const void); +DEFINE_C_API_STRUCT(MlirTypeID, const void); DEFINE_C_API_STRUCT(MlirValue, const void); #undef DEFINE_C_API_STRUCT @@ -356,6 +357,11 @@ /// Gets the context this operation is associated with MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op); +/// Gets the type id of the operation. +/// Returns null if the operation does not have a registered operation +/// description. +MLIR_CAPI_EXPORTED MlirTypeID mlirOperationGetTypeID(MlirOperation op); + /// Gets the name of the operation as an identifier. MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op); @@ -626,6 +632,9 @@ /// Gets the context that a type was created with. MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type); +/// Gets the type ID of the type. +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type); + /// Checks whether a type is null. static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; } @@ -655,6 +664,9 @@ /// Gets the type of this attribute. MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute); +/// Gets the type id of the attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute); + /// Checks whether an attribute is null. static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } @@ -693,6 +705,21 @@ /// Gets the string value of the identifier. MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident); +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// + +/// Checks whether a type id is null. +MLIR_CAPI_EXPORTED static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { + return !typeID.ptr; +} + +/// Checks if two type ids are equal. +MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2); + +/// Returns the hash value of the type id. +MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -33,6 +33,7 @@ DEFINE_C_API_METHODS(MlirLocation, mlir::Location) DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) DEFINE_C_API_METHODS(MlirType, mlir::Type) +DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) DEFINE_C_API_METHODS(MlirValue, mlir::Value) #endif // MLIR_CAPI_IR_H diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -23,6 +23,7 @@ #include "mlir/Parser.h" #include "llvm/Support/Debug.h" +#include using namespace mlir; @@ -345,6 +346,13 @@ return wrap(unwrap(op)->getContext()); } +MlirTypeID mlirOperationGetTypeID(MlirOperation op) { + if (const auto *abstractOp = unwrap(op)->getAbstractOperation()) { + return wrap(abstractOp->typeID); + } + return {nullptr}; +} + MlirIdentifier mlirOperationGetName(MlirOperation op) { return wrap(unwrap(op)->getName().getIdentifier()); } @@ -658,6 +666,10 @@ return wrap(unwrap(type).getContext()); } +MlirTypeID mlirTypeGetTypeID(MlirType type) { + return wrap(unwrap(type).getTypeID()); +} + bool mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } @@ -685,6 +697,10 @@ return wrap(unwrap(attribute).getType()); } +MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { + return wrap(unwrap(attr).getTypeID()); +} + bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { return unwrap(a1) == unwrap(a2); } @@ -721,3 +737,15 @@ MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { return wrap(unwrap(ident).strref()); } + +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// + +bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { + return unwrap(typeID1) == unwrap(typeID2); +} + +size_t mlirTypeIDHashValue(MlirTypeID typeID) { + return hash_value(unwrap(typeID)); +}