diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -92,7 +92,9 @@ int mlirContextEqual(MlirContext ctx1, MlirContext ctx2); /** Checks whether a context is null. */ -inline int mlirContextIsNull(MlirContext context) { return !context.ptr; } +static inline int mlirContextIsNull(MlirContext context) { + return !context.ptr; +} /** Takes an MLIR context owned by the caller and destroys it. */ void mlirContextDestroy(MlirContext context); @@ -127,7 +129,9 @@ MlirContext mlirDialectGetContext(MlirDialect dialect); /** Checks if the dialect is null. */ -int mlirDialectIsNull(MlirDialect dialect); +static inline int mlirDialectIsNull(MlirDialect dialect) { + return !dialect.ptr; +} /** Checks if two dialects that belong to the same context are equal. Dialects * from different contexts will not compare equal. */ @@ -171,7 +175,7 @@ MlirContext mlirModuleGetContext(MlirModule module); /** Checks whether a module is null. */ -inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; } +static inline int mlirModuleIsNull(MlirModule module) { return !module.ptr; } /** Takes a module owned by the caller and deletes it. */ void mlirModuleDestroy(MlirModule module); @@ -235,7 +239,7 @@ void mlirOperationDestroy(MlirOperation op); /** Checks whether the underlying operation is null. */ -int mlirOperationIsNull(MlirOperation op); +static inline int mlirOperationIsNull(MlirOperation op) { return !op.ptr; } /** Returns the number of regions attached to the given operation. */ intptr_t mlirOperationGetNumRegions(MlirOperation op); @@ -275,6 +279,15 @@ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, const char *name); +/** Sets an attribute by name, replacing the existing if it exists or + * adding a new one otherwise. */ +void mlirOperationSetAttributeByName(MlirOperation op, const char *name, + MlirAttribute attr); + +/** Removes an attribute by name. Returns 0 if the attribute was not found + * and !0 if removed. */ +int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name); + /** Prints an operation by sending chunks of the string representation and * forwarding `userData to `callback`. Note that the callback may be called * several times with consecutive chunks of the string. */ @@ -295,7 +308,7 @@ void mlirRegionDestroy(MlirRegion region); /** Checks whether a region is null. */ -int mlirRegionIsNull(MlirRegion region); +static inline int mlirRegionIsNull(MlirRegion region) { return !region.ptr; } /** Gets the first block in the region. */ MlirBlock mlirRegionGetFirstBlock(MlirRegion region); @@ -333,7 +346,7 @@ void mlirBlockDestroy(MlirBlock block); /** Checks whether a block is null. */ -int mlirBlockIsNull(MlirBlock block); +static inline int mlirBlockIsNull(MlirBlock block) { return !block.ptr; } /** Returns the block immediately following the given block in its parent * region. */ @@ -381,6 +394,9 @@ /* Value API. */ /*============================================================================*/ +/** Returns whether the value is null. */ +static inline int mlirValueIsNull(MlirValue value) { return !value.ptr; } + /** Returns the type of the value. */ MlirType mlirValueGetType(MlirValue value); @@ -401,7 +417,7 @@ MlirContext mlirTypeGetContext(MlirType type); /** Checks whether a type is null. */ -inline int mlirTypeIsNull(MlirType type) { return !type.ptr; } +static inline int mlirTypeIsNull(MlirType type) { return !type.ptr; } /** Checks if two types are equal. */ int mlirTypeEqual(MlirType t1, MlirType t2); @@ -425,7 +441,7 @@ MlirContext mlirAttributeGetContext(MlirAttribute attribute); /** Checks whether an attribute is null. */ -inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } +static inline int mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } /** Checks if two attributes are equal. */ int mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -66,10 +66,6 @@ return wrap(unwrap(dialect)->getContext()); } -int mlirDialectIsNull(MlirDialect dialect) { - return unwrap(dialect) == nullptr; -} - int mlirDialectEqual(MlirDialect dialect1, MlirDialect dialect2) { return unwrap(dialect1) == unwrap(dialect2); } @@ -215,8 +211,6 @@ void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } -int mlirOperationIsNull(MlirOperation op) { return unwrap(op) == nullptr; } - intptr_t mlirOperationGetNumRegions(MlirOperation op) { return static_cast(unwrap(op)->getNumRegions()); } @@ -267,6 +261,16 @@ return wrap(unwrap(op)->getAttr(name)); } +void mlirOperationSetAttributeByName(MlirOperation op, const char *name, + MlirAttribute attr) { + unwrap(op)->setAttr(name, unwrap(attr)); +} + +int mlirOperationRemoveAttributeByName(MlirOperation op, const char *name) { + auto removeResult = unwrap(op)->removeAttr(name); + return removeResult == MutableDictionaryAttr::RemoveResult::Removed; +} + void mlirOperationPrint(MlirOperation op, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); @@ -328,8 +332,6 @@ delete static_cast(region.ptr); } -int mlirRegionIsNull(MlirRegion region) { return unwrap(region) == nullptr; } - /* ========================================================================== */ /* Block API. */ /* ========================================================================== */ @@ -391,8 +393,6 @@ void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } -int mlirBlockIsNull(MlirBlock block) { return unwrap(block) == nullptr; } - intptr_t mlirBlockGetNumArguments(MlirBlock block) { return static_cast(unwrap(block)->getNumArguments()); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -216,7 +216,7 @@ fwrite(str, 1, len, stderr); } -static void printFirstOfEach(MlirOperation operation) { +static void printFirstOfEach(MlirContext ctx, MlirOperation operation) { // Assuming we are given a module, go to the first operation of the first // function. MlirRegion region = mlirOperationGetRegion(operation, 0); @@ -227,24 +227,59 @@ operation = mlirBlockGetFirstOperation(block); // In the module we created, the first operation of the first function is an - // "std.dim", which has an attribute an a single result that we can use to + // "std.dim", which has an attribute and a single result that we can use to // test the printing mechanism. mlirBlockPrint(block, printToStderr, NULL); fprintf(stderr, "\n"); + fprintf(stderr, "First operation: "); mlirOperationPrint(operation, printToStderr, NULL); fprintf(stderr, "\n"); - MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation, 0); - mlirAttributePrint(namedAttr.attribute, printToStderr, NULL); + // Get the attribute by index. + MlirNamedAttribute namedAttr0 = mlirOperationGetAttribute(operation, 0); + fprintf(stderr, "Get attr 0: "); + mlirAttributePrint(namedAttr0.attribute, printToStderr, NULL); fprintf(stderr, "\n"); + // Now re-get the attribute by name. + MlirAttribute attr0ByName = + mlirOperationGetAttributeByName(operation, namedAttr0.name); + fprintf(stderr, "Get attr 0 by name: "); + mlirAttributePrint(attr0ByName, printToStderr, NULL); + fprintf(stderr, "\n"); + + // Get a non-existing attribute and assert that it is null (sanity). + fprintf(stderr, "does_not_exist is null: %d\n", + mlirAttributeIsNull( + mlirOperationGetAttributeByName(operation, "does_not_exist"))); + + // Get result 0 and its type. MlirValue value = mlirOperationGetResult(operation, 0); + fprintf(stderr, "Result 0: "); mlirValuePrint(value, printToStderr, NULL); fprintf(stderr, "\n"); + fprintf(stderr, "Value is null: %d\n", mlirValueIsNull(value)); MlirType type = mlirValueGetType(value); + fprintf(stderr, "Result 0 type: "); mlirTypePrint(type, printToStderr, NULL); fprintf(stderr, "\n"); + + // Set a custom attribute. + mlirOperationSetAttributeByName(operation, "custom_attr", + mlirBoolAttrGet(ctx, 1)); + fprintf(stderr, "Op with set attr: "); + mlirOperationPrint(operation, printToStderr, NULL); + fprintf(stderr, "\n"); + + // Remove the attribute. + fprintf(stderr, "Remove attr: %d\n", + mlirOperationRemoveAttributeByName(operation, "custom_attr")); + fprintf(stderr, "Remove attr again: %d\n", + mlirOperationRemoveAttributeByName(operation, "custom_attr")); + fprintf(stderr, "Removed attr is null: %d\n", + mlirAttributeIsNull( + mlirOperationGetAttributeByName(operation, "custom_attr"))); } /// Creates an operation with a region containing multiple blocks with @@ -884,7 +919,7 @@ // CHECK: Number of values: 9 // clang-format on - printFirstOfEach(module); + printFirstOfEach(ctx, module); // clang-format off // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref @@ -896,10 +931,17 @@ // CHECK: store %[[SUM]], %{{.*}}[%[[I]]] : memref // CHECK: } // CHECK: return - // CHECK: constant 0 : index - // CHECK: 0 : index - // CHECK: constant 0 : index - // CHECK: index + // CHECK: First operation: {{.*}} = constant 0 : index + // CHECK: Get attr 0: 0 : index + // CHECK: Get attr 0 by name: 0 : index + // CHECK: does_not_exist is null: 1 + // CHECK: Result 0: {{.*}} = constant 0 : index + // CHECK: Value is null: 0 + // CHECK: Result 0 type: index + // CHECK: Op with set attr: {{.*}} {custom_attr = true} + // CHECK: Remove attr: 1 + // CHECK: Remove attr again: 0 + // CHECK: Removed attr is null: 1 // clang-format on mlirModuleDestroy(moduleOp);