diff --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md --- a/mlir/docs/CAPI.md +++ b/mlir/docs/CAPI.md @@ -71,10 +71,31 @@ ### Nullity A handle may refer to a _null_ object. It is the responsibility of the caller to -check if an object is null by using `MlirXIsNull(MlirX)`. API functions do _not_ +check if an object is null by using `mlirXIsNull(MlirX)`. API functions do _not_ expect null objects as arguments unless explicitly stated otherwise. API functions _may_ return null objects. +### Conversion To String and Printing + +IR objects can be converted to a string representation, for example for +printing, using `mlirXPrint(MlirX, MlirPrintCallback, void *)` functions. These +functions accept take arguments a callback with signature `void (*)(const char +*, intptr_t, void *)` and a pointer to user-defined data. They call the callback +and supply it with chunks of the string representation, provided as a pointer to +the first character and a length, and forward the user-defined data unmodified. +It is up to the caller to allocate memory if the string representation must be +stored and perform the copy. There is no guarantee that the pointer supplied to +the callback points to a null-terminated string, the size argument should be +used to find the end of the string. The callback may be called multiple times +with consecutive chunks of the string representation (the printing itself is +bufferred). + +*Rationale*: this approach allows the caller to have full control of the +allocation and avoid unnecessary allocation and copying inside the printer. + +For convenience, `mlirXDump(MlirX)` functions are provided to print the given +object to the standard error stream. + ### Common Patterns The API adopts the following patterns for recurrent functionality in MLIR. diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -60,7 +60,7 @@ /** Named MLIR attribute. * - * A named attribute is essentially a (name, attrbute) pair where the name is + * A named attribute is essentially a (name, attribute) pair where the name is * a string. */ struct MlirNamedAttribute { @@ -69,6 +69,17 @@ }; typedef struct MlirNamedAttribute MlirNamedAttribute; +/** A callback for printing to IR objects. + * + * This function is called back by the printing functions with the following + * arguments: + * - a pointer to the beginning of a string; + * - the length of the string (the pointer may point to a larger buffer, not + * necessarily null-terminated); + * - a pointer to user data forwarded from the printing call. + */ +typedef void (*MlirPrintCallback)(const char *, intptr_t, void *); + /*============================================================================*/ /* Context API. */ /*============================================================================*/ @@ -91,6 +102,12 @@ /** Creates a location with unknown position owned by the given context. */ MlirLocation mlirLocationUnknownGet(MlirContext context); +/** Prints a location by sending chunks of the string representation and + * forwarding `userData to `callback`. Note that the callback may be called + * several times with consecutive chunks of the string. */ +void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, + void *userData); + /*============================================================================*/ /* Module API. */ /*============================================================================*/ @@ -202,6 +219,14 @@ /** Returns an attrbute attached to the operation given its name. */ MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, const char *name); + +/** Prints an operation by sending chunks of the string representation and + * forwarding `userData to `callback`. Note that the callback may be called + * several times with consecutive chunks of the string. */ +void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, + void *userData); + +/** Prints an operation to stderr. */ void mlirOperationDump(MlirOperation op); /*============================================================================*/ @@ -263,6 +288,12 @@ /** Returns `pos`-th argument of the block. */ MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos); +/** Prints a block by sending chunks of the string representation and + * forwarding `userData to `callback`. Note that the callback may be called + * several times with consecutive chunks of the string. */ +void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, + void *userData); + /*============================================================================*/ /* Value API. */ /*============================================================================*/ @@ -270,6 +301,12 @@ /** Returns the type of the value. */ MlirType mlirValueGetType(MlirValue value); +/** Prints a value by sending chunks of the string representation and + * forwarding `userData to `callback`. Note that the callback may be called + * several times with consecutive chunks of the string. */ +void mlirValuePrint(MlirValue value, MlirPrintCallback callback, + void *userData); + /*============================================================================*/ /* Type API. */ /*============================================================================*/ @@ -277,6 +314,11 @@ /** Parses a type. The type is owned by the context. */ MlirType mlirTypeParseGet(MlirContext context, const char *type); +/** Prints a location by sending chunks of the string representation and + * forwarding `userData to `callback`. Note that the callback may be called + * several times with consecutive chunks of the string. */ +void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData); + /** Prints the type to the standard error stream. */ void mlirTypeDump(MlirType type); @@ -287,6 +329,12 @@ /** Parses an attribute. The attribute is owned by the context. */ MlirAttribute mlirAttributeParseGet(MlirContext context, const char *attr); +/** Prints an attribute by sending chunks of the string representation and + * forwarding `userData to `callback`. Note that the callback may be called + * several times with consecutive chunks of the string. */ +void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, + void *userData); + /** Prints the attrbute to the standard error stream. */ void mlirAttributeDump(MlirAttribute attr); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -13,6 +13,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/Parser.h" +#include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -56,6 +57,33 @@ return storage; } +/* ========================================================================== */ +/* Printing helper. */ +/* ========================================================================== */ + +namespace { +/// A simple raw ostream subclass that forwards write_impl calls to the +/// user-supplied callback together with opaque user-supplied data. +class CallbackOstream : public llvm::raw_ostream { +public: + CallbackOstream(std::function callback, + void *opaqueData) + : callback(callback), opaqueData(opaqueData), pos(0u) {} + + void write_impl(const char *ptr, size_t size) override { + callback(ptr, size, opaqueData); + pos += size; + } + + uint64_t current_pos() const override { return pos; } + +private: + std::function callback; + void *opaqueData; + uint64_t pos; +}; +} // end namespace + /* ========================================================================== */ /* Context API. */ /* ========================================================================== */ @@ -81,6 +109,13 @@ return wrap(UnknownLoc::get(unwrap(context))); } +void mlirLocationPrint(MlirLocation location, MlirPrintCallback callback, + void *userData) { + CallbackOstream stream(callback, userData); + unwrap(location).print(stream); + stream.flush(); +} + /* ========================================================================== */ /* Module API. */ /* ========================================================================== */ @@ -239,6 +274,13 @@ return wrap(unwrap(op)->getAttr(name)); } +void mlirOperationPrint(MlirOperation op, MlirPrintCallback callback, + void *userData) { + CallbackOstream stream(callback, userData); + unwrap(op)->print(stream); + stream.flush(); +} + void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } /* ========================================================================== */ @@ -314,6 +356,13 @@ return wrap(unwrap(block)->getArgument(static_cast(pos))); } +void mlirBlockPrint(MlirBlock block, MlirPrintCallback callback, + void *userData) { + CallbackOstream stream(callback, userData); + unwrap(block)->print(stream); + stream.flush(); +} + /* ========================================================================== */ /* Value API. */ /* ========================================================================== */ @@ -322,6 +371,13 @@ return wrap(unwrap(value).getType()); } +void mlirValuePrint(MlirValue value, MlirPrintCallback callback, + void *userData) { + CallbackOstream stream(callback, userData); + unwrap(value).print(stream); + stream.flush(); +} + /* ========================================================================== */ /* Type API. */ /* ========================================================================== */ @@ -330,6 +386,12 @@ return wrap(mlir::parseType(type, unwrap(context))); } +void mlirTypePrint(MlirType type, MlirPrintCallback callback, void *userData) { + CallbackOstream stream(callback, userData); + unwrap(type).print(stream); + stream.flush(); +} + void mlirTypeDump(MlirType type) { unwrap(type).dump(); } /* ========================================================================== */ @@ -340,6 +402,13 @@ return wrap(mlir::parseAttribute(attr, unwrap(context))); } +void mlirAttributePrint(MlirAttribute attr, MlirPrintCallback callback, + void *userData) { + CallbackOstream stream(callback, userData); + unwrap(attr).print(stream); + stream.flush(); +} + void mlirAttributeDump(MlirAttribute attr) { unwrap(attr).dump(); } MlirNamedAttribute mlirNamedAttributeGet(const char *name, MlirAttribute attr) { diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -197,11 +197,47 @@ head = next; } while (head); - printf("Number of operations: %u\n", stats.numOperations); - printf("Number of attributes: %u\n", stats.numAttributes); - printf("Number of blocks: %u\n", stats.numBlocks); - printf("Number of regions: %u\n", stats.numRegions); - printf("Number of values: %u\n", stats.numValues); + fprintf(stderr, "Number of operations: %u\n", stats.numOperations); + fprintf(stderr, "Number of attributes: %u\n", stats.numAttributes); + fprintf(stderr, "Number of blocks: %u\n", stats.numBlocks); + fprintf(stderr, "Number of regions: %u\n", stats.numRegions); + fprintf(stderr, "Number of values: %u\n", stats.numValues); +} + +static void printToStderr(const char *str, intptr_t len, void *userData) { + (void)userData; + fwrite(str, 1, len, stderr); +} + +static void printFirstOfEach(MlirOperation operation) { + // Assuming we are given a module, go to the first operation of the first + // function. + MlirRegion region = mlirOperationGetRegion(operation, 0); + MlirBlock block = mlirRegionGetFirstBlock(region); + operation = mlirBlockGetFirstOperation(block); + region = mlirOperationGetRegion(operation, 0); + block = mlirRegionGetFirstBlock(region); + operation = mlirBlockGetFirstOperation(block); + + // In the module we created, the first operation of the first function is an + // "std.dim", which has an attribute an a single result that we can use to + // test the printing mechanism. + mlirBlockPrint(block, printToStderr, NULL); + fprintf(stderr, "\n"); + mlirOperationPrint(operation, printToStderr, NULL); + fprintf(stderr, "\n"); + + MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation, 0); + mlirAttributePrint(namedAttr.attribute, printToStderr, NULL); + fprintf(stderr, "\n"); + + MlirValue value = mlirOperationGetResult(operation, 0); + mlirValuePrint(value, printToStderr, NULL); + fprintf(stderr, "\n"); + + MlirType type = mlirValueGetType(value); + mlirTypePrint(type, printToStderr, NULL); + fprintf(stderr, "\n"); } int main() { @@ -238,6 +274,24 @@ // CHECK: Number of values: 9 // clang-format on + printFirstOfEach(module); + // clang-format off + // CHECK: %[[C0:.*]] = constant 0 : index + // CHECK: %[[DIM:.*]] = dim %{{.*}}, %[[C0]] : memref + // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[DIM]] step %[[C1]] { + // CHECK: %[[LHS:.*]] = load %{{.*}}[%[[I]]] : memref + // CHECK: %[[RHS:.*]] = load %{{.*}}[%[[I]]] : memref + // CHECK: %[[SUM:.*]] = addf %[[LHS]], %[[RHS]] : f32 + // CHECK: store %[[SUM]], %{{.*}}[%[[I]]] : memref + // CHECK: } + // CHECK: return + // CHECK: constant 0 : index + // CHECK: 0 : index + // CHECK: constant 0 : index + // CHECK: index + // clang-format on + mlirModuleDestroy(moduleOp); mlirContextDestroy(ctx);