diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -509,6 +509,58 @@ PyThreadContextEntry::popContext(*this); } +py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { + PyDiagnosticHandler *pyHandler = + new PyDiagnosticHandler(get(), std::move(callback)); + py::object pyHandlerObject = + py::cast(pyHandler, py::return_value_policy::take_ownership); + + // We artifically increase the reference count. It will be decremented by the + // delete callback when the context is done with it. + pyHandlerObject.inc_ref(); + + // In these C callbacks, the userData is a PyDiagnosticHandler* that is + // guaranteed to be known to pybind. + auto handlerCallback = + +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { + PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); + py::object pyDiagnosticObject = + py::cast(pyDiagnostic, py::return_value_policy::take_ownership); + + auto *pyHandler = static_cast(userData); + bool result = false; + { + // Since this can be called from arbitrary C++ contexts, always get the + // gil. + py::gil_scoped_acquire gil; + try { + result = py::cast(pyHandler->callback(pyDiagnostic)); + } catch (std::exception &e) { + fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", + e.what()); + pyHandler->hadError = true; + } + } + + pyDiagnostic->invalidate(); + return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); + }; + auto deleteCallback = +[](void *userData) { + auto *pyHandler = static_cast(userData); + assert(pyHandler->registeredID && "handler is not registered"); + pyHandler->registeredID.reset(); + + // Decrement reference, balancing the inc_ref() above. + py::object pyHandlerObject = + py::cast(pyHandler, py::return_value_policy::reference); + pyHandlerObject.dec_ref(); + }; + + pyHandler->registeredID = mlirContextAttachDiagnosticHandler( + get(), handlerCallback, static_cast(pyHandler), deleteCallback); + return pyHandlerObject; +} + PyMlirContext &DefaultingPyMlirContext::resolve() { PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); if (!context) { @@ -654,6 +706,78 @@ stack.pop_back(); } +//------------------------------------------------------------------------------ +// PyDiagnostic* +//------------------------------------------------------------------------------ + +void PyDiagnostic::invalidate() { + valid = false; + if (materializedNotes) { + for (auto ¬eObject : *materializedNotes) { + PyDiagnostic *note = py::cast(noteObject); + note->invalidate(); + } + } +} + +PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, + py::object callback) + : context(context), callback(std::move(callback)) {} + +PyDiagnosticHandler::~PyDiagnosticHandler() {} + +void PyDiagnosticHandler::detach() { + if (!registeredID) + return; + MlirDiagnosticHandlerID localID = *registeredID; + mlirContextDetachDiagnosticHandler(context, localID); + assert(!registeredID && "should have unregistered"); + // Not strictly necessary but keeps stale pointers from being around to cause + // issues. + context = {nullptr}; +} + +void PyDiagnostic::checkValid() { + if (!valid) { + throw std::invalid_argument( + "Diagnostic is invalid (used outside of callback)"); + } +} + +MlirDiagnosticSeverity PyDiagnostic::getSeverity() { + checkValid(); + return mlirDiagnosticGetSeverity(diagnostic); +} + +PyLocation PyDiagnostic::getLocation() { + checkValid(); + MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); + MlirContext context = mlirLocationGetContext(loc); + return PyLocation(PyMlirContext::forContext(context), loc); +} + +py::str PyDiagnostic::getMessage() { + checkValid(); + py::object fileObject = py::module::import("io").attr("StringIO")(); + PyFileAccumulator accum(fileObject, /*binary=*/false); + mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); + return fileObject.attr("getvalue")(); +} + +py::tuple PyDiagnostic::getNotes() { + checkValid(); + if (materializedNotes) + return *materializedNotes; + intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); + materializedNotes = py::tuple(numNotes); + for (intptr_t i = 0; i < numNotes; ++i) { + MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); + py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); + PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); + } + return *materializedNotes; +} + //------------------------------------------------------------------------------ // PyDialect, PyDialectDescriptor, PyDialects //------------------------------------------------------------------------------ @@ -2022,6 +2146,36 @@ //------------------------------------------------------------------------------ void mlir::python::populateIRCore(py::module &m) { + //---------------------------------------------------------------------------- + // Enums. + //---------------------------------------------------------------------------- + py::enum_(m, "DiagnosticSeverity", py::module_local()) + .value("ERROR", MlirDiagnosticError) + .value("WARNING", MlirDiagnosticWarning) + .value("NOTE", MlirDiagnosticNote) + .value("REMARK", MlirDiagnosticRemark); + + //---------------------------------------------------------------------------- + // Mapping of Diagnostics. + //---------------------------------------------------------------------------- + py::class_(m, "Diagnostic", py::module_local()) + .def_property_readonly("severity", &PyDiagnostic::getSeverity) + .def_property_readonly("location", &PyDiagnostic::getLocation) + .def_property_readonly("message", &PyDiagnostic::getMessage) + .def_property_readonly("notes", &PyDiagnostic::getNotes) + .def("__str__", [](PyDiagnostic &self) -> py::str { + if (!self.isValid()) + return ""; + return self.getMessage(); + }); + + py::class_(m, "DiagnosticHandler", py::module_local()) + .def("detach", &PyDiagnosticHandler::detach) + .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) + .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) + .def("__enter__", &PyDiagnosticHandler::contextEnter) + .def("__exit__", &PyDiagnosticHandler::contextExit); + //---------------------------------------------------------------------------- // Mapping of MlirContext. //---------------------------------------------------------------------------- @@ -2077,6 +2231,9 @@ [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) + .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, + py::arg("callback"), + "Attaches a diagnostic handler that will receive callbacks") .def( "enable_multithreading", [](PyMlirContext &self, bool enable) { @@ -2202,7 +2359,8 @@ py::arg("context") = py::none(), kContextGetFileLocationDocstring) .def_static( "fused", - [](const std::vector &pyLocations, llvm::Optional metadata, + [](const std::vector &pyLocations, + llvm::Optional metadata, DefaultingPyMlirContext context) { if (pyLocations.empty()) throw py::value_error("No locations provided"); @@ -2234,6 +2392,12 @@ "context", [](PyLocation &self) { return self.getContext().getObject(); }, "Context that owns the Location") + .def( + "emit_error", + [](PyLocation &self, std::string message) { + mlirEmitError(self, message.c_str()); + }, + py::arg("message"), "Emits an error at this location") .def("__repr__", [](PyLocation &self) { PyPrintAccumulator printAccum; mlirLocationPrint(self, printAccum.getCallback(), diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -15,6 +15,7 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "llvm/ADT/DenseMap.h" @@ -24,6 +25,8 @@ namespace python { class PyBlock; +class PyDiagnostic; +class PyDiagnosticHandler; class PyInsertionPoint; class PyLocation; class DefaultingPyLocation; @@ -206,6 +209,10 @@ void contextExit(pybind11::object excType, pybind11::object excVal, pybind11::object excTb); + /// Attaches a Python callback as a diagnostic handler, returning a + /// registration object (internally a PyDiagnosticHandler). + pybind11::object attachDiagnosticHandler(pybind11::object callback); + private: PyMlirContext(MlirContext context); // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, @@ -266,6 +273,75 @@ PyMlirContextRef contextRef; }; +/// Python class mirroring the C MlirDiagnostic struct. Note that these structs +/// are only valid for the duration of a diagnostic callback and attempting +/// to access them outside of that will raise an exception. This applies to +/// nested diagnostics (in the notes) as well. +class PyDiagnostic { +public: + PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} + void invalidate(); + bool isValid() { return valid; } + MlirDiagnosticSeverity getSeverity(); + PyLocation getLocation(); + pybind11::str getMessage(); + pybind11::tuple getNotes(); + +private: + MlirDiagnostic diagnostic; + + void checkValid(); + /// If notes have been materialized from the diagnostic, then this will + /// be populated with the corresponding objects (all castable to + /// PyDiagnostic). + llvm::Optional materializedNotes; + bool valid = true; +}; + +/// Represents a diagnostic handler attached to the context. The handler's +/// callback will be invoked with PyDiagnostic instances until the detach() +/// method is called or the context is destroyed. A diagnostic handler can be +/// the subject of a `with` block, which will detach it when the block exits. +/// +/// Since diagnostic handlers can call back into Python code which can do +/// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, +/// etc), this is generally not deemed to be a great user-level API. Users +/// should generally use some form of DiagnosticCollector. If the handler raises +/// any exceptions, they will just be emitted to stderr and dropped. +/// +/// The unique usage of this class means that its lifetime management is +/// different from most other parts of the API. Instances are always created +/// in an attached state and can transition to a detached state by either: +/// a) The context being destroyed and unregistering all handlers. +/// b) An explicit call to detach(). +/// The object may remain live from a Python perspective for an arbitrary time +/// after detachment, but there is nothing the user can do with it (since there +/// is no way to attach an existing handler object). +class PyDiagnosticHandler { +public: + PyDiagnosticHandler(MlirContext context, pybind11::object callback); + ~PyDiagnosticHandler(); + + bool isAttached() { return registeredID.hasValue(); } + bool getHadError() { return hadError; } + + /// Detaches the handler. Does nothing if not attached. + void detach(); + + pybind11::object contextEnter() { return pybind11::cast(this); } + void contextExit(pybind11::object excType, pybind11::object excVal, + pybind11::object excTb) { + detach(); + } + +private: + MlirContext context; + pybind11::object callback; + llvm::Optional registeredID; + bool hadError = false; + friend class PyMlirContext; +}; + /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in /// order to differentiate it from the `Dialect` base class which is extended by /// plugins which extend dialect functionality through extension python code. diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -7,7 +7,7 @@ # * Local edits to signatures and types that MyPy did not auto detect (or # detected incorrectly). -from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence +from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple from typing import overload @@ -43,6 +43,9 @@ "Dialect", "DialectDescriptor", "Dialects", + "Diagnostic", + "DiagnosticHandler", + "DiagnosticSeverity", "DictAttr", "F16Type", "F32Type", @@ -425,8 +428,9 @@ def _get_live_count() -> int: ... def _get_live_module_count(self) -> int: ... def _get_live_operation_count(self) -> int: ... + def attach_diagnostic_handler(self, callback: Callable[["Diagnostic"], bool]) -> "DiagnosticHandler": ... def enable_multithreading(self, enable: bool) -> None: ... - def get_dialect_descriptor(name: dialect_name: str) -> "DialectDescriptor": ... + def get_dialect_descriptor(dialect_name: str) -> "DialectDescriptor": ... def is_registered_operation(self, operation_name: str) -> bool: ... def __enter__(self) -> "Context": ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @@ -479,6 +483,31 @@ def __getattr__(self, arg0: str) -> "Dialect": ... def __getitem__(self, arg0: str) -> "Dialect": ... +class Diagnostic: + @property + def severity(self) -> "DiagnosticSeverity": ... + @property + def location(self) -> "Location": ... + @property + def message(self) -> str: ... + @property + def notes(self) -> Tuple["Diagnostic"]: ... + +class DiagnosticHandler: + def detach(self) -> None: ... + @property + def attached(self) -> bool: ... + @property + def had_error(self) -> bool: ... + def __enter__(self) -> "DiagnosticHandler": ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + +class DiagnosticSeverity: + ERROR: "DiagnosticSeverity" + WARNING: "DiagnosticSeverity" + NOTE: "DiagnosticSeverity" + REMARK: "DiagnosticSeverity" + # TODO: Auto-generated. Audit and fix. class DictAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/ir/diagnostic_handler.py @@ -0,0 +1,172 @@ +# RUN: %PYTHON %s | FileCheck %s + +import gc +from mlir.ir import * + +def run(f): + print("\nTEST:", f.__name__) + f() + gc.collect() + assert Context._get_live_count() == 0 + return f + + +@run +def testLifecycleContextDestroy(): + ctx = Context() + def callback(foo): ... + handler = ctx.attach_diagnostic_handler(callback) + assert handler.attached + # If context is destroyed before the handler, it should auto-detach. + ctx = None + gc.collect() + assert not handler.attached + + # And finally collecting the handler should be fine. + handler = None + gc.collect() + + +@run +def testLifecycleExplicitDetach(): + ctx = Context() + def callback(foo): ... + handler = ctx.attach_diagnostic_handler(callback) + assert handler.attached + handler.detach() + assert not handler.attached + + +@run +def testLifecycleWith(): + ctx = Context() + def callback(foo): ... + with ctx.attach_diagnostic_handler(callback) as handler: + assert handler.attached + assert not handler.attached + + +@run +def testLifecycleWithAndExplicitDetach(): + ctx = Context() + def callback(foo): ... + with ctx.attach_diagnostic_handler(callback) as handler: + assert handler.attached + handler.detach() + assert not handler.attached + + +# CHECK-LABEL: TEST: testDiagnosticCallback +@run +def testDiagnosticCallback(): + ctx = Context() + def callback(d): + # CHECK: DIAGNOSTIC: message='foobar', severity=DiagnosticSeverity.ERROR, loc=loc(unknown) + print(f"DIAGNOSTIC: message='{d.message}', severity={d.severity}, loc={d.location}") + return True + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + loc.emit_error("foobar") + assert not handler.had_error + + +# CHECK-LABEL: TEST: testDiagnosticEmptyNotes +# TODO: Come up with a way to inject a diagnostic with notes from this API. +@run +def testDiagnosticEmptyNotes(): + ctx = Context() + def callback(d): + # CHECK: DIAGNOSTIC: notes=() + print(f"DIAGNOSTIC: notes={d.notes}") + return True + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + loc.emit_error("foobar") + assert not handler.had_error + + +# CHECK-LABEL: TEST: testDiagnosticCallbackException +@run +def testDiagnosticCallbackException(): + ctx = Context() + def callback(d): + raise ValueError("Error in handler") + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + loc.emit_error("foobar") + assert handler.had_error + + +# CHECK-LABEL: TEST: testEscapingDiagnostic +@run +def testEscapingDiagnostic(): + ctx = Context() + diags = [] + def callback(d): + diags.append(d) + return True + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + loc.emit_error("foobar") + assert not handler.had_error + + # CHECK: DIAGNOSTIC: + print(f"DIAGNOSTIC: {str(diags[0])}") + try: + diags[0].severity + raise RuntimeError("expected exception") + except ValueError: + pass + try: + diags[0].location + raise RuntimeError("expected exception") + except ValueError: + pass + try: + diags[0].message + raise RuntimeError("expected exception") + except ValueError: + pass + try: + diags[0].notes + raise RuntimeError("expected exception") + except ValueError: + pass + + + +# CHECK-LABEL: TEST: testDiagnosticReturnTrueHandles +@run +def testDiagnosticReturnTrueHandles(): + ctx = Context() + def callback1(d): + print(f"CALLBACK1: {d}") + return True + def callback2(d): + print(f"CALLBACK2: {d}") + return True + ctx.attach_diagnostic_handler(callback1) + ctx.attach_diagnostic_handler(callback2) + loc = Location.unknown(ctx) + # CHECK-NOT: CALLBACK1 + # CHECK: CALLBACK2: foobar + # CHECK-NOT: CALLBACK1 + loc.emit_error("foobar") + + +# CHECK-LABEL: TEST: testDiagnosticReturnFalseDoesNotHandle +@run +def testDiagnosticReturnFalseDoesNotHandle(): + ctx = Context() + def callback1(d): + print(f"CALLBACK1: {d}") + return True + def callback2(d): + print(f"CALLBACK2: {d}") + return False + ctx.attach_diagnostic_handler(callback1) + ctx.attach_diagnostic_handler(callback2) + loc = Location.unknown(ctx) + # CHECK: CALLBACK2: foobar + # CHECK: CALLBACK1: foobar + loc.emit_error("foobar")