diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -344,15 +344,10 @@ c.def_static( "get", [](PyType &type, double value, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(type)).cast() + - "' and expected floating point type."); - } + if (mlirAttributeIsNull(attr)) + throw MLIRError("Invalid attribute", errors.take()); return PyFloatAttribute(type.getContext(), attr); }, py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -15,6 +15,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" +#include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" //#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" @@ -38,7 +39,7 @@ static const char kContextParseTypeDocstring[] = R"(Parses the assembly form of a type. -Returns a Type object or raises a ValueError if the type cannot be parsed. +Returns a Type object or raises an MLIRError if the type cannot be parsed. See also: https://mlir.llvm.org/docs/LangRef/#type-system )"; @@ -58,7 +59,7 @@ static const char kModuleParseDocstring[] = R"(Parses a module's assembly format from a string. -Returns a new MlirModule or raises a ValueError if the parsing fails. +Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; @@ -654,6 +655,20 @@ return pyHandlerObject; } +MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, + void *userData) { + auto *self = static_cast(userData); + // Check if the context requested we emit errors instead of capturing them. + if (self->ctx->emitErrorDiagnostics) + return mlirLogicalResultFailure(); + + if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) + return mlirLogicalResultFailure(); + + self->errors.emplace_back(PyDiagnostic(diag).getInfo()); + return mlirLogicalResultSuccess(); +} + PyMlirContext &DefaultingPyMlirContext::resolve() { PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); if (!context) { @@ -870,6 +885,13 @@ return *materializedNotes; } +PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { + std::vector notes; + for (py::handle n : getNotes()) + notes.emplace_back(n.cast().getInfo()); + return {getSeverity(), getLocation(), getMessage(), std::move(notes)}; +} + //------------------------------------------------------------------------------ // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry //------------------------------------------------------------------------------ @@ -1062,13 +1084,12 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName) { + PyMlirContext::ErrorCapture errors(contextRef); MlirOperation op = mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), toMlirStringRef(sourceName)); - // TODO: Include error diagnostic messages in the exception message if (mlirOperationIsNull(op)) - throw py::value_error( - "Unable to parse operation assembly (see diagnostics)"); + throw MLIRError("Unable to parse operation assembly", errors.take()); return PyOperation::createDetached(std::move(contextRef), op); } @@ -1155,6 +1176,14 @@ operation.parentKeepAlive = otherOp.parentKeepAlive; } +bool PyOperationBase::verify() { + PyOperation &op = getOperation(); + PyMlirContext::ErrorCapture errors(op.getContext()); + if (!mlirOperationVerify(op.get())) + throw MLIRError("Verification failed", errors.take()); + return true; +} + std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) @@ -2287,6 +2316,16 @@ return self.getMessage(); }); + py::class_(m, "DiagnosticInfo", + py::module_local()) + .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); })) + .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity) + .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location) + .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message) + .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes) + .def("__str__", + [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); + py::class_(m, "DiagnosticHandler", py::module_local()) .def("detach", &PyDiagnosticHandler::detach) .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) @@ -2375,6 +2414,11 @@ mlirContextAppendDialectRegistry(self.get(), registry); }, py::arg("registry")) + .def_property("emit_error_diagnostics", nullptr, + &PyMlirContext::setEmitErrorDiagnostics, + "Emit error diagnostics to diagnostic handlers. By default " + "error diagnostics are captured and reported through " + "MLIRError exceptions.") .def("load_all_available_dialects", [](PyMlirContext &self) { mlirContextLoadAllAvailableDialects(self.get()); }); @@ -2566,16 +2610,12 @@ .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) .def_static( "parse", - [](const std::string moduleAsm, DefaultingPyMlirContext context) { + [](const std::string &moduleAsm, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); MlirModule module = mlirModuleCreateParse( context->get(), toMlirStringRef(moduleAsm)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirModuleIsNull(module)) { - throw SetPyError( - PyExc_ValueError, - "Unable to parse module assembly (see diagnostics)"); - } + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, py::arg("asm"), py::arg("context") = py::none(), @@ -2724,13 +2764,9 @@ py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, py::arg("assume_verified") = false, kOperationGetAsmDocstring) - .def( - "verify", - [](PyOperationBase &self) { - return mlirOperationVerify(self.getOperation()); - }, - "Verify the operation and return true if it passes, false if it " - "fails.") + .def("verify", &PyOperationBase::verify, + "Verify the operation. Raises MLIRError if verification fails, and " + "returns true otherwise.") .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), "Puts self immediately after the other operation in its parent " "block.") @@ -2833,12 +2869,12 @@ // directly. std::string clsOpName = py::cast(cls.attr("OPERATION_NAME")); - MlirStringRef parsedOpName = + MlirStringRef identifier = mlirIdentifierStr(mlirOperationGetName(*parsed.get())); - if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName))) - throw py::value_error( - "Expected a '" + clsOpName + "' op, got: '" + - std::string(parsedOpName.data, parsedOpName.length) + "'"); + std::string_view parsedOpName(identifier.data, identifier.length); + if (clsOpName != parsedOpName) + throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + + parsedOpName + "'"); return PyOpView::constructDerived(cls, *parsed.get()); }, py::arg("cls"), py::arg("source"), py::kw_only(), @@ -3071,19 +3107,16 @@ .def_static( "parse", [](std::string attrSpec, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); MlirAttribute type = mlirAttributeParseGet( context->get(), toMlirStringRef(attrSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - Twine("Unable to parse attribute: '") + - attrSpec + "'"); - } + if (mlirAttributeIsNull(type)) + throw MLIRError("Unable to parse attribute", errors.take()); return PyAttribute(context->getRef(), type); }, py::arg("asm"), py::arg("context") = py::none(), - "Parses an attribute from an assembly form") + "Parses an attribute from an assembly form. Raises an MLIRError on " + "failure.") .def_property_readonly( "context", [](PyAttribute &self) { return self.getContext().getObject(); }, @@ -3182,15 +3215,11 @@ .def_static( "parse", [](std::string typeSpec, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); MlirType type = mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - Twine("Unable to parse type: '") + typeSpec + - "'"); - } + if (mlirTypeIsNull(type)) + throw MLIRError("Unable to parse type", errors.take()); return PyType(context->getRef(), type); }, py::arg("asm"), py::arg("context") = py::none(), @@ -3342,4 +3371,17 @@ // Attribute builder getter. PyAttrBuilderMap::bind(m); + + py::register_local_exception_translator([](std::exception_ptr p) { + // We can't define exceptions with custom fields through pybind, so instead + // the exception class is defined in python and imported here. + try { + if (p) + std::rethrow_exception(p); + } catch (const MLIRError &e) { + py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("MLIRError")(e.message, e.errorDiagnostics); + PyErr_SetObject(PyExc_Exception, obj.ptr()); + } + }); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -221,6 +221,11 @@ /// registration object (internally a PyDiagnosticHandler). pybind11::object attachDiagnosticHandler(pybind11::object callback); + /// Controls whether error diagnostics should be propagated to diagnostic + /// handlers, instead of being captured by `ErrorCapture`. + void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } + struct ErrorCapture; + private: PyMlirContext(MlirContext context); // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, @@ -248,6 +253,8 @@ llvm::DenseMap>; LiveOperationMap liveOperations; + bool emitErrorDiagnostics = false; + MlirContext context; friend class PyModule; friend class PyOperation; @@ -281,6 +288,34 @@ PyMlirContextRef contextRef; }; +/// Wrapper around an MlirLocation. +class PyLocation : public BaseContextObject { +public: + PyLocation(PyMlirContextRef contextRef, MlirLocation loc) + : BaseContextObject(std::move(contextRef)), loc(loc) {} + + operator MlirLocation() const { return loc; } + MlirLocation get() const { return loc; } + + /// Enter and exit the context manager. + pybind11::object contextEnter(); + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb); + + /// Gets a capsule wrapping the void* within the MlirLocation. + pybind11::object getCapsule(); + + /// Creates a PyLocation from the MlirLocation wrapped by a capsule. + /// Note that PyLocation instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirLocation + /// is taken by calling this function. + static PyLocation createFromCapsule(pybind11::object capsule); + +private: + MlirLocation loc; +}; + /// Python class mirroring the C MlirDiagnostic struct. Note that these structs /// are only valid for the duration of a diagnostic callback and attempting /// to access them outside of that will raise an exception. This applies to @@ -295,6 +330,16 @@ pybind11::str getMessage(); pybind11::tuple getNotes(); + /// Materialized diagnostic information. This is safe to access outside the + /// diagnostic callback. + struct DiagnosticInfo { + MlirDiagnosticSeverity severity; + PyLocation location; + std::string message; + std::vector notes; + }; + DiagnosticInfo getInfo(); + private: MlirDiagnostic diagnostic; @@ -351,6 +396,30 @@ friend class PyMlirContext; }; +/// RAII object that captures any error diagnostics emitted to the provided +/// context. +struct PyMlirContext::ErrorCapture { + ErrorCapture(PyMlirContextRef ctx) + : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( + ctx->get(), handler, /*userData=*/this, + /*deleteUserData=*/nullptr)) {} + ~ErrorCapture() { + mlirContextDetachDiagnosticHandler(ctx->get(), handlerID); + assert(errors.empty() && "unhandled captured errors"); + } + + std::vector take() { + return std::move(errors); + }; + +private: + PyMlirContextRef ctx; + MlirDiagnosticHandlerID handlerID; + std::vector errors; + + static MlirLogicalResult handler(MlirDiagnostic diag, void *userData); +}; + /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in /// order to differentiate it from the `Dialect` base class which is extended by /// plugins which extend dialect functionality through extension python code. @@ -416,34 +485,6 @@ MlirDialectRegistry registry; }; -/// Wrapper around an MlirLocation. -class PyLocation : public BaseContextObject { -public: - PyLocation(PyMlirContextRef contextRef, MlirLocation loc) - : BaseContextObject(std::move(contextRef)), loc(loc) {} - - operator MlirLocation() const { return loc; } - MlirLocation get() const { return loc; } - - /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); - - /// Gets a capsule wrapping the void* within the MlirLocation. - pybind11::object getCapsule(); - - /// Creates a PyLocation from the MlirLocation wrapped by a capsule. - /// Note that PyLocation instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirLocation - /// is taken by calling this function. - static PyLocation createFromCapsule(pybind11::object capsule); - -private: - MlirLocation loc; -}; - /// Used in function arguments when None should resolve to the current context /// manager set instance. class DefaultingPyLocation @@ -519,6 +560,10 @@ void moveAfter(PyOperationBase &other); void moveBefore(PyOperationBase &other); + /// Verify the operation. Throws `MLIRError` if verification fails, and + /// returns `true` otherwise. + bool verify(); + /// Each must provide access to the raw Operation. virtual PyOperation &getOperation() = 0; }; @@ -1073,6 +1118,16 @@ MlirSymbolTable symbolTable; }; +/// Custom exception that allows access to error diagnostic information. This is +/// converted to the `ir.MLIRError` python exception when thrown. +struct MLIRError { + MLIRError(llvm::Twine message, + std::vector &&errorDiagnostics = {}) + : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} + std::string message; + std::vector errorDiagnostics; +}; + void populateIRAffine(pybind11::module &m); void populateIRAttributes(pybind11::module &m); void populateIRCore(pybind11::module &m); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -407,17 +407,11 @@ "get", [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyVectorType(elementType.getContext(), t); }, py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), @@ -438,20 +432,12 @@ "get", [](std::vector shape, PyType &elementType, std::optional &encodingAttr, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirType t = mlirRankedTensorTypeGetChecked( loc, shape.size(), shape.data(), elementType, encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyRankedTensorType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), @@ -479,18 +465,10 @@ c.def_static( "get", [](PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyUnrankedTensorType(elementType.getContext(), t); }, py::arg("element_type"), py::arg("loc") = py::none(), @@ -511,23 +489,15 @@ [](std::vector shape, PyType &elementType, PyAttribute *layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); MlirAttribute memSpaceAttr = memorySpace ? *memorySpace : mlirAttributeGetNull(); MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), shape.data(), layoutAttr, memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), @@ -570,23 +540,15 @@ "get", [](PyType &elementType, PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirAttribute memSpaceAttr = {}; if (memorySpace) memSpaceAttr = *memorySpace; MlirType t = mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyUnrankedMemRefType(elementType.getContext(), t); }, py::arg("element_type"), py::arg("memory_space"), diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -117,15 +117,16 @@ .def( "run", [](PyPassManager &passManager, PyOperationBase &op) { + PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( passManager.get(), op.getOperation().get()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_RuntimeError, - "Failure while executing pass pipeline."); + throw MLIRError("Failure while executing pass pipeline", + errors.take()); }, py::arg("operation"), - "Run the pass manager on the provided operation, throw a " - "RuntimeError on failure.") + "Run the pass manager on the provided operation, raising an " + "MLIRError on failure.") .def( "__str__", [](PyPassManager &self) { diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -100,8 +100,29 @@ # all dialects. It is being done here in order to preserve existing # behavior. See: https://github.com/llvm/llvm-project/issues/56037 self.load_all_available_dialects() - ir.Context = Context + class MLIRError(Exception): + """ + An exception with diagnostic information. Has the following fields: + message: str + error_diagnostics: List[ir.DiagnosticInfo] + """ + def __init__(self, message, error_diagnostics): + self.message = message + self.error_diagnostics = error_diagnostics + super().__init__(message, error_diagnostics) + + def __str__(self): + s = self.message + if self.error_diagnostics: + s += ':' + for diag in self.error_diagnostics: + s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ') + for note in diag.notes: + s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ') + return s + ir.MLIRError = MLIRError + _site_initialize() diff --git a/mlir/test/python/ir/attributes.py b/mlir/test/python/ir/attributes.py --- a/mlir/test/python/ir/attributes.py +++ b/mlir/test/python/ir/attributes.py @@ -28,16 +28,17 @@ # CHECK-LABEL: TEST: testParseError -# TODO: Hook the diagnostic manager to capture a more meaningful error -# message. @run def testParseError(): with Context(): try: t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") - except ValueError as e: - # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' - print("testParseError:", e) + except MLIRError as e: + # CHECK: testParseError: < + # CHECK: Unable to parse attribute: + # CHECK: error: "BAD_ATTR_DOES_NOT_EXIST":1:1: expected attribute value + # CHECK: > + print(f"testParseError: <{e}>") else: print("Exception not produced") @@ -180,8 +181,9 @@ try: fattr_invalid = FloatAttr.get( IntegerType.get_signless(32), 42) - except ValueError as e: - # CHECK: invalid 'Type(i32)' and expected floating point type. + except MLIRError as e: + # CHECK: Invalid attribute: + # CHECK: error: unknown: expected floating point type print(e) else: print("Exception not produced") diff --git a/mlir/test/python/ir/builtin_types.py b/mlir/test/python/ir/builtin_types.py --- a/mlir/test/python/ir/builtin_types.py +++ b/mlir/test/python/ir/builtin_types.py @@ -26,16 +26,17 @@ # CHECK-LABEL: TEST: testParseError -# TODO: Hook the diagnostic manager to capture a more meaningful error -# message. @run def testParseError(): ctx = Context() try: t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx) - except ValueError as e: - # CHECK: Unable to parse type: 'BAD_TYPE_DOES_NOT_EXIST' - print("testParseError:", e) + except MLIRError as e: + # CHECK: testParseError: < + # CHECK: Unable to parse type: + # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type + # CHECK: > + print(f"testParseError: <{e}>") else: print("Exception not produced") @@ -292,8 +293,9 @@ none = NoneType.get() try: vector_invalid = VectorType.get(shape, none) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point or integer type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: vector elements must be int/index/float type but got 'none' print(e) else: print("Exception not produced") @@ -313,9 +315,9 @@ none = NoneType.get() try: tensor_invalid = RankedTensorType.get(shape, none) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid tensor element type: 'none' print(e) else: print("Exception not produced") @@ -361,9 +363,9 @@ none = NoneType.get() try: tensor_invalid = UnrankedTensorType.get(none) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid tensor element type: 'none' print(e) else: print("Exception not produced") @@ -400,9 +402,9 @@ none = NoneType.get() try: memref_invalid = MemRefType.get(shape, none) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid memref element type print(e) else: print("Exception not produced") @@ -444,9 +446,9 @@ none = NoneType.get() try: memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) - except ValueError as e: - # CHECK: invalid 'Type(none)' and expected floating point, integer, vector - # CHECK: or complex type. + except MLIRError as e: + # CHECK: Invalid type: + # CHECK: error: unknown: invalid memref element type print(e) else: print("Exception not produced") diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py --- a/mlir/test/python/ir/diagnostic_handler.py +++ b/mlir/test/python/ir/diagnostic_handler.py @@ -89,6 +89,7 @@ @run def testDiagnosticNonEmptyNotes(): ctx = Context() + ctx.emit_error_diagnostics = True def callback(d): # CHECK: DIAGNOSTIC: # CHECK: message='arith.addi' op requires one result @@ -99,7 +100,10 @@ return True handler = ctx.attach_diagnostic_handler(callback) loc = Location.unknown(ctx) - Operation.create('arith.addi', loc=loc).verify() + try: + Operation.create('arith.addi', loc=loc).verify() + except MLIRError: + pass assert not handler.had_error # CHECK-LABEL: TEST: testDiagnosticCallbackException diff --git a/mlir/test/python/ir/exception.py b/mlir/test/python/ir/exception.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/ir/exception.py @@ -0,0 +1,77 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f + + +# CHECK-LABEL: TEST: test_exception +@run +def test_exception(): + ctx = Context() + ctx.allow_unregistered_dialects = True + try: + Operation.parse(""" + func.func @foo() { + "test.use"(%0) : (i64) -> () loc("use") + %0 = "test.def"() : () -> i64 loc("def") + return + } + """, context=ctx) + except MLIRError as e: + # CHECK: Exception: < + # CHECK: Unable to parse operation assembly: + # CHECK: error: "use": operand #0 does not dominate this use + # CHECK: note: "use": see current operation: "test.use"(%0) : (i64) -> () + # CHECK: note: "def": operand defined here (op in the same block) + # CHECK: > + print(f"Exception: <{e}>") + + # CHECK: message: Unable to parse operation assembly + print(f"message: {e.message}") + + # CHECK: error_diagnostics[0]: loc("use") operand #0 does not dominate this use + # CHECK: error_diagnostics[0].notes[0]: loc("use") see current operation: "test.use"(%0) : (i64) -> () + # CHECK: error_diagnostics[0].notes[1]: loc("def") operand defined here (op in the same block) + print("error_diagnostics[0]: ", e.error_diagnostics[0].location, e.error_diagnostics[0].message) + print("error_diagnostics[0].notes[0]: ", e.error_diagnostics[0].notes[0].location, e.error_diagnostics[0].notes[0].message) + print("error_diagnostics[0].notes[1]: ", e.error_diagnostics[0].notes[1].location, e.error_diagnostics[0].notes[1].message) + + +# CHECK-LABEL: test_emit_error_diagnostics +@run +def test_emit_error_diagnostics(): + ctx = Context() + loc = Location.unknown(ctx) + handler_diags = [] + def handler(d): + handler_diags.append(str(d)) + return True + ctx.attach_diagnostic_handler(handler) + + try: + Attribute.parse("not an attr", ctx) + except MLIRError as e: + # CHECK: emit_error_diagnostics=False: + # CHECK: e.error_diagnostics: ['expected attribute value'] + # CHECK: handler_diags: [] + print(f"emit_error_diagnostics=False:") + print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}") + print(f"handler_diags: {handler_diags}") + + ctx.emit_error_diagnostics = True + try: + Attribute.parse("not an attr", ctx) + except MLIRError as e: + # CHECK: emit_error_diagnostics=True: + # CHECK: e.error_diagnostics: [] + # CHECK: handler_diags: ['expected attribute value'] + print(f"emit_error_diagnostics=True:") + print(f"e.error_diagnostics: {[str(diag) for diag in e.error_diagnostics]}") + print(f"handler_diags: {handler_diags}") diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -28,14 +28,17 @@ # Verify parse error. # CHECK-LABEL: TEST: testParseError -# CHECK: testParseError: Unable to parse module assembly (see diagnostics) +# CHECK: testParseError: < +# CHECK: Unable to parse module assembly: +# CHECK: error: "-":1:1: expected operation name in quotes +# CHECK: > @run def testParseError(): ctx = Context() try: module = Module.parse(r"""}SYNTAX ERROR{""", ctx) - except ValueError as e: - print("testParseError:", e) + except MLIRError as e: + print(f"testParseError: <{e}>") else: print("Exception not produced") diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -685,8 +685,19 @@ # CHECK: "builtin.module"() ({ # CHECK: }) : () -> () print(invalid_op) - # CHECK: .verify = False - print(f".verify = {invalid_op.operation.verify()}") + try: + invalid_op.verify() + except MLIRError as e: + # CHECK: Exception: < + # CHECK: Verification failed: + # CHECK: error: unknown: 'builtin.module' op requires one region + # CHECK: note: unknown: see current operation: + # CHECK: "builtin.module"() ({ + # CHECK: ^bb0: + # CHECK: }, { + # CHECK: }) : () -> () + # CHECK: > + print(f"Exception: <{e}>") # CHECK-LABEL: TEST: testInvalidModuleStrSoftFails @@ -920,7 +931,7 @@ assert isinstance(m, ModuleOp) try: ModuleOp.parse('"test.foo"() : () -> ()') - except ValueError as e: + except MLIRError as e: # CHECK: error: Expected a 'builtin.module' op, got: 'test.foo' print(f"error: {e}") else: diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -118,7 +118,7 @@ # Verify that a pass manager can execute on IR -# CHECK-LABEL: TEST: testRun +# CHECK-LABEL: TEST: testRunPipeline def testRunPipeline(): with Context(): pm = PassManager.parse("any(print-op-stats{json=false})") @@ -128,3 +128,20 @@ # CHECK: func.func , 1 # CHECK: func.return , 1 run(testRunPipeline) + +# CHECK-LABEL: TEST: testRunPipelineError +@run +def testRunPipelineError(): + with Context() as ctx: + ctx.allow_unregistered_dialects = True + op = Operation.parse('"test.op"() : () -> ()') + pm = PassManager.parse("any(cse)") + try: + pm.run(op) + except MLIRError as e: + # CHECK: Exception: < + # CHECK: Failure while executing pass pipeline: + # CHECK: error: "-":1:1: 'test.op' op trying to schedule a pass on an unregistered operation + # CHECK: note: "-":1:1: see current operation: "test.op"() : () -> () + # CHECK: > + print(f"Exception: <{e}>")