diff --git a/mlir/include/mlir-c/Diagnostics.h b/mlir/include/mlir-c/Diagnostics.h --- a/mlir/include/mlir-c/Diagnostics.h +++ b/mlir/include/mlir-c/Diagnostics.h @@ -45,7 +45,8 @@ * diagnostic completely, it is expected to return success. Otherwise, it is * expected to return failure to indicate that other handlers should attempt to * process the diagnostic. */ -typedef MlirLogicalResult (*MlirDiagnosticHandler)(MlirDiagnostic); +typedef MlirLogicalResult (*MlirDiagnosticHandler)(MlirDiagnostic, + void *userData); /** Prints a diagnostic using the provided callback. */ void mlirDiagnosticPrint(MlirDiagnostic diagnostic, MlirStringCallback callback, @@ -67,9 +68,8 @@ /** Attaches the diagnostic handler to the context. Handlers are invoked in the * reverse order of attachment until one of them processes the diagnostic * completely. Returns an identifier that can be used to detach the handler. */ -MlirDiagnosticHandlerID -mlirContextAttachDiagnosticHandler(MlirContext context, - MlirDiagnosticHandler handler); +MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( + MlirContext context, MlirDiagnosticHandler handler, void *userData); /** Detaches an attached diagnostic handler from the context given its * identifier. */ diff --git a/mlir/lib/CAPI/IR/Diagnostics.cpp b/mlir/lib/CAPI/IR/Diagnostics.cpp --- a/mlir/lib/CAPI/IR/Diagnostics.cpp +++ b/mlir/lib/CAPI/IR/Diagnostics.cpp @@ -52,14 +52,13 @@ return wrap(*std::next(unwrap(diagnostic).getNotes().begin(), pos)); } -MlirDiagnosticHandlerID -mlirContextAttachDiagnosticHandler(MlirContext context, - MlirDiagnosticHandler handler) { +MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( + MlirContext context, MlirDiagnosticHandler handler, void *userData) { assert(handler && "unexpected null diagnostic handler"); DiagnosticEngine::HandlerID id = unwrap(context)->getDiagEngine().registerHandler( - [handler](Diagnostic &diagnostic) { - return unwrap(handler(wrap(diagnostic))); + [handler, userData](Diagnostic &diagnostic) { + return unwrap(handler(wrap(diagnostic), userData)); }); return static_cast(id); } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1064,8 +1064,8 @@ } // Wraps a diagnostic into additional text we can match against. -MlirLogicalResult errorHandler(MlirDiagnostic diagnostic) { - fprintf(stderr, "processing diagnostic <<\n"); +MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) { + fprintf(stderr, "processing diagnostic (userData: %d) <<\n", (int)userData); mlirDiagnosticPrint(diagnostic, printToStderr, NULL); fprintf(stderr, "\n"); MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); @@ -1078,7 +1078,7 @@ void testDiagnostics() { MlirContext ctx = mlirContextCreate(); MlirDiagnosticHandlerID id = - mlirContextAttachDiagnosticHandler(ctx, errorHandler); + mlirContextAttachDiagnosticHandler(ctx, errorHandler, (void *)42); MlirLocation loc = mlirLocationUnknownGet(ctx); mlirEmitError(loc, "test diagnostics"); mlirContextDetachDiagnosticHandler(ctx, id); @@ -1281,7 +1281,7 @@ testDiagnostics(); // clang-format off // CHECK-LABEL: @test_diagnostics - // CHECK: processing diagnostic << + // CHECK: processing diagnostic (userData: 42) << // CHECK: test diagnostics // CHECK: loc(unknown) // CHECK: >> end of diagnostic