diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -48,6 +48,7 @@ }; \ typedef struct name name +DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void); DEFINE_C_API_STRUCT(MlirContext, void); DEFINE_C_API_STRUCT(MlirDialect, void); DEFINE_C_API_STRUCT(MlirDialectRegistry, void); @@ -408,6 +409,24 @@ MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags); +//===----------------------------------------------------------------------===// +// Bytecode printing flags API. +//===----------------------------------------------------------------------===// + +/// Creates new printing flags with defaults, intended for customization. +/// Must be freed with a call to mlirBytecodeWriterConfigDestroy(). +MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig +mlirBytecodeWriterConfigCreate(void); + +/// Destroys printing flags created with mlirBytecodeWriterConfigCreate. +MLIR_CAPI_EXPORTED void +mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config); + +/// Sets the version to emit in the writer config. +MLIR_CAPI_EXPORTED void +mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, + int64_t version); + //===----------------------------------------------------------------------===// // Operation API. //===----------------------------------------------------------------------===// @@ -546,10 +565,27 @@ MlirStringCallback callback, void *userData); -/// Same as mlirOperationPrint but writing the bytecode format out. -MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, - MlirStringCallback callback, - void *userData); +struct MlirBytecodeWriterResult { + int64_t minVersion; +}; +typedef struct MlirBytecodeWriterResult MlirBytecodeWriterResult; + +inline static bool +mlirBytecodeWriterResultGetMinVersion(MlirBytecodeWriterResult res) { + return res.minVersion; +} + +/// Same as mlirOperationPrint but writing the bytecode format and returns the +/// minimum bytecode version the consumer needs to support. +MLIR_CAPI_EXPORTED MlirBytecodeWriterResult mlirOperationWriteBytecode( + MlirOperation op, MlirStringCallback callback, void *userData); + +/// Same as mlirOperationWriteBytecode but with writer config. +MLIR_CAPI_EXPORTED MlirBytecodeWriterResult +mlirOperationWriteBytecodeWithConfig(MlirOperation op, + MlirBytecodeWriterConfig config, + MlirStringCallback callback, + void *userData); /// Prints an operation to stderr. MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op); diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -233,6 +233,9 @@ /// guaranteed to not die before the end of the bytecode process. The blob is /// written as-is, with no additional compression or compaction. virtual void writeOwnedBlob(ArrayRef blob) = 0; + + /// Return the bytecode version being emitted for. + virtual int64_t getBytecodeVersion() const = 0; }; //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -40,6 +40,12 @@ /// Return an instance of the internal implementation. const Impl &getImpl() const { return *impl; } + /// Set the desired bytecode version to emit. This function clamps the version + /// to the existing version if larger than existing. The desired version may + /// not be used depending on the features used and the actual version required + /// is returned by bytecode writer entry point. + void setDesiredBytecodeVersion(int64_t bytecodeVersion); + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// @@ -69,14 +75,21 @@ std::unique_ptr impl; }; +/// Status of bytecode serialization. +struct BytecodeWriterResult { + /// The minimum version of the reader required to read the serialized file. + int64_t minVersion; +}; + //===----------------------------------------------------------------------===// // Entry Points //===----------------------------------------------------------------------===// /// Write the bytecode for the given operation to the provided output stream. /// For streams where it matters, the given stream should be in "binary" mode. -void writeBytecodeToFile(Operation *op, raw_ostream &os, - const BytecodeWriterConfig &config = {}); +BytecodeWriterResult +writeBytecodeToFile(Operation *op, raw_ostream &os, + const BytecodeWriterConfig &config = {}); } // namespace mlir diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -15,11 +15,13 @@ #ifndef MLIR_CAPI_IR_H #define MLIR_CAPI_IR_H +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig) DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry) diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -90,6 +90,15 @@ } StringRef getIrdlFile() const { return irdlFileFlag; } + /// Set the bytecode version to emit. + MlirOptMainConfig &setEmitBytecodeVersion(int64_t version) { + emitBytecodeVersion = version; + return *this; + } + std::optional bytecodeVersionToEmit() const { + return emitBytecodeVersion; + } + /// Set the callback to populate the pass manager. MlirOptMainConfig & setPassPipelineSetupFn(std::function callback) { @@ -168,6 +177,9 @@ /// Location Breakpoints to filter the action logging. std::vector logActionLocationFilter; + /// Emit bytecode at given version. + std::optional emitBytecodeVersion = std::nullopt; + /// The callback to populate the pass manager. std::function passPipelineCallback; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -124,6 +124,9 @@ Args: file: The file like object to write to. + desired_version: The version of bytecode to emit. +Returns: + The bytecode writer status. )"; static const char kOperationStrDunderDocstring[] = @@ -1131,12 +1134,21 @@ mlirOpPrintingFlagsDestroy(flags); } -void PyOperationBase::writeBytecode(const py::object &fileObject) { +MlirBytecodeWriterResult +PyOperationBase::writeBytecode(const py::object &fileObject, + std::optional bytecodeVersion) { PyOperation &operation = getOperation(); operation.checkValid(); PyFileAccumulator accum(fileObject, /*binary=*/true); - mlirOperationWriteBytecode(operation, accum.getCallback(), - accum.getUserData()); + + if (!bytecodeVersion.has_value()) + return mlirOperationWriteBytecode(operation, accum.getCallback(), + accum.getUserData()); + + MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate(); + mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion); + return mlirOperationWriteBytecodeWithConfig( + operation, config, accum.getCallback(), accum.getUserData()); } py::object PyOperationBase::getAsm(bool binary, @@ -2757,6 +2769,7 @@ py::arg("use_local_scope") = false, py::arg("assume_verified") = false, kOperationPrintDocstring) .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), + py::arg("desired_version") = py::none(), kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. @@ -3365,6 +3378,10 @@ py::arg("from_op"), py::arg("all_sym_uses_visible"), py::arg("callback")); + py::class_(m, "BytecodeResult", py::module_local()) + .def("min_version", + [](MlirBytecodeWriterResult &res) { return res.minVersion; }); + // Container bindings. PyBlockArgumentList::bind(m); PyBlockIterator::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -554,7 +554,9 @@ bool assumeVerified); // Implement the bound 'writeBytecode' method. - void writeBytecode(const pybind11::object &fileObject); + MlirBytecodeWriterResult + writeBytecode(const pybind11::object &fileObject, + std::optional bytecodeVersion); /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -27,6 +27,10 @@ struct BytecodeWriterConfig::Impl { Impl(StringRef producer) : producer(producer) {} + /// Version to use when writing. + /// Note: This only differs from kVersion if a specific version is set. + int64_t bytecodeVersion = bytecode::kVersion; + /// The producer of the bytecode. StringRef producer; @@ -48,6 +52,12 @@ impl->externalResourcePrinters.emplace_back(std::move(printer)); } +void BytecodeWriterConfig::setDesiredBytecodeVersion(int64_t bytecodeVersion) { + // Clamp to current version. + impl->bytecodeVersion = + std::min(bytecodeVersion, bytecode::kVersion); +} + //===----------------------------------------------------------------------===// // EncodingEmitter //===----------------------------------------------------------------------===// @@ -295,7 +305,8 @@ class DialectWriter : public DialectBytecodeWriter { public: - DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, + DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter, + IRNumberingState &numberingState, StringSectionBuilder &stringSection) : emitter(emitter), numberingState(numberingState), stringSection(stringSection) {} @@ -362,7 +373,10 @@ reinterpret_cast(blob.data()), blob.size())); } + int64_t getBytecodeVersion() const override { return bytecodeVersion; } + private: + int64_t bytecodeVersion; EncodingEmitter &emitter; IRNumberingState &numberingState; StringSectionBuilder &stringSection; @@ -421,11 +435,11 @@ namespace { class BytecodeWriter { public: - BytecodeWriter(Operation *op) : numberingState(op) {} + BytecodeWriter(Operation *op, const BytecodeWriterConfig::Impl &config) + : numberingState(op), config(config) {} /// Write the bytecode for the given root operation. - void write(Operation *rootOp, raw_ostream &os, - const BytecodeWriterConfig::Impl &config); + void write(Operation *rootOp, raw_ostream &os); private: //===--------------------------------------------------------------------===// @@ -449,8 +463,7 @@ //===--------------------------------------------------------------------===// // Resources - void writeResourceSection(Operation *op, EncodingEmitter &emitter, - const BytecodeWriterConfig::Impl &config); + void writeResourceSection(Operation *op, EncodingEmitter &emitter); //===--------------------------------------------------------------------===// // Strings @@ -465,11 +478,13 @@ /// The IR numbering state generated for the root operation. IRNumberingState numberingState; + + /// Configuration dictating bytecode emission. + const BytecodeWriterConfig::Impl &config; }; } // namespace -void BytecodeWriter::write(Operation *rootOp, raw_ostream &os, - const BytecodeWriterConfig::Impl &config) { +void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { EncodingEmitter emitter; // Emit the bytecode file header. This is how we identify the output as a @@ -477,7 +492,7 @@ emitter.emitString("ML\xefR"); // Emit the bytecode version. - emitter.emitVarInt(bytecode::kVersion); + emitter.emitVarInt(config.bytecodeVersion); // Emit the producer. emitter.emitNulTerminatedString(config.producer); @@ -492,7 +507,7 @@ writeIRSection(emitter, rootOp); // Emit the resources section. - writeResourceSection(rootOp, emitter, config); + writeResourceSection(rootOp, emitter); // Emit the string section. writeStringSection(emitter); @@ -540,12 +555,17 @@ // Write the string section and get the ID. size_t nameID = stringSection.insert(dialect.name); + if (config.bytecodeVersion == 0) { + dialectEmitter.emitVarInt(nameID); + continue; + } + // Try writing the version to the versionEmitter. EncodingEmitter versionEmitter; if (dialect.interface) { // The writer used when emitting using a custom bytecode encoding. - DialectWriter versionWriter(versionEmitter, numberingState, - stringSection); + DialectWriter versionWriter(config.bytecodeVersion, versionEmitter, + numberingState, stringSection); dialect.interface->writeVersion(versionWriter); } @@ -586,8 +606,8 @@ bool hasCustomEncoding = false; if (const BytecodeDialectInterface *interface = entry.dialect->interface) { // The writer used when emitting using a custom bytecode encoding. - DialectWriter dialectWriter(attrTypeEmitter, numberingState, - stringSection); + DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, + numberingState, stringSection); if constexpr (std::is_same_v, Type>) { // TODO: We don't currently support custom encoded mutable types. @@ -787,9 +807,8 @@ }; } // namespace -void BytecodeWriter::writeResourceSection( - Operation *op, EncodingEmitter &emitter, - const BytecodeWriterConfig::Impl &config) { +void BytecodeWriter::writeResourceSection(Operation *op, + EncodingEmitter &emitter) { EncodingEmitter resourceEmitter; EncodingEmitter resourceOffsetEmitter; uint64_t prevOffset = 0; @@ -868,8 +887,12 @@ // Entry Points //===----------------------------------------------------------------------===// -void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, - const BytecodeWriterConfig &config) { - BytecodeWriter writer(op); - writer.write(op, os, config.getImpl()); +BytecodeWriterResult +mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, + const BytecodeWriterConfig &config) { + BytecodeWriter writer(op, config.getImpl()); + writer.write(op, os); + // Return the bytecode version emitted - currently there is no additional + // feedback as to minimum beyond the requested one. + return {config.getImpl().bytecodeVersion}; } diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -7,11 +7,13 @@ //===----------------------------------------------------------------------===// #include "IRNumbering.h" +#include "../Encoding.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::bytecode::detail; @@ -41,6 +43,10 @@ } void writeOwnedBlob(ArrayRef blob) override {} + int64_t getBytecodeVersion() const override { + llvm_unreachable("unexpected querying of version in IRNumbering"); + } + /// The parent numbering state that is populated by this writer. IRNumberingState &state; }; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -145,6 +145,23 @@ unwrap(flags)->assumeVerified(); } +//===----------------------------------------------------------------------===// +// Bytecode printing flags API. +//===----------------------------------------------------------------------===// + +MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() { + return wrap(new BytecodeWriterConfig()); +} + +void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) { + delete unwrap(config); +} + +void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, + int64_t version) { + unwrap(flags)->setDesiredBytecodeVersion(version); +} + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// @@ -507,10 +524,25 @@ unwrap(op)->print(stream, *unwrap(flags)); } -void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, - void *userData) { +MlirBytecodeWriterResult mlirOperationWriteBytecode(MlirOperation op, + MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + MlirBytecodeWriterResult res; + BytecodeWriterResult r = writeBytecodeToFile(unwrap(op), stream); + res.minVersion = r.minVersion; + return res; +} + +MlirBytecodeWriterResult mlirOperationWriteBytecodeWithConfig( + MlirOperation op, MlirBytecodeWriterConfig config, + MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); - writeBytecodeToFile(unwrap(op), stream); + BytecodeWriterResult r = + writeBytecodeToFile(unwrap(op), stream, *unwrap(config)); + MlirBytecodeWriterResult res; + res.minVersion = r.minVersion; + return res; } void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -52,6 +52,22 @@ using namespace llvm; namespace { +class BytecodeVersionParser : public cl::parser> { +public: + BytecodeVersionParser(cl::Option &O) + : cl::parser>(O) {} + + bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg, + std::optional &v) { + long long w; + if (getAsSignedInteger(arg, 10, w)) + return O.error("Invalid argument '" + arg + + "', only integer is supported."); + v = w; + return false; + } +}; + /// This class is intended to manage the handling of command line options for /// creating a *-opt config. This is a singleton. struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { @@ -74,6 +90,13 @@ "emit-bytecode", cl::desc("Emit bytecode when generating output"), cl::location(emitBytecodeFlag), cl::init(false)); + static cl::opt, /*ExternalStorage=*/true, + BytecodeVersionParser> + bytecodeVersion( + "emit-bytecode-version", + cl::desc("Use specified bytecode when generating output"), + cl::location(emitBytecodeVersion), cl::init(std::nullopt)); + static cl::opt irdlFile( "irdl-file", cl::desc("IRDL file to register before processing the input"), @@ -241,13 +264,23 @@ TimingScope outputTiming = timing.nest("Output"); if (config.shouldEmitBytecode()) { BytecodeWriterConfig writerConfig(fallbackResourceMap); + if (auto v = config.bytecodeVersionToEmit()) { + writerConfig.setDesiredBytecodeVersion(*v); + // Returns failure if requested version couldn't be used for opt tools. + return success( + writeBytecodeToFile(op.get(), os, writerConfig).minVersion <= *v); + } writeBytecodeToFile(op.get(), os, writerConfig); - } else { - AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, - &fallbackResourceMap); - op.get()->print(os, asmState); - os << '\n'; + return success(); } + + if (config.bytecodeVersionToEmit().has_value()) + return emitError(UnknownLoc::get(pm.getContext())) + << "bytecode version while not emitting bytecode"; + AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, + &fallbackResourceMap); + op.get()->print(os, asmState); + os << '\n'; return success(); } diff --git a/mlir/test/Bytecode/versioning/versioned_bytecode.mlir b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir @@ -0,0 +1,14 @@ +// This file contains test cases related to roundtripping. + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: target=s390x-{{.*}} + +//===--------------------------------------------------------------------===// +// Test roundtrip +//===--------------------------------------------------------------------===// + +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -emit-bytecode \ +// RUN: -emit-bytecode-version=0 | mlir-opt -o %t.1 && \ +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -o %t.2 && \ +// RUN: diff %t.1 %t.2 + diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -571,7 +571,8 @@ # Test roundtrip to bytecode. bytecode_stream = io.BytesIO() - module.operation.write_bytecode(bytecode_stream) + result = module.operation.write_bytecode(bytecode_stream, desired_version=1) + assert result.min_version() == 1, "Requested version not serialized to" bytecode = bytecode_stream.getvalue() assert bytecode.startswith(b'ML\xefR'), "Expected bytecode to start with MLïR" module_roundtrip = Module.parse(bytecode, ctx)