diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -22,6 +22,7 @@ add_library(MLIRBindingsPythonExtension ${PYEXT_LINK_MODE} MainModule.cpp IRModules.cpp + PybindUtils.cpp ) target_include_directories(MLIRBindingsPythonExtension PRIVATE @@ -68,7 +69,6 @@ target_link_libraries(MLIRBindingsPythonExtension PRIVATE - MLIRIR MLIRCAPIIR MLIRCAPIRegistration ${PYEXT_LIBADD} diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModules.h @@ -13,7 +13,8 @@ #include "mlir-c/IR.h" -namespace py = pybind11; +namespace mlir { +namespace python { class PyMlirContext; class PyMlirModule; @@ -48,6 +49,9 @@ MlirModule module; }; -void populateIRSubmodule(py::module &m); +void populateIRSubmodule(pybind11::module &m); + +} // namespace python +} // namespace mlir #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -7,6 +7,61 @@ //===----------------------------------------------------------------------===// #include "IRModules.h" +#include "PybindUtils.h" + +namespace py = pybind11; +using namespace mlir::python; + +//------------------------------------------------------------------------------ +// Docstrings (trivial, non-duplicated docstrings are included inline). +//------------------------------------------------------------------------------ + +static const char kContextParseDocstring[] = + R"(Parses a module's assembly format from a string. + +Returns a new MlirModule or raises a ValueError if the parsing fails. +)"; + +static const char kOperationStrDunderDocstring[] = + R"(Prints the assembly form of the operation with default options. + +If more advanced control over the assembly formatting or I/O options is needed, +use the dedicated print method, which supports keyword arguments to customize +behavior. +)"; + +static const char kDumpDocstring[] = + R"(Dumps a debug representation of the object to stderr.)"; + +//------------------------------------------------------------------------------ +// Conversion utilities. +//------------------------------------------------------------------------------ + +namespace { + +/// Accumulates into a python string from a method that accepts an +/// MlirPrintCallback. +struct PyPrintAccumulator { + py::list parts; + + void *getUserData() { return this; } + + MlirPrintCallback getCallback() { + return [](const char *part, intptr_t size, void *userData) { + PyPrintAccumulator *printAccum = + static_cast(userData); + py::str pyPart(part, size); // Decodes as UTF-8 by default. + printAccum->parts.append(std::move(pyPart)); + }; + } + + py::str join() { + py::str delim("", 0); + return delim.attr("join")(parts); + } +}; + +} // namespace //------------------------------------------------------------------------------ // Context Wrapper Class. @@ -14,6 +69,10 @@ PyMlirModule PyMlirContext::parse(const std::string &module) { auto moduleRef = mlirModuleCreateParse(context, module.c_str()); + if (!moduleRef.ptr) { + throw SetPyError(PyExc_ValueError, + "Unable to parse module assembly (see diagnostics)"); + } return PyMlirModule(moduleRef); } @@ -27,10 +86,22 @@ // Populates the pybind11 IR submodule. //------------------------------------------------------------------------------ -void populateIRSubmodule(py::module &m) { +void mlir::python::populateIRSubmodule(py::module &m) { py::class_(m, "MlirContext") .def(py::init<>()) - .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>()); + .def("parse", &PyMlirContext::parse, py::keep_alive<0, 1>(), + kContextParseDocstring); - py::class_(m, "MlirModule").def("dump", &PyMlirModule::dump); + py::class_(m, "MlirModule") + .def("dump", &PyMlirModule::dump, kDumpDocstring) + .def( + "__str__", + [](PyMlirModule &self) { + auto operation = mlirModuleGetOperation(self.module); + PyPrintAccumulator printAccum; + mlirOperationPrint(operation, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + kOperationStrDunderDocstring); } diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -11,21 +11,14 @@ #include #include "IRModules.h" -#include "mlir/IR/MLIRContext.h" +namespace py = pybind11; using namespace mlir; +using namespace mlir::python; PYBIND11_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; - m.def("get_test_value", []() { - // This is just calling a method on the MLIRContext as a smoketest - // for linkage. - MLIRContext context; - return std::make_tuple(std::string("From the native module"), - context.isMultithreadingEnabled()); - }); - // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); populateIRSubmodule(irModule); diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -0,0 +1,28 @@ +//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H +#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H + +#include + +#include "llvm/ADT/Twine.h" + +namespace mlir { +namespace python { + +// Sets a python error, ready to be thrown to return control back to the +// python runtime. +// Correct usage: +// throw SetPyError(PyExc_ValueError, "Foobar'd"); +pybind11::error_already_set SetPyError(PyObject *excClass, llvm::Twine message); + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H diff --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/PybindUtils.cpp @@ -0,0 +1,18 @@ +//===- PybindUtils.cpp - Utilities for interop with pybind11 --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PybindUtils.h" + +namespace py = pybind11; + +pybind11::error_already_set mlir::python::SetPyError(PyObject *excClass, + llvm::Twine message) { + auto messageStr = message.str(); + PyErr_SetString(excClass, messageStr.c_str()); + return pybind11::error_already_set(); +} diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -126,6 +126,8 @@ MlirModule mlirModuleCreateParse(MlirContext context, const char *module) { OwningModuleRef owning = parseSourceString(module, unwrap(context)); + if (!owning) + return MlirModule{nullptr}; return MlirModule{owning.release().getOperation()}; } diff --git a/mlir/test/Bindings/Python/ir_module_test.py b/mlir/test/Bindings/Python/ir_module_test.py new file mode 100644 --- /dev/null +++ b/mlir/test/Bindings/Python/ir_module_test.py @@ -0,0 +1,49 @@ +# RUN: %PYTHON %s | FileCheck %s + +import mlir + +def run(f): + print("TEST:", f.__name__) + f() + +# Verify successful parse. +# CHECK-LABEL: TEST: testParseSuccess +# CHECK: module @successfulParse +def testParseSuccess(): + ctx = mlir.ir.MlirContext() + module = ctx.parse(r"""module @successfulParse {}""") + module.dump() # Just outputs to stderr. Verifies that it functions. + print(str(module)) + +run(testParseSuccess) + + +# Verify parse error. +# CHECK-LABEL: TEST: testParseError +# CHECK: testParseError: Unable to parse module assembly (see diagnostics) +def testParseError(): + ctx = mlir.ir.MlirContext() + try: + module = ctx.parse(r"""}SYNTAX ERROR{""") + except ValueError as e: + print("testParseError:", e) + else: + print("Exception not produced") + +run(testParseError) + + +# Verify round-trip of ASM that contains unicode. +# Note that this does not test that the print path converts unicode properly +# because MLIR asm always normalizes it to the hex encoding. +# CHECK-LABEL: TEST: testRoundtripUnicode +# CHECK: func @roundtripUnicode() +# CHECK: foo = "\F0\9F\98\8A" +def testRoundtripUnicode(): + ctx = mlir.ir.MlirContext() + module = ctx.parse(r""" + func @roundtripUnicode() attributes { foo = "😊" } + """) + print(str(module)) + +run(testRoundtripUnicode) diff --git a/mlir/test/Bindings/Python/ir_test.py b/mlir/test/Bindings/Python/ir_test.py deleted file mode 100644 --- a/mlir/test/Bindings/Python/ir_test.py +++ /dev/null @@ -1,14 +0,0 @@ -# RUN: %PYTHON %s | FileCheck %s - -import mlir - -TEST_MLIR_ASM = r""" -module { -} -""" - -ctx = mlir.ir.MlirContext() -module = ctx.parse(TEST_MLIR_ASM) -module.dump() -print(bool(module)) -# CHECK: True diff --git a/mlir/test/Bindings/Python/smoke_test.py b/mlir/test/Bindings/Python/smoke_test.py deleted file mode 100644 --- a/mlir/test/Bindings/Python/smoke_test.py +++ /dev/null @@ -1,6 +0,0 @@ -# RUN: %PYTHON %s | FileCheck %s - -import mlir - -# CHECK: From the native module -print(mlir.get_test_value())