diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -123,92 +123,6 @@ equivalent to printing the operation that produced it. )"; -//------------------------------------------------------------------------------ -// Conversion utilities. -//------------------------------------------------------------------------------ - -namespace { - -/// Accumulates into a python string from a method that accepts an -/// MlirStringCallback. -struct PyPrintAccumulator { - py::list parts; - - void *getUserData() { return this; } - - MlirStringCallback getCallback() { - return [](const char *part, intptr_t size, void *userData) { - PyPrintAccumulator *printAccum = - static_cast(userData); - py::str pyPart(part, size); // Decodes as UTF-8 by default. - printAccum->parts.append(std::move(pyPart)); - }; - } - - py::str join() { - py::str delim("", 0); - return delim.attr("join")(parts); - } -}; - -/// Accumulates int a python file-like object, either writing text (default) -/// or binary. -class PyFileAccumulator { -public: - PyFileAccumulator(py::object fileObject, bool binary) - : pyWriteFunction(fileObject.attr("write")), binary(binary) {} - - void *getUserData() { return this; } - - MlirStringCallback getCallback() { - return [](const char *part, intptr_t size, void *userData) { - py::gil_scoped_acquire(); - PyFileAccumulator *accum = static_cast(userData); - if (accum->binary) { - // Note: Still has to copy and not avoidable with this API. - py::bytes pyBytes(part, size); - accum->pyWriteFunction(pyBytes); - } else { - py::str pyStr(part, size); // Decodes as UTF-8 by default. - accum->pyWriteFunction(pyStr); - } - }; - } - -private: - py::object pyWriteFunction; - bool binary; -}; - -/// Accumulates into a python string from a method that is expected to make -/// one (no more, no less) call to the callback (asserts internally on -/// violation). -struct PySinglePartStringAccumulator { - void *getUserData() { return this; } - - MlirStringCallback getCallback() { - return [](const char *part, intptr_t size, void *userData) { - PySinglePartStringAccumulator *accum = - static_cast(userData); - assert(!accum->invoked && - "PySinglePartStringAccumulator called back multiple times"); - accum->invoked = true; - accum->value = py::str(part, size); - }; - } - - py::str takeValue() { - assert(invoked && "PySinglePartStringAccumulator not called back"); - return std::move(value); - } - -private: - py::str value; - bool invoked = false; -}; - -} // namespace - //------------------------------------------------------------------------------ // Utilities. //------------------------------------------------------------------------------ diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -9,11 +9,13 @@ #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H +#include "mlir-c/Support.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/Twine.h" + #include #include -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/Twine.h" namespace mlir { namespace python { @@ -99,4 +101,90 @@ } // namespace detail } // namespace pybind11 +//------------------------------------------------------------------------------ +// Conversion utilities. +//------------------------------------------------------------------------------ + +namespace mlir { + +/// Accumulates into a python string from a method that accepts an +/// MlirStringCallback. +struct PyPrintAccumulator { + pybind11::list parts; + + void *getUserData() { return this; } + + MlirStringCallback getCallback() { + return [](const char *part, intptr_t size, void *userData) { + PyPrintAccumulator *printAccum = + static_cast(userData); + pybind11::str pyPart(part, size); // Decodes as UTF-8 by default. + printAccum->parts.append(std::move(pyPart)); + }; + } + + pybind11::str join() { + pybind11::str delim("", 0); + return delim.attr("join")(parts); + } +}; + +/// Accumulates int a python file-like object, either writing text (default) +/// or binary. +class PyFileAccumulator { +public: + PyFileAccumulator(pybind11::object fileObject, bool binary) + : pyWriteFunction(fileObject.attr("write")), binary(binary) {} + + void *getUserData() { return this; } + + MlirStringCallback getCallback() { + return [](const char *part, intptr_t size, void *userData) { + pybind11::gil_scoped_acquire(); + PyFileAccumulator *accum = static_cast(userData); + if (accum->binary) { + // Note: Still has to copy and not avoidable with this API. + pybind11::bytes pyBytes(part, size); + accum->pyWriteFunction(pyBytes); + } else { + pybind11::str pyStr(part, size); // Decodes as UTF-8 by default. + accum->pyWriteFunction(pyStr); + } + }; + } + +private: + pybind11::object pyWriteFunction; + bool binary; +}; + +/// Accumulates into a python string from a method that is expected to make +/// one (no more, no less) call to the callback (asserts internally on +/// violation). +struct PySinglePartStringAccumulator { + void *getUserData() { return this; } + + MlirStringCallback getCallback() { + return [](const char *part, intptr_t size, void *userData) { + PySinglePartStringAccumulator *accum = + static_cast(userData); + assert(!accum->invoked && + "PySinglePartStringAccumulator called back multiple times"); + accum->invoked = true; + accum->value = pybind11::str(part, size); + }; + } + + pybind11::str takeValue() { + assert(invoked && "PySinglePartStringAccumulator not called back"); + return std::move(value); + } + +private: + pybind11::str value; + bool invoked = false; +}; + +} // namespace mlir + #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H