diff --git a/mlir/include/mlir-c/Diagnostics.h b/mlir/include/mlir-c/Diagnostics.h --- a/mlir/include/mlir-c/Diagnostics.h +++ b/mlir/include/mlir-c/Diagnostics.h @@ -40,12 +40,14 @@ /// Opaque identifier of a diagnostic handler, useful to detach a handler. typedef uint64_t MlirDiagnosticHandlerID; -/** Diagnostic handler type. Acceps a reference to a diagnostic, which is only - * guaranteed to be live during the call. If the handler processed the - * diagnostic completely, it is expected to return success. Otherwise, it is - * expected to return failure to indicate that other handlers should attempt to - * process the diagnostic. */ -typedef MlirLogicalResult (*MlirDiagnosticHandler)(MlirDiagnostic); +/** Diagnostic handler type. Accepts a reference to a diagnostic, which is only + * guaranteed to be live during the call. The handler is passed the `userData` + * that was provided when the handler was attached to a context. If the handler + * processed the diagnostic completely, it is expected to return success. + * Otherwise, it is expected to return failure to indicate that other handlers + * should attempt to process the diagnostic. */ +typedef MlirLogicalResult (*MlirDiagnosticHandler)(MlirDiagnostic, + void *userData); /// Prints a diagnostic using the provided callback. MLIR_CAPI_EXPORTED void mlirDiagnosticPrint(MlirDiagnostic diagnostic, @@ -71,9 +73,15 @@ /** Attaches the diagnostic handler to the context. Handlers are invoked in the * reverse order of attachment until one of them processes the diagnostic - * completely. Returns an identifier that can be used to detach the handler. */ + * completely. When a handler is invoked it is passed the `userData` that was + * provided when it was attached. If non-NULL, `deleteUserData` is called once + * the system no longer needs to call the handler (for instance after the + * handler is detached or the context is destroyed). Returns an identifier that + * can be used to detach the handler. + */ MLIR_CAPI_EXPORTED MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( - MlirContext context, MlirDiagnosticHandler handler); + MlirContext context, MlirDiagnosticHandler handler, void *userData, + void (*deleteUserData)(void *)); /** Detaches an attached diagnostic handler from the context given its * identifier. */ diff --git a/mlir/lib/CAPI/IR/Diagnostics.cpp b/mlir/lib/CAPI/IR/Diagnostics.cpp --- a/mlir/lib/CAPI/IR/Diagnostics.cpp +++ b/mlir/lib/CAPI/IR/Diagnostics.cpp @@ -51,14 +51,19 @@ return wrap(*std::next(unwrap(diagnostic).getNotes().begin(), pos)); } -MlirDiagnosticHandlerID -mlirContextAttachDiagnosticHandler(MlirContext context, - MlirDiagnosticHandler handler) { +static void deleteUserDataNoop(void *userData) {} + +MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( + MlirContext context, MlirDiagnosticHandler handler, void *userData, + void (*deleteUserData)(void *)) { assert(handler && "unexpected null diagnostic handler"); + if (deleteUserData == NULL) + deleteUserData = deleteUserDataNoop; + std::shared_ptr sharedUserData(userData, deleteUserData); DiagnosticEngine::HandlerID id = unwrap(context)->getDiagEngine().registerHandler( - [handler](Diagnostic &diagnostic) { - return unwrap(handler(wrap(diagnostic))); + [handler, sharedUserData](Diagnostic &diagnostic) { + return unwrap(handler(wrap(diagnostic), sharedUserData.get())); }); return static_cast(id); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1248,31 +1248,37 @@ } // Wraps a diagnostic into additional text we can match against. -MlirLogicalResult errorHandler(MlirDiagnostic diagnostic) { - fprintf(stderr, "processing diagnostic <<\n"); +MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) { + fprintf(stderr, "processing diagnostic (userData: %d) <<\n", (int)userData); mlirDiagnosticPrint(diagnostic, printToStderr, NULL); fprintf(stderr, "\n"); MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); mlirLocationPrint(loc, printToStderr, NULL); assert(mlirDiagnosticGetNumNotes(diagnostic) == 0); - fprintf(stderr, ">> end of diagnostic\n"); + fprintf(stderr, ">> end of diagnostic (userData: %d)\n", (int)userData); return mlirLogicalResultSuccess(); } +// Logs when the delete user data callback is called +static void deleteUserData(void *userData) { + fprintf(stderr, "deleting user data (userData: %d)\n", (int)userData); +} + void testDiagnostics() { MlirContext ctx = mlirContextCreate(); - MlirDiagnosticHandlerID id = - mlirContextAttachDiagnosticHandler(ctx, errorHandler); + MlirDiagnosticHandlerID id = mlirContextAttachDiagnosticHandler( + ctx, errorHandler, (void *)42, deleteUserData); MlirLocation loc = mlirLocationUnknownGet(ctx); fprintf(stderr, "@test_diagnostics\n"); mlirEmitError(loc, "test diagnostics"); mlirContextDetachDiagnosticHandler(ctx, id); mlirEmitError(loc, "more test diagnostics"); // CHECK-LABEL: @test_diagnostics - // CHECK: processing diagnostic << + // CHECK: processing diagnostic (userData: 42) << // CHECK: test diagnostics // CHECK: loc(unknown) - // CHECK: >> end of diagnostic + // CHECK: >> end of diagnostic (userData: 42) + // CHECK: deleting user data (userData: 42) // CHECK-NOT: processing diagnostic // CHECK: more test diagnostics }