diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -547,9 +547,8 @@ void *userData); /// Same as mlirOperationPrint but writing the bytecode format out. -MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, - MlirStringCallback callback, - void *userData); +MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecode( + MlirOperation op, MlirStringCallback callback, void *userData); /// Prints an operation to stderr. MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op); diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -233,6 +233,9 @@ /// guaranteed to not die before the end of the bytecode process. The blob is /// written as-is, with no additional compression or compaction. virtual void writeOwnedBlob(ArrayRef blob) = 0; + + /// Return the bytecode version being emitted for. + virtual int64_t getBytecodeVersion() const = 0; }; //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -40,6 +40,10 @@ /// Return an instance of the internal implementation. const Impl &getImpl() const { return *impl; } + /// Set the bytecode version to emit. This function clamps the version to the + /// existing version if larger than existing. + void setBytecodeVersion(int64_t bytecodeVersion); + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// @@ -75,8 +79,10 @@ /// Write the bytecode for the given operation to the provided output stream. /// For streams where it matters, the given stream should be in "binary" mode. -void writeBytecodeToFile(Operation *op, raw_ostream &os, - const BytecodeWriterConfig &config = {}); +/// Returns the minimum bytecode version required to read the serialized +/// operation. +int64_t writeBytecodeToFile(Operation *op, raw_ostream &os, + const BytecodeWriterConfig &config = {}); } // namespace mlir diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -74,6 +74,15 @@ } bool shouldEmitBytecode() const { return emitBytecodeFlag; } + /// Set the bytecode version to emit. + MlirOptMainConfig &setEmitBytecodeVersion(int64_t version) { + emitBytecodeVersion = version; + return *this; + } + std::optional bytecodeVersionToEmit() const { + return emitBytecodeVersion; + } + /// Set the callback to populate the pass manager. MlirOptMainConfig & setPassPipelineSetupFn(std::function callback) { @@ -149,6 +158,9 @@ /// Emit bytecode instead of textual assembly when generating output. bool emitBytecodeFlag = false; + /// Emit bytecode at given version. + std::optional emitBytecodeVersion = std::nullopt; + /// The callback to populate the pass manager. std::function passPipelineCallback; diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -27,6 +27,10 @@ struct BytecodeWriterConfig::Impl { Impl(StringRef producer) : producer(producer) {} + /// Version to use when writing. + /// Note: This is only set if a specific set version can be emitted. + int64_t bytecodeVersion = bytecode::kVersion; + /// The producer of the bytecode. StringRef producer; @@ -48,6 +52,12 @@ impl->externalResourcePrinters.emplace_back(std::move(printer)); } +void BytecodeWriterConfig::setBytecodeVersion(int64_t bytecodeVersion) { + // Clamp to current version. + impl->bytecodeVersion = + std::min(bytecodeVersion, bytecode::kVersion); +} + //===----------------------------------------------------------------------===// // EncodingEmitter //===----------------------------------------------------------------------===// @@ -295,7 +305,8 @@ class DialectWriter : public DialectBytecodeWriter { public: - DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, + DialectWriter(int64_t bytecodeVersion, EncodingEmitter &emitter, + IRNumberingState &numberingState, StringSectionBuilder &stringSection) : emitter(emitter), numberingState(numberingState), stringSection(stringSection) {} @@ -362,7 +373,10 @@ reinterpret_cast(blob.data()), blob.size())); } + int64_t getBytecodeVersion() const override { return bytecodeVersion; } + private: + int64_t bytecodeVersion; EncodingEmitter &emitter; IRNumberingState &numberingState; StringSectionBuilder &stringSection; @@ -421,11 +435,11 @@ namespace { class BytecodeWriter { public: - BytecodeWriter(Operation *op) : numberingState(op) {} + BytecodeWriter(Operation *op, const BytecodeWriterConfig::Impl &config) + : numberingState(op), config(config) {} /// Write the bytecode for the given root operation. - void write(Operation *rootOp, raw_ostream &os, - const BytecodeWriterConfig::Impl &config); + void write(Operation *rootOp, raw_ostream &os); private: //===--------------------------------------------------------------------===// @@ -449,8 +463,7 @@ //===--------------------------------------------------------------------===// // Resources - void writeResourceSection(Operation *op, EncodingEmitter &emitter, - const BytecodeWriterConfig::Impl &config); + void writeResourceSection(Operation *op, EncodingEmitter &emitter); //===--------------------------------------------------------------------===// // Strings @@ -465,11 +478,13 @@ /// The IR numbering state generated for the root operation. IRNumberingState numberingState; + + /// Configuration dictating bytecode emission. + const BytecodeWriterConfig::Impl &config; }; } // namespace -void BytecodeWriter::write(Operation *rootOp, raw_ostream &os, - const BytecodeWriterConfig::Impl &config) { +void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { EncodingEmitter emitter; // Emit the bytecode file header. This is how we identify the output as a @@ -477,7 +492,7 @@ emitter.emitString("ML\xefR"); // Emit the bytecode version. - emitter.emitVarInt(bytecode::kVersion); + emitter.emitVarInt(config.bytecodeVersion); // Emit the producer. emitter.emitNulTerminatedString(config.producer); @@ -492,7 +507,7 @@ writeIRSection(emitter, rootOp); // Emit the resources section. - writeResourceSection(rootOp, emitter, config); + writeResourceSection(rootOp, emitter); // Emit the string section. writeStringSection(emitter); @@ -540,12 +555,17 @@ // Write the string section and get the ID. size_t nameID = stringSection.insert(dialect.name); + if (config.bytecodeVersion == 0) { + dialectEmitter.emitVarInt(nameID); + continue; + } + // Try writing the version to the versionEmitter. EncodingEmitter versionEmitter; if (dialect.interface) { // The writer used when emitting using a custom bytecode encoding. - DialectWriter versionWriter(versionEmitter, numberingState, - stringSection); + DialectWriter versionWriter(config.bytecodeVersion, versionEmitter, + numberingState, stringSection); dialect.interface->writeVersion(versionWriter); } @@ -586,8 +606,8 @@ bool hasCustomEncoding = false; if (const BytecodeDialectInterface *interface = entry.dialect->interface) { // The writer used when emitting using a custom bytecode encoding. - DialectWriter dialectWriter(attrTypeEmitter, numberingState, - stringSection); + DialectWriter dialectWriter(config.bytecodeVersion, attrTypeEmitter, + numberingState, stringSection); if constexpr (std::is_same_v, Type>) { // TODO: We don't currently support custom encoded mutable types. @@ -787,9 +807,8 @@ }; } // namespace -void BytecodeWriter::writeResourceSection( - Operation *op, EncodingEmitter &emitter, - const BytecodeWriterConfig::Impl &config) { +void BytecodeWriter::writeResourceSection(Operation *op, + EncodingEmitter &emitter) { EncodingEmitter resourceEmitter; EncodingEmitter resourceOffsetEmitter; uint64_t prevOffset = 0; @@ -868,8 +887,11 @@ // Entry Points //===----------------------------------------------------------------------===// -void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, - const BytecodeWriterConfig &config) { - BytecodeWriter writer(op); - writer.write(op, os, config.getImpl()); +int64_t mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, + const BytecodeWriterConfig &config) { + BytecodeWriter writer(op, config.getImpl()); + writer.write(op, os); + // Return the bytecode version emitted - currently there is no additional + // feedback as to minimum beyond the requested one. + return config.getImpl().bytecodeVersion; } diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -7,11 +7,13 @@ //===----------------------------------------------------------------------===// #include "IRNumbering.h" +#include "../Encoding.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::bytecode::detail; @@ -41,6 +43,10 @@ } void writeOwnedBlob(ArrayRef blob) override {} + int64_t getBytecodeVersion() const override { + llvm_unreachable("unexpected querying of version in IRNumbering"); + } + /// The parent numbering state that is populated by this writer. IRNumberingState &state; }; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -507,10 +507,11 @@ unwrap(op)->print(stream, *unwrap(flags)); } -void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, - void *userData) { +MlirLogicalResult mlirOperationWriteBytecode(MlirOperation op, + MlirStringCallback callback, + void *userData) { detail::CallbackOstream stream(callback, userData); - writeBytecodeToFile(unwrap(op), stream); + return wrap(writeBytecodeToFile(unwrap(op), stream)); } void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -896,7 +896,9 @@ std::string rawBytecodeBuffer; llvm::raw_string_ostream os(rawBytecodeBuffer); - writeBytecodeToFile(&parsedIR.front(), os, writerConfig); + if (failed(writeBytecodeToFile(&parsedIR.front(), os, writerConfig))) + return llvm::make_error("bytecode serialization failed", + lsp::ErrorCode::RequestFailed); result.output = llvm::encodeBase64(rawBytecodeBuffer); } return result; diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -43,6 +43,22 @@ using namespace llvm; namespace { +class BytecodeVersionParser : public cl::parser> { +public: + BytecodeVersionParser(cl::Option &O) + : cl::parser>(O) {} + + bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg, + std::optional &v) { + long long w; + if (getAsSignedInteger(arg, 10, w)) + return O.error("Invalid argument '" + arg + + "', only integer is supported."); + v = w; + return false; + } +}; + /// This class is intended to manage the handling of command line options for /// creating a *-opt config. This is a singleton. struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { @@ -65,6 +81,13 @@ "emit-bytecode", cl::desc("Emit bytecode when generating output"), cl::location(emitBytecodeFlag), cl::init(false)); + static cl::opt, /*ExternalStorage=*/true, + BytecodeVersionParser> + bytecodeVersion( + "emit-bytecode-version", + cl::desc("Use specified bytecode when generating output"), + cl::location(emitBytecodeVersion), cl::init(std::nullopt)); + static cl::opt explicitModule( "no-implicit-module", cl::desc("Disable implicit addition of a top-level module op during " @@ -180,13 +203,22 @@ TimingScope outputTiming = timing.nest("Output"); if (config.shouldEmitBytecode()) { BytecodeWriterConfig writerConfig(fallbackResourceMap); + if (auto v = config.bytecodeVersionToEmit()) { + writerConfig.setBytecodeVersion(*v); + // Returns failure if requested version couldn't be used for opt tools. + return success(writeBytecodeToFile(op.get(), os, writerConfig) == *v); + } writeBytecodeToFile(op.get(), os, writerConfig); - } else { - AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, - &fallbackResourceMap); - op.get()->print(os, asmState); - os << '\n'; + return success(); } + + if (config.bytecodeVersionToEmit().has_value()) + return emitError(UnknownLoc::get(pm.getContext())) + << "bytecode version while not emitting bytecode"; + AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, + &fallbackResourceMap); + op.get()->print(os, asmState); + os << '\n'; return success(); } diff --git a/mlir/test/Bytecode/versioning/versioned_bytecode.mlir b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir @@ -0,0 +1,14 @@ +// This file contains test cases related to roundtripping. + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: target=s390x-{{.*}} + +//===--------------------------------------------------------------------===// +// Test roundtrip +//===--------------------------------------------------------------------===// + +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -emit-bytecode \ +// RUN: -emit-bytecode-version=0 | mlir-opt -o %t.1 && \ +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -o %t.2 && \ +// RUN: diff %t.1 %t.2 +