diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -15,6 +15,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" +#include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" //#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" @@ -181,6 +182,13 @@ return mlirStringRefCreate(s.data(), s.size()); } +struct MLIRError : public std::exception { + MLIRError(Twine message, std::vector &&errorDiagnostics) + : message(message.str()), errorDiagnostics(errorDiagnostics) {} + std::string message; + std::vector errorDiagnostics; +}; + /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } @@ -870,6 +878,13 @@ return *materializedNotes; } +py::tuple PyDiagnostic::asTuple() { + py::list notes; + for (py::handle x : getNotes()) + notes.append(x.cast().asTuple()); + return py::make_tuple(getSeverity(), getLocation(), getMessage(), notes); +} + //------------------------------------------------------------------------------ // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry //------------------------------------------------------------------------------ @@ -1059,16 +1074,37 @@ return created; } +struct PyScopedErrorDiagCapture { + MlirContext ctx; + MlirDiagnosticHandlerID handlerID; + std::vector errorDiagnostics; + + PyScopedErrorDiagCapture(MlirContext ctx) + : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( + ctx, handler, this, /*deleteUserData=*/nullptr)) {} + ~PyScopedErrorDiagCapture() { + mlirContextDetachDiagnosticHandler(ctx, handlerID); + } + + static MlirLogicalResult handler(MlirDiagnostic diag, void *userData) { + if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) + return mlirLogicalResultFailure(); + auto *diagCapture = static_cast(userData); + diagCapture->errorDiagnostics.push_back(PyDiagnostic(diag).asTuple()); + return mlirLogicalResultSuccess(); + } +}; + PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName) { + PyScopedErrorDiagCapture diagCapture(contextRef->get()); MlirOperation op = mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), toMlirStringRef(sourceName)); - // TODO: Include error diagnostic messages in the exception message if (mlirOperationIsNull(op)) - throw py::value_error( - "Unable to parse operation assembly (see diagnostics)"); + throw MLIRError("Unable to parse operation assembly", + std::move(diagCapture.errorDiagnostics)); return PyOperation::createDetached(std::move(contextRef), op); } @@ -2845,9 +2881,10 @@ MlirStringRef parsedOpName = mlirIdentifierStr(mlirOperationGetName(*parsed.get())); if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName))) - throw py::value_error( + throw MLIRError( "Expected a '" + clsOpName + "' op, got: '" + - std::string(parsedOpName.data, parsedOpName.length) + "'"); + std::string(parsedOpName.data, parsedOpName.length) + "'", + {}); return PyOpView::constructDerived(cls, *parsed.get()); }, py::arg("cls"), py::arg("source"), py::kw_only(), @@ -3351,4 +3388,15 @@ // Attribute builder getter. PyAttrBuilderMap::bind(m); + + py::register_local_exception_translator([](std::exception_ptr p) { + try { + if (p) + std::rethrow_exception(p); + } catch (const MLIRError &e) { + py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("MLIRError")(e.message, e.errorDiagnostics); + PyErr_SetObject(PyExc_ValueError, obj.ptr()); + } + }); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -294,6 +294,7 @@ PyLocation getLocation(); pybind11::str getMessage(); pybind11::tuple getNotes(); + pybind11::tuple asTuple(); private: MlirDiagnostic diagnostic; diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -100,8 +100,27 @@ # all dialects. It is being done here in order to preserve existing # behavior. See: https://github.com/llvm/llvm-project/issues/56037 self.load_all_available_dialects() - ir.Context = Context + class MLIRError(ValueError): + def __init__(self, message, error_diagnostics): + self.message = message + self.error_diagnostics = error_diagnostics + super().__init__(str(self)) + + def __str__(self): + s = self.message + if self.error_diagnostics: + s += ':' + for severity, loc, msg, notes in self.error_diagnostics: + assert severity == ir.DiagnosticSeverity.ERROR + s += f"\nerror: {loc}: {msg}" + for severity, loc, msg, notes in notes: + assert severity == ir.DiagnosticSeverity.NOTE + assert not notes + s += f"\n {loc}: {msg}" + return s + ir.MLIRError = MLIRError + _site_initialize() diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -929,3 +929,44 @@ o = Operation.parse('"test.foo"() : () -> ()', source_name="my-source-string") # CHECK: op_with_source_name: "test.foo"() : () -> () loc("my-source-string":1:1) print(f"op_with_source_name: {o.get_asm(enable_debug_info=True, use_local_scope=True)}") + + +# CHECK-LABEL: TEST: test_exception +@run +def test_exception(): + from pprint import pprint + with Context() as ctx, Location.unknown(): + ctx.allow_unregistered_dialects = True + try: + Operation.parse(""" + func.func @foo() { + "test.use"(%0) : (i64) -> () loc("use") + %0 = "test.def"() : () -> i64 loc("def") + return + } + """) + except MLIRError as e: + # CHECK: exception: + # CHECK: Unable to parse operation assembly: + # CHECK: error: loc("use"): operand #0 does not dominate this use + # CHECK: loc("use"): see current operation: "test.use"(%0) : (i64) -> () + # CHECK: loc("def"): operand defined here (op in the same block) + print(f"exception:\n{e}") + + # CHECK: message: Unable to parse operation assembly + print(f"message: {e.message}") + + # CHECK: error_diagnostics: + # CHECK: [(, + # CHECK: loc("use"), + # CHECK: 'operand #0 does not dominate this use', + # CHECK: [(, + # CHECK: loc("use"), + # CHECK: 'see current operation: "test.use"(%0) : (i64) -> ()', + # CHECK: []), + # CHECK: (, + # CHECK: loc("def"), + # CHECK: 'operand defined here (op in the same block)', + # CHECK: [])])] + print("error_diagnostics: ") + pprint(e.error_diagnostics)