diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -344,15 +344,10 @@ c.def_static( "get", [](PyType &type, double value, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(type)).cast() + - "' and expected floating point type."); - } + if (mlirAttributeIsNull(attr)) + throw MLIRError("Invalid attribute", errors.take()); return PyFloatAttribute(type.getContext(), attr); }, py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -15,6 +15,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" +#include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" //#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" @@ -38,7 +39,7 @@ static const char kContextParseTypeDocstring[] = R"(Parses the assembly form of a type. -Returns a Type object or raises a ValueError if the type cannot be parsed. +Returns a Type object or raises an MLIRError if the type cannot be parsed. See also: https://mlir.llvm.org/docs/LangRef/#type-system )"; @@ -58,7 +59,7 @@ static const char kModuleParseDocstring[] = R"(Parses a module's assembly format from a string. -Returns a new MlirModule or raises a ValueError if the parsing fails. +Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; @@ -654,6 +655,28 @@ return pyHandlerObject; } +MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, + void *userData) { + auto *self = static_cast(userData); + // Check if the context requested we emit errors instead of capturing them. + if (self->ctx->emitErrorDiagnostics) + return mlirLogicalResultFailure(); + + if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) + return mlirLogicalResultFailure(); + + PyDiagnostic pyDiag(diag); + std::vector> notes; + for (py::handle n : pyDiag.getNotes()) { + auto note = n.cast(); + notes.emplace_back(note.getLocation(), note.getMessage()); + } + self->errors.emplace_back(pyDiag.getLocation(), pyDiag.getMessage(), + std::move(notes)); + + return mlirLogicalResultSuccess(); +} + PyMlirContext &DefaultingPyMlirContext::resolve() { PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); if (!context) { @@ -1062,13 +1085,12 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName) { + PyMlirContext::ErrorCapture errors(contextRef); MlirOperation op = mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), toMlirStringRef(sourceName)); - // TODO: Include error diagnostic messages in the exception message if (mlirOperationIsNull(op)) - throw py::value_error( - "Unable to parse operation assembly (see diagnostics)"); + throw MLIRError("Unable to parse operation assembly", errors.take()); return PyOperation::createDetached(std::move(contextRef), op); } @@ -1155,6 +1177,14 @@ operation.parentKeepAlive = otherOp.parentKeepAlive; } +bool PyOperationBase::verify() { + PyOperation &op = getOperation(); + PyMlirContext::ErrorCapture errors(op.getContext()); + if (!mlirOperationVerify(op.get())) + throw MLIRError("Verification failed", errors.take()); + return true; +} + std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) @@ -2375,6 +2405,11 @@ mlirContextAppendDialectRegistry(self.get(), registry); }, py::arg("registry")) + .def_property("emit_error_diagnostics", nullptr, + &PyMlirContext::setEmitErrorDiagnostics, + "Emit error diagnostics to diagnostic handlers. By default " + "error diagnostics are captured and reported through " + "MLIRError exceptions.") .def("load_all_available_dialects", [](PyMlirContext &self) { mlirContextLoadAllAvailableDialects(self.get()); }); @@ -2566,16 +2601,12 @@ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) .def_static( "parse", - [](const std::string moduleAsm, DefaultingPyMlirContext context) { + [](const std::string &moduleAsm, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); MlirModule module = mlirModuleCreateParse( context->get(), toMlirStringRef(moduleAsm)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirModuleIsNull(module)) { - throw SetPyError( - PyExc_ValueError, - "Unable to parse module assembly (see diagnostics)"); - } + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, py::arg("asm"), py::arg("context") = py::none(), @@ -2724,13 +2755,9 @@ py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, py::arg("assume_verified") = false, kOperationGetAsmDocstring) - .def( - "verify", - [](PyOperationBase &self) { - return mlirOperationVerify(self.getOperation()); - }, - "Verify the operation and return true if it passes, false if it " - "fails.") + .def("verify", &PyOperationBase::verify, + "Verify the operation. Raises MLIRError if verification fails, and " + "returns true otherwise.") .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), "Puts self immediately after the other operation in its parent " "block.") @@ -2833,12 +2860,12 @@ // directly. std::string clsOpName = py::cast(cls.attr("OPERATION_NAME")); - MlirStringRef parsedOpName = + MlirStringRef identifier = mlirIdentifierStr(mlirOperationGetName(*parsed.get())); - if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName))) - throw py::value_error( - "Expected a '" + clsOpName + "' op, got: '" + - std::string(parsedOpName.data, parsedOpName.length) + "'"); + std::string_view parsedOpName(identifier.data, identifier.length); + if (clsOpName != parsedOpName) + throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + + parsedOpName + "'"); return PyOpView::constructDerived(cls, *parsed.get()); }, py::arg("cls"), py::arg("source"), py::kw_only(), @@ -3071,19 +3098,16 @@ .def_static( "parse", [](std::string attrSpec, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); MlirAttribute type = mlirAttributeParseGet( context->get(), toMlirStringRef(attrSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - Twine("Unable to parse attribute: '") + - attrSpec + "'"); - } + if (mlirAttributeIsNull(type)) + throw MLIRError("Unable to parse attribute", errors.take()); return PyAttribute(context->getRef(), type); }, py::arg("asm"), py::arg("context") = py::none(), - "Parses an attribute from an assembly form") + "Parses an attribute from an assembly form. Raises an MLIRError on " + "failure.") .def_property_readonly( "context", [](PyAttribute &self) { return self.getContext().getObject(); }, @@ -3182,15 +3206,11 @@ .def_static( "parse", [](std::string typeSpec, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); MlirType type = mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - Twine("Unable to parse type: '") + typeSpec + - "'"); - } + if (mlirTypeIsNull(type)) + throw MLIRError("Unable to parse type", errors.take()); return PyType(context->getRef(), type); }, py::arg("asm"), py::arg("context") = py::none(), @@ -3342,4 +3362,17 @@ // Attribute builder getter. PyAttrBuilderMap::bind(m); + + py::register_local_exception_translator([](std::exception_ptr p) { + // We can't define exceptions with custom fields through pybind, so instead + // the exception class is defined in python and imported here. + try { + if (p) + std::rethrow_exception(p); + } catch (const MLIRError &e) { + py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("MLIRError")(e.message, e.errorDiagnostics); + PyErr_SetObject(PyExc_Exception, obj.ptr()); + } + }); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -221,6 +221,36 @@ /// registration object (internally a PyDiagnosticHandler). pybind11::object attachDiagnosticHandler(pybind11::object callback); + /// Controls whether error diagnostics should be propagated to diagnostic + /// handlers, instead of being captured by `ErrorCapture`. + void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } + + /// RAII object that captures any error diagnostics emitted to the provided + /// context. + struct ErrorCapture { + ErrorCapture(PyMlirContextRef ctx) + : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( + ctx->get(), handler, /*userData=*/this, + /*deleteUserData=*/nullptr)) {} + ~ErrorCapture() { + mlirContextDetachDiagnosticHandler(ctx->get(), handlerID); + assert(errors.empty() && "unhandled captured errors"); + } + + // Tuple of {location, message, notes} for an error diagnostic. + using ErrorInfo = + std::tuple>>; + std::vector take() { return std::move(errors); }; + + private: + PyMlirContextRef ctx; + MlirDiagnosticHandlerID handlerID; + std::vector errors; + + static MlirLogicalResult handler(MlirDiagnostic diag, void *userData); + }; + private: PyMlirContext(MlirContext context); // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, @@ -248,6 +278,8 @@ llvm::DenseMap>; LiveOperationMap liveOperations; + bool emitErrorDiagnostics = false; + MlirContext context; friend class PyModule; friend class PyOperation; @@ -519,6 +551,10 @@ void moveAfter(PyOperationBase &other); void moveBefore(PyOperationBase &other); + /// Verify the operation. Throws `MLIRError` if verification fails, and + /// returns `true` otherwise. + bool verify(); + /// Each must provide access to the raw Operation. virtual PyOperation &getOperation() = 0; }; @@ -1073,6 +1109,17 @@ MlirSymbolTable symbolTable; }; +/// Custom exception that allows access to error diagnostic information. This is +/// converted to the `ir.MLIRError` python exception when thrown. +struct MLIRError { + MLIRError(llvm::Twine message, + std::vector + &&errorDiagnostics = {}) + : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} + std::string message; + std::vector errorDiagnostics; +}; + void populateIRAffine(pybind11::module &m); void populateIRAttributes(pybind11::module &m); void populateIRCore(pybind11::module &m); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -407,17 +407,11 @@ "get", [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyVectorType(elementType.getContext(), t); }, py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), @@ -438,20 +432,12 @@ "get", [](std::vector shape, PyType &elementType, std::optional &encodingAttr, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirType t = mlirRankedTensorTypeGetChecked( loc, shape.size(), shape.data(), elementType, encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyRankedTensorType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), @@ -479,18 +465,10 @@ c.def_static( "get", [](PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyUnrankedTensorType(elementType.getContext(), t); }, py::arg("element_type"), py::arg("loc") = py::none(), @@ -511,23 +489,15 @@ [](std::vector shape, PyType &elementType, PyAttribute *layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); MlirAttribute memSpaceAttr = memorySpace ? *memorySpace : mlirAttributeGetNull(); MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), shape.data(), layoutAttr, memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), @@ -570,23 +540,15 @@ "get", [](PyType &elementType, PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirAttribute memSpaceAttr = {}; if (memorySpace) memSpaceAttr = *memorySpace; MlirType t = mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyUnrankedMemRefType(elementType.getContext(), t); }, py::arg("element_type"), py::arg("memory_space"), diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -117,15 +117,16 @@ .def( "run", [](PyPassManager &passManager, PyOperationBase &op) { + PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRun(passManager.get(), op.getOperation().get()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_RuntimeError, - "Failure while executing pass pipeline."); + throw MLIRError("Failure while executing pass pipeline", + errors.take()); }, py::arg("operation"), - "Run the pass manager on the provided operation, throw a " - "RuntimeError on failure.") + "Run the pass manager on the provided operation, raising an " + "MLIRError on failure.") .def( "__str__", [](PyPassManager &self) { diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -100,8 +100,29 @@ # all dialects. It is being done here in order to preserve existing # behavior. See: https://github.com/llvm/llvm-project/issues/56037 self.load_all_available_dialects() - ir.Context = Context + class MLIRError(Exception): + """ + An exception with diagnostic information. Has the following fields: + message: str + error_diagnostics: List[Tuple[ir.Location, str, List[Tuple[ir.Location, str]]]] + """ + def __init__(self, message, error_diagnostics): + self.message = message + self.error_diagnostics = error_diagnostics + super().__init__(message, error_diagnostics) + + def __str__(self): + s = self.message + if self.error_diagnostics: + s += ':' + for loc, msg, notes in self.error_diagnostics: + s += "\nerror: " + str(loc)[4:-1] + ": " + msg.replace('\n', '\n ') + for loc, msg, in notes: + s += "\n note: " + str(loc)[4:-1] + ": " + msg.replace('\n', '\n ') + return s + ir.MLIRError = MLIRError + _site_initialize() diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -28,16 +28,17 @@ # CHECK-LABEL: TEST: testParseError -# TODO: Hook the diagnostic manager to capture a more meaningful error -# message. @run def testParseError(): with Context(): try: t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") - except ValueError as e: - # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' - print("testParseError:", e) + except MLIRError as e: + # CHECK: testParseError: < + # CHECK: Unable to parse attribute: + # CHECK: error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value + # CHECK: > + print(f"testParseError: <{e}>") else: print("Exception not produced") @@ -180,8 +181,9 @@ try: fattr_invalid = FloatAttr.get( IntegerType.get_signless(32), 42) - except ValueError as e: - # CHECK: invalid 'Type(i32)' and expected floating point type. + except MLIRError as e: + # CHECK: Invalid attribute: + # CHECK: error: unknown: expected floating point type print(e) else: print("Exception not produced") diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -26,16 +26,17 @@ # CHECK-LABEL: TEST: testParseError -# TODO: Hook the diagnostic manager to capture a more meaningful error -# message. @run def testParseError(): ctx = Context() try: t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx) - except ValueError as e: - # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST' - print("testParseError:", e) + except MLIRError as e: + # CHECK: testParseError: < + # CHECK: Unable to parse type: + # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type + # CHECK: > + print(f"testParseError: <{e}>") else: print("Exception not produced") @@ -292,8 +293,9 @@ none = NoneType.get() try: vector_invalid = VectorType.get(shape, none) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point or integer type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: vector elements must be int/index/float type but got 'none' print(e) else: print("Exception not produced") @@ -313,9 +315,9 @@ none = NoneType.get() try: tensor_invalid = RankedTensorType.get(shape, none) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid tensor element type: 'none' print(e) else: print("Exception not produced") @@ -361,9 +363,9 @@ none = NoneType.get() try: tensor_invalid = UnrankedTensorType.get(none) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid tensor element type: 'none' print(e) else: print("Exception not produced") @@ -400,9 +402,9 @@ none = NoneType.get() try: memref_invalid = MemRefType.get(shape, none) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid memref element type print(e) else: print("Exception not produced") @@ -444,9 +446,9 @@ none = NoneType.get() try: memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid memref element type print(e) else: print("Exception not produced") diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py --- a/mlir/test/python/ir/diagnostic_handler.py +++ b/mlir/test/python/ir/diagnostic_handler.py @@ -89,6 +89,7 @@ @run def testDiagnosticNonEmptyNotes(): ctx = Context() + ctx.emit_error_diagnostics = True def callback(d): # CHECK: DIAGNOSTIC: # CHECK: message='arith.addi' op requires one result @@ -99,7 +100,10 @@ return True handler = ctx.attach_diagnostic_handler(callback) loc = Location.unknown(ctx) - Operation.create('arith.addi', loc=loc).verify() + try: + Operation.create('arith.addi', loc=loc).verify() + except MLIRError: + pass assert not handler.had_error # CHECK-LABEL: TEST: testDiagnosticCallbackException diff --git a/mlir/test/python/ir/exception.py b/mlir/test/python/ir/exception.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/ir/exception.py @@ -0,0 +1,78 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f + + +# CHECK-LABEL: TEST: test_exception +@run +def test_exception(): + ctx = Context() + ctx.allow_unregistered_dialects = True + try: + Operation.parse(""" + func.func @foo() { + "test.use"(%0) : (i64) -> () loc("use") + %0 = "test.def"() : () -> i64 loc("def") + return + } + """, context=ctx) + except MLIRError as e: + # CHECK: Exception: < + # CHECK: Unable to parse operation assembly: + # CHECK: error: "use": operand #0 does not dominate this use + # CHECK: note: "use": see current operation: "test.use"(%0) : (i64) -> () + # CHECK: note: "def": operand defined here (op in the same block) + # CHECK: > + print(f"Exception: <{e}>") + + # CHECK: message: Unable to parse operation assembly + print(f"message: {e.message}") + + # CHECK: error_diagnostics: [( + # CHECK-SAME: loc("use"), 'operand #0 does not dominate this use', [ + # CHECK-SAME: (loc("use"), 'see current operation: "test.use"(%0) : (i64) -> ()'), + # CHECK-SAME: (loc("def"), 'operand defined here (op in the same block)') + # CHECK-SAME: ] + # CHECK-SAME: )] + print(f"error_diagnostics: {e.error_diagnostics}") + + +# CHECK-LABEL: test_emit_error_diagnostics +@run +def test_emit_error_diagnostics(): + ctx = Context() + loc = Location.unknown(ctx) + handler_diags = [] + def handler(d): + handler_diags.append(d.message) + return True + ctx.attach_diagnostic_handler(handler) + + try: + Attribute.parse("not an attr", ctx) + except MLIRError as e: + # CHECK: emit_error_diagnostics=False: + # CHECK: e.error_diagnostics: [(loc("not an attr":1:1), 'expected attribute value', [])] + # CHECK: handler_diags: [] + print(f"emit_error_diagnostics=False:") + print(f"e.error_diagnostics: {e.error_diagnostics}") + print(f"handler_diags: {handler_diags}") + + ctx.emit_error_diagnostics = True + try: + Attribute.parse("not an attr", ctx) + except MLIRError as e: + # CHECK: emit_error_diagnostics=True: + # CHECK: e.error_diagnostics: [] + # CHECK: handler_diags: ['expected attribute value'] + print(f"emit_error_diagnostics=True:") + print(f"e.error_diagnostics: {e.error_diagnostics}") + print(f"handler_diags: {handler_diags}") diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -28,14 +28,17 @@ # Verify parse error. # CHECK-LABEL: TEST: testParseError -# CHECK: testParseError: Unable to parse module assembly (see diagnostics) +# CHECK: testParseError: < +# CHECK: Unable to parse module assembly: +# CHECK: error: "-":1:1: expected operation name in quotes +# CHECK: > @run def testParseError(): ctx = Context() try: module = Module.parse(r"""}SYNTAX ERROR{""", ctx) - except ValueError as e: - print("testParseError:", e) + except MLIRError as e: + print(f"testParseError: <{e}>") else: print("Exception not produced") diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -685,8 +685,19 @@ # CHECK: "builtin.module"() ({ # CHECK: }) : () -> () print(invalid_op) - # CHECK: .verify = False - print(f".verify = {invalid_op.operation.verify()}") + try: + invalid_op.verify() + except MLIRError as e: + # CHECK: Exception: < + # CHECK: Verification failed: + # CHECK: error: unknown: 'builtin.module' op requires one region + # CHECK: note: unknown: see current operation: + # CHECK: "builtin.module"() ({ + # CHECK: ^bb0: + # CHECK: }, { + # CHECK: }) : () -> () + # CHECK: > + print(f"Exception: <{e}>") # CHECK-LABEL: TEST: testInvalidModuleStrSoftFails @@ -920,7 +931,7 @@ assert isinstance(m, ModuleOp) try: ModuleOp.parse('"test.foo"() : () -> ()') - except ValueError as e: + except MLIRError as e: # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo' print(f"error: {e}") else: diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -118,7 +118,7 @@ # Verify that a pass manager can execute on IR -# CHECK-LABEL: TEST: testRun +# CHECK-LABEL: TEST: testRunPipeline def testRunPipeline(): with Context(): pm = PassManager.parse("any(print-op-stats{json=false})") @@ -128,3 +128,20 @@ # CHECK: func.func , 1 # CHECK: func.return , 1 run(testRunPipeline) + +# CHECK-LABEL: TEST: testRunPipelineError +@run +def testRunPipelineError(): + with Context() as ctx: + ctx.allow_unregistered_dialects = True + op = Operation.parse('"test.op"() : () -> ()') + pm = PassManager.parse("any(cse)") + try: + pm.run(op) + except MLIRError as e: + # CHECK: Exception: < + # CHECK: Failure while executing pass pipeline: + # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation + # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> () + # CHECK: > + print(f"Exception: <{e}>")