diff --git a/mlir/include/mlir-c/Diagnostics.h b/mlir/include/mlir-c/Diagnostics.h --- a/mlir/include/mlir-c/Diagnostics.h +++ b/mlir/include/mlir-c/Diagnostics.h @@ -40,12 +40,14 @@ /// Opaque identifier of a diagnostic handler, useful to detach a handler. typedef uint64_t MlirDiagnosticHandlerID; -/** Diagnostic handler type. Acceps a reference to a diagnostic, which is only - * guaranteed to be live during the call. If the handler processed the - * diagnostic completely, it is expected to return success. Otherwise, it is - * expected to return failure to indicate that other handlers should attempt to - * process the diagnostic. */ -typedef MlirLogicalResult (*MlirDiagnosticHandler)(MlirDiagnostic); +/** Diagnostic handler type. Accepts a reference to a diagnostic, which is only + * guaranteed to be live during the call. The handler is passed the `userData` + * that was provided when the handler was attached to a context. If the handler + * processed the diagnostic completely, it is expected to return success. + * Otherwise, it is expected to return failure to indicate that other handlers + * should attempt to process the diagnostic. */ +typedef MlirLogicalResult (*MlirDiagnosticHandler)(MlirDiagnostic, + void *userData); /// Prints a diagnostic using the provided callback. MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, @@ -71,9 +73,15 @@ /** Attaches the diagnostic handler to the context. Handlers are invoked in the * reverse order of attachment until one of them processes the diagnostic - * completely. Returns an identifier that can be used to detach the handler. */ + * completely. When a handler is invoked it is passed the `userData` that was + * provided when it was attached. If non-NULL, `deleteUserData` is called once + * the system no longer needs to call the handler (for instance after the + * handler is detached or the context is destroyed). Returns an identifier that + * can be used to detach the handler. + */ MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( - MlirContext context, MlirDiagnosticHandler handler); + MlirContext context, MlirDiagnosticHandler handler, void *userData, + void (*deleteUserData)(void *)); /** Detaches an attached diagnostic handler from the context given its * identifier. */ diff --git a/mlir/lib/CAPI/IR/Diagnostics.cpp b/mlir/lib/CAPI/IR/Diagnostics.cpp --- a/mlir/lib/CAPI/IR/Diagnostics.cpp +++ b/mlir/lib/CAPI/IR/Diagnostics.cpp @@ -51,14 +51,20 @@ return wrap(*std::next(unwrap(diagnostic).getNotes().begin(), pos)); } -MlirDiagnosticHandlerID -mlirContextAttachDiagnosticHandler(MlirContext context, - MlirDiagnosticHandler handler) { +static void deleteUserDataNoop(void *userData) {} + +MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( + MlirContext context, MlirDiagnosticHandler handler, void *userData, + void (*deleteUserData)(void *)) { assert(handler && "unexpected null diagnostic handler"); + if (deleteUserData == NULL) { + deleteUserData = deleteUserDataNoop; + } + std::shared_ptr sharedUserData(userData, deleteUserData); DiagnosticEngine::HandlerID id = unwrap(context)->getDiagEngine().registerHandler( - [handler](Diagnostic &diagnostic) { - return unwrap(handler(wrap(diagnostic))); + [handler, sharedUserData](Diagnostic &diagnostic) { + return unwrap(handler(wrap(diagnostic), sharedUserData.get())); }); return static_cast(id); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -421,7 +421,9 @@ fprintf(stderr, "Op print with all flags: "); mlirOperationPrintWithFlags(operation, flags, printToStderr, NULL); fprintf(stderr, "\n"); - // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index loc(unknown) + // CHECK: Op print with all flags: %{{.*}} = "std.constant"() {elts = + // opaque<"", "0xDEADBEEF"> : tensor<4xi32>, value = 0 : index} : () -> index + // loc(unknown) mlirOpPrintingFlagsDestroy(flags); } @@ -911,14 +913,12 @@ (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements); int64_t *int64RawData = (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements); - float *floatRawData = - (float *)mlirDenseElementsAttrGetRawData(floatElements); + float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements); double *doubleRawData = (double *)mlirDenseElementsAttrGetRawData(doubleElements); if (uint32RawData[0] != 0u || uint32RawData[1] != 1u || - int32RawData[0] != 0 || int32RawData[1] != 1 || - uint64RawData[0] != 0u || uint64RawData[1] != 1u || - int64RawData[0] != 0 || int64RawData[1] != 1 || + int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u || + uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 || floatRawData[0] != 0.0f || floatRawData[1] != 1.0f || doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0) return 18; @@ -1248,31 +1248,37 @@ } // Wraps a diagnostic into additional text we can match against. -MlirLogicalResult errorHandler(MlirDiagnostic diagnostic) { - fprintf(stderr, "processing diagnostic <<\n"); +MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) { + fprintf(stderr, "processing diagnostic (userData: %d) <<\n", (int)userData); mlirDiagnosticPrint(diagnostic, printToStderr, NULL); fprintf(stderr, "\n"); MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); mlirLocationPrint(loc, printToStderr, NULL); assert(mlirDiagnosticGetNumNotes(diagnostic) == 0); - fprintf(stderr, ">> end of diagnostic\n"); + fprintf(stderr, ">> end of diagnostic (userData: %d)\n", (int)userData); return mlirLogicalResultSuccess(); } +// Logs when the delete user data callback is called +static void deleteUserData(void *userData) { + fprintf(stderr, "deleting user data (userData: %d)\n", (int)userData); +} + void testDiagnostics() { MlirContext ctx = mlirContextCreate(); - MlirDiagnosticHandlerID id = - mlirContextAttachDiagnosticHandler(ctx, errorHandler); + MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( + ctx, errorHandler, (void *)42, deleteUserData); MlirLocation loc = mlirLocationUnknownGet(ctx); fprintf(stderr, "@test_diagnostics\n"); mlirEmitError(loc, "test diagnostics"); mlirContextDetachDiagnosticHandler(ctx, id); mlirEmitError(loc, "more test diagnostics"); // CHECK-LABEL: @test_diagnostics - // CHECK: processing diagnostic << + // CHECK: processing diagnostic (userData: 42) << // CHECK: test diagnostics // CHECK: loc(unknown) - // CHECK: >> end of diagnostic + // CHECK: >> end of diagnostic (userData: 42) + // CHECK: deleting user data (userData: 42) // CHECK-NOT: processing diagnostic // CHECK: more test diagnostics }