diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -521,6 +521,11 @@ MlirStringCallback callback, void *userData); +/// Same as mlirOperationPrint but writing the bytecode format out. +MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, + MlirStringCallback callback, + void *userData); + /// Prints an operation to stderr. MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -119,6 +119,13 @@ argument. )"; +static const char kOperationPrintBytecodeDocstring[] = + R"(Write the bytecode form of the operation to a file like object. + +Args: + file: The file like object to write to. +)"; + static const char kOperationStrDunderDocstring[] = R"(Gets the assembly form of the operation with default options. @@ -1022,6 +1029,14 @@ mlirOpPrintingFlagsDestroy(flags); } +void PyOperationBase::writeBytecode(py::object fileObject) { + PyOperation &operation = getOperation(); + operation.checkValid(); + PyFileAccumulator accum(fileObject, /*binary=*/true); + mlirOperationWriteBytecode(operation, accum.getCallback(), + accum.getUserData()); +} + py::object PyOperationBase::getAsm(bool binary, llvm::Optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, @@ -2627,6 +2642,8 @@ py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, py::arg("assume_verified") = false, kOperationPrintDocstring) + .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), + kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. py::arg("binary") = false, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -512,6 +512,9 @@ bool printGenericOpForm, bool useLocalScope, bool assumeVerified); + // Implement the bound 'writeBytecode' method. + void writeBytecode(pybind11::object fileObject); + /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); void moveBefore(PyOperationBase &other); diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -12,6 +12,7 @@ Support.cpp LINK_LIBS PUBLIC + MLIRBytecodeWriter MLIRIR MLIRParser MLIRSupport diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -10,6 +10,7 @@ #include "mlir-c/Support.h" #include "mlir/AsmParser/AsmParser.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" @@ -23,7 +24,6 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" -#include "llvm/Support/Debug.h" #include using namespace mlir; @@ -485,6 +485,12 @@ unwrap(op)->print(stream, *unwrap(flags)); } +void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + writeBytecodeToFile(unwrap(op), stream); +} + void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } bool mlirOperationVerify(MlirOperation op) { diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -566,6 +566,18 @@ print(str_value.__class__) print(f.getvalue()) + # Test roundtrip to bytecode. + bytecode_stream = io.BytesIO() + module.operation.write_bytecode(bytecode_stream) + bytecode = bytecode_stream.getvalue() + assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR" + module_roundtrip = Module.parse(bytecode, ctx) + f = io.StringIO() + module_roundtrip.operation.print(file=f) + roundtrip_value = f.getvalue() + assert str_value == roundtrip_value, "Mismatch after roundtrip bytecode" + + # Test print to binary file. f = io.BytesIO() # CHECK: diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -398,6 +398,7 @@ includes = ["include"], deps = [ ":AsmParser", + ":BytecodeWriter", ":ConversionPassIncGen", ":FuncDialect", ":InferTypeOpInterface",