diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -40,6 +40,11 @@ /// Return an instance of the internal implementation. const Impl &getImpl() const { return *impl; } + /// Set the bytecode version to emit. Returns whether or not that version can + /// be emitted. This function only updates the set version if it can be + /// emitted. + LogicalResult setBytecodeVersion(int64_t bytecodeVersion); + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -74,6 +74,15 @@ } bool shouldEmitBytecode() const { return emitBytecodeFlag; } + /// Set the bytecode version to emit. + MlirOptMainConfig &setEmitBytecodeVersion(int64_t version) { + emitBytecodeVersion = version; + return *this; + } + std::optional bytecodeVersionToEmit() const { + return emitBytecodeVersion; + } + /// Set the callback to populate the pass manager. MlirOptMainConfig & setPassPipelineSetupFn(std::function callback) { @@ -149,6 +158,9 @@ /// Emit bytecode instead of textual assembly when generating output. bool emitBytecodeFlag = false; + /// Emit bytecode at given version. + std::optional emitBytecodeVersion = std::nullopt; + /// The callback to populate the pass manager. std::function passPipelineCallback; diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -25,7 +25,12 @@ //===----------------------------------------------------------------------===// struct BytecodeWriterConfig::Impl { - Impl(StringRef producer) : producer(producer) {} + Impl(StringRef producer) + : bytecodeVersion(bytecode::kVersion), producer(producer) {} + + /// Optional version to use when writing. + /// Note: This is only set if a specific set version can be emitted. + int64_t bytecodeVersion; /// The producer of the bytecode. StringRef producer; @@ -48,6 +53,14 @@ impl->externalResourcePrinters.emplace_back(std::move(printer)); } +LogicalResult +BytecodeWriterConfig::setBytecodeVersion(int64_t bytecodeVersion) { + if (bytecodeVersion > bytecode::kVersion) + return failure(); + impl->bytecodeVersion = bytecodeVersion; + return success(); +} + //===----------------------------------------------------------------------===// // EncodingEmitter //===----------------------------------------------------------------------===// @@ -421,11 +434,11 @@ namespace { class BytecodeWriter { public: - BytecodeWriter(Operation *op) : numberingState(op) {} + BytecodeWriter(Operation *op, const BytecodeWriterConfig::Impl &config) + : numberingState(op), config(config) {} /// Write the bytecode for the given root operation. - void write(Operation *rootOp, raw_ostream &os, - const BytecodeWriterConfig::Impl &config); + void write(Operation *rootOp, raw_ostream &os); private: //===--------------------------------------------------------------------===// @@ -449,8 +462,7 @@ //===--------------------------------------------------------------------===// // Resources - void writeResourceSection(Operation *op, EncodingEmitter &emitter, - const BytecodeWriterConfig::Impl &config); + void writeResourceSection(Operation *op, EncodingEmitter &emitter); //===--------------------------------------------------------------------===// // Strings @@ -465,11 +477,13 @@ /// The IR numbering state generated for the root operation. IRNumberingState numberingState; + + /// Configuration dictating bytecode emission. + const BytecodeWriterConfig::Impl &config; }; } // namespace -void BytecodeWriter::write(Operation *rootOp, raw_ostream &os, - const BytecodeWriterConfig::Impl &config) { +void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { EncodingEmitter emitter; // Emit the bytecode file header. This is how we identify the output as a @@ -477,7 +491,7 @@ emitter.emitString("ML\xefR"); // Emit the bytecode version. - emitter.emitVarInt(bytecode::kVersion); + emitter.emitVarInt(config.bytecodeVersion); // Emit the producer. emitter.emitNulTerminatedString(config.producer); @@ -492,7 +506,7 @@ writeIRSection(emitter, rootOp); // Emit the resources section. - writeResourceSection(rootOp, emitter, config); + writeResourceSection(rootOp, emitter); // Emit the string section. writeStringSection(emitter); @@ -540,6 +554,11 @@ // Write the string section and get the ID. size_t nameID = stringSection.insert(dialect.name); + if (config.bytecodeVersion == 0) { + dialectEmitter.emitVarInt(nameID); + continue; + } + // Try writing the version to the versionEmitter. EncodingEmitter versionEmitter; if (dialect.interface) { @@ -787,9 +806,8 @@ }; } // namespace -void BytecodeWriter::writeResourceSection( - Operation *op, EncodingEmitter &emitter, - const BytecodeWriterConfig::Impl &config) { +void BytecodeWriter::writeResourceSection(Operation *op, + EncodingEmitter &emitter) { EncodingEmitter resourceEmitter; EncodingEmitter resourceOffsetEmitter; uint64_t prevOffset = 0; @@ -870,6 +888,6 @@ void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, const BytecodeWriterConfig &config) { - BytecodeWriter writer(op); - writer.write(op, os, config.getImpl()); + BytecodeWriter writer(op, config.getImpl()); + writer.write(op, os); } diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -43,6 +43,22 @@ using namespace llvm; namespace { +class BytecodeVersionParser : public cl::parser> { +public: + BytecodeVersionParser(cl::Option &O) + : cl::parser>(O) {} + + bool parse(cl::Option &O, StringRef /*argName*/, StringRef arg, + std::optional &v) { + long long w; + if (getAsSignedInteger(arg, 10, w)) + return O.error("Invalid argument '" + arg + + "', only integer is supported."); + v = w; + return false; + } +}; + /// This class is intended to manage the handling of command line options for /// creating a *-opt config. This is a singleton. struct MlirOptMainConfigCLOptions : public MlirOptMainConfig { @@ -65,6 +81,13 @@ "emit-bytecode", cl::desc("Emit bytecode when generating output"), cl::location(emitBytecodeFlag), cl::init(false)); + static cl::opt, /*ExternalStorage=*/true, + BytecodeVersionParser> + bytecodeVersion("emit-bytecode-version", + cl::desc("Emit bytecode when generating output"), + cl::location(emitBytecodeVersion), + cl::init(std::nullopt)); + static cl::opt explicitModule( "no-implicit-module", cl::desc("Disable implicit addition of a top-level module op during " @@ -180,6 +203,10 @@ TimingScope outputTiming = timing.nest("Output"); if (config.shouldEmitBytecode()) { BytecodeWriterConfig writerConfig(fallbackResourceMap); + if (auto v = config.bytecodeVersionToEmit(); + v && failed(writerConfig.setBytecodeVersion(*v))) + return emitError(UnknownLoc::get(pm.getContext())) + << "unsupported bytecode version chosen to emit"; writeBytecodeToFile(op.get(), os, writerConfig); } else { AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr, diff --git a/mlir/test/Bytecode/versioning/versioned_bytecode.mlir b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir @@ -0,0 +1,14 @@ +// This file contains test cases related to roundtripping. + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: target=s390x-{{.*}} + +//===--------------------------------------------------------------------===// +// Test roundtrip +//===--------------------------------------------------------------------===// + +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -emit-bytecode \ +// RUN: -emit-bytecode-version=0 | mlir-opt -o %t.1 && \ +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -o %t.2 && \ +// RUN: diff %t.1 %t.2 +