diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -6,7 +6,8 @@ ## Magic Number -MLIR uses the following four-byte magic number to indicate bytecode files: +MLIR uses the following four-byte magic number to +indicate bytecode files: '\[‘M’8, ‘L’8, ‘ï’8, ‘R’8\]' @@ -157,16 +158,25 @@ } op_name_group { - dialect: varint, + dialect: varint // (dialectID << 1) | (hasVersion), + version : dialect_version_section numOpNames: varint, opNames: varint[] } + +dialect_version_section { + size: varint, + version: byte[] +} + ``` -Dialects are encoded as indexes to the name string within the string section. -Operation names are encoded in groups by dialect, with each group containing the -dialect, the number of operation names, and the array of indexes to each name -within the string section. +Dialects are encoded as a `varint` containing the index to the name string +within the string section, plus a flag indicating whether the dialect is +versioned. Operation names are encoded in groups by dialect, with each group +containing the dialect, the number of operation names, and the array of indexes +to each name within the string section. The version is encoded as a nested +section. ### Attribute/Type Sections diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -845,3 +845,18 @@ that are directly usable by any other dialect in MLIR. These types cover a range from primitive integer and floating-point values, attribute dictionaries, dense multi-dimensional arrays, and more. + +### IR Versionning + +A dialect can opt-in to handle versioning through the +`BytecodeDialectInterface`. Few hooks are exposed to the dialect to allow +managing a version encoded into the bytecode file. The version is loaded lazily +and allows to retrieve the version information while parsing the input IR, and +gives an opportunity to each dialect for which a version is present to perform +IR upgrades post-parsing through the `upgradeFromVersion` method. Custom +Attribute and Type encodings can also be upgraded according to the dialect +version using readAttribute and readType methods. + +There is no restriction on what kind of information a dialect is allowed to +encode to model its versioning. Currently, versioning is supported only for +bytecode formats. diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -235,6 +235,17 @@ virtual void writeOwnedBlob(ArrayRef blob) = 0; }; +//===--------------------------------------------------------------------===// +// Dialect Version Interface. +//===--------------------------------------------------------------------===// + +/// This class is used to represent the version of a dialect, for the purpose +/// of polymorphic destruction. +class DialectVersion { +public: + virtual ~DialectVersion() = default; +}; + //===----------------------------------------------------------------------===// // BytecodeDialectInterface //===----------------------------------------------------------------------===// @@ -256,6 +267,19 @@ return Attribute(); } + /// Read a versioned attribute encoding belonging to this dialect from the + /// given reader. This method should return null in the case of failure, and + /// falls back to the non-versioned reader in case the dialect implements + /// versioning but it does not support versioned custom encodings for the + /// attributes. + virtual Attribute readAttribute(DialectBytecodeReader &reader, + const DialectVersion &version) const { + reader.emitError() + << "dialect " << getDialect()->getNamespace() + << " does not support reading versioned attributes from bytecode"; + return Attribute(); + } + /// Read a type belonging to this dialect from the given reader. This method /// should return null in the case of failure. virtual Type readType(DialectBytecodeReader &reader) const { @@ -264,6 +288,19 @@ return Type(); } + /// Read a versioned type encoding belonging to this dialect from the given + /// reader. This method should return null in the case of failure, and + /// falls back to the non-versioned reader in case the dialect implements + /// versioning but it does not support versioned custom encodings for the + /// types. + virtual Type readType(DialectBytecodeReader &reader, + const DialectVersion &version) const { + reader.emitError() + << "dialect " << getDialect()->getNamespace() + << " does not support reading versioned types from bytecode"; + return Type(); + } + //===--------------------------------------------------------------------===// // Writing //===--------------------------------------------------------------------===// @@ -285,6 +322,27 @@ DialectBytecodeWriter &writer) const { return failure(); } + + /// Write the version of this dialect to the given writer. + virtual void writeVersion(DialectBytecodeWriter &writer) const {} + + // Read the version of this dialect from the provided reader and return it as + // a `unique_ptr` to a dialect version object. + virtual std::unique_ptr + readVersion(DialectBytecodeReader &reader) const { + reader.emitError("Dialect does not support versioning"); + return nullptr; + } + + /// Hook invoked after parsing completed, if a version directive was present + /// and included an entry for the current dialect. This hook offers the + /// opportunity to the dialect to visit the IR and upgrades constructs emitted + /// by the version of the dialect corresponding to the provided version. + virtual LogicalResult + upgradeFromVersion(Operation *topLevelOp, + const DialectVersion &version) const { + return success(); + } }; } // namespace mlir diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -23,8 +23,11 @@ //===----------------------------------------------------------------------===// enum { + /// The minimum supported version of the bytecode. + kMinSupportedVersion = 0, + /// The current bytecode version. - kVersion = 0, + kVersion = 1, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, @@ -61,8 +64,11 @@ /// section. kResourceOffset = 6, + /// This section contains the versions of each dialect. + kDialectVersions = 7, + /// The total number of section types. - kNumSections = 7, + kNumSections = 8, }; } // namespace Section diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -47,6 +47,8 @@ return "Resource (5)"; case bytecode::Section::kResourceOffset: return "ResourceOffset (6)"; + case bytecode::Section::kDialectVersions: + return "DialectVersions (7)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } @@ -63,6 +65,7 @@ return false; case bytecode::Section::kResource: case bytecode::Section::kResourceOffset: + case bytecode::Section::kDialectVersions: return true; default: llvm_unreachable("unknown section ID"); @@ -350,6 +353,13 @@ return parseEntry(reader, strings, result, "string"); } + /// Parse a shared string from the string section. The shared string is + /// encoded using an index to a corresponding string in the string section. + LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, + StringRef &result) { + return resolveEntry(reader, strings, index, result, "string"); + } + private: /// The table of strings referenced within the bytecode file. SmallVector strings; @@ -400,31 +410,15 @@ //===----------------------------------------------------------------------===// namespace { +class DialectReader; + /// This struct represents a dialect entry within the bytecode. struct BytecodeDialect { /// Load the dialect into the provided context if it hasn't been loaded yet. /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(EncodingReader &reader, MLIRContext *ctx) { - if (dialect) - return success(); - Dialect *loadedDialect = ctx->getOrLoadDialect(name); - if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { - return reader.emitError( - "dialect '", name, - "' is unknown. If this is intended, please call " - "allowUnregisteredDialects() on the MLIRContext, or use " - "-allow-unregistered-dialect with the MLIR tool used."); - } - dialect = loadedDialect; - - // If the dialect was actually loaded, check to see if it has a bytecode - // interface. - if (loadedDialect) - interface = dyn_cast(loadedDialect); - return success(); - } + LogicalResult load(DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -446,6 +440,12 @@ /// The name of the dialect. StringRef name; + + /// A buffer containing the encoding of the dialect version parsed. + ArrayRef versionBuffer; + + /// Lazy loaded dialect version from the handle above. + std::unique_ptr loadedVersion; }; /// This struct represents an operation name entry within the bytecode. @@ -496,7 +496,7 @@ initialize(Location fileLoc, const ParserConfig &config, MutableArrayRef dialects, StringSectionReader &stringReader, ArrayRef sectionData, - ArrayRef offsetSectionData, + ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef); /// Parse a dialect resource handle from the resource section. @@ -643,7 +643,7 @@ Location fileLoc, const ParserConfig &config, MutableArrayRef dialects, StringSectionReader &stringReader, ArrayRef sectionData, - ArrayRef offsetSectionData, + ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef) { EncodingReader resourceReader(sectionData, fileLoc); EncodingReader offsetReader(offsetSectionData, fileLoc); @@ -684,7 +684,7 @@ while (!offsetReader.empty()) { BytecodeDialect *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(resourceReader, ctx))) + failed(dialect->load(dialectReader, ctx))) return failure(); Dialect *loadedDialect = dialect->getLoadedDialect(); if (!loadedDialect) { @@ -1051,7 +1051,8 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - if (failed(entry.dialect->load(reader, fileLoc.getContext()))) + DialectReader dialectReader(*this, stringReader, resourceReader, reader); + if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1060,12 +1061,22 @@ "' does not implement the bytecode interface"); } - // Ask the dialect to parse the entry. - DialectReader dialectReader(*this, stringReader, resourceReader, reader); - if constexpr (std::is_same_v) - entry.entry = entry.dialect->interface->readType(dialectReader); - else - entry.entry = entry.dialect->interface->readAttribute(dialectReader); + // Ask the dialect to parse the entry. If the dialect is versioned, parse + // using the versioned encoding readers. + if (entry.dialect->loadedVersion.get()) { + if constexpr (std::is_same_v) + entry.entry = entry.dialect->interface->readType( + dialectReader, *entry.dialect->loadedVersion); + else + entry.entry = entry.dialect->interface->readAttribute( + dialectReader, *entry.dialect->loadedVersion); + + } else { + if constexpr (std::is_same_v) + entry.entry = entry.dialect->interface->readType(dialectReader); + else + entry.entry = entry.dialect->interface->readAttribute(dialectReader); + } return success(!!entry.entry); } @@ -1122,7 +1133,8 @@ // Resource Section LogicalResult - parseResourceSection(std::optional> resourceData, + parseResourceSection(EncodingReader &reader, + std::optional> resourceData, std::optional> resourceOffsetData); //===--------------------------------------------------------------------===// @@ -1306,7 +1318,7 @@ // Process the resource section if present. if (failed(parseResourceSection( - sectionDatas[bytecode::Section::kResource], + reader, sectionDatas[bytecode::Section::kResource], sectionDatas[bytecode::Section::kResourceOffset]))) return failure(); @@ -1326,7 +1338,8 @@ // Validate the bytecode version. uint64_t currentVersion = bytecode::kVersion; - if (version < currentVersion) { + uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; + if (version < minSupportedVersion) { return reader.emitError("bytecode version ", version, " is older than the current version of ", currentVersion, ", and upgrade is not supported"); @@ -1342,6 +1355,36 @@ //===----------------------------------------------------------------------===// // Dialect Section +LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { + if (dialect) + return success(); + Dialect *loadedDialect = ctx->getOrLoadDialect(name); + if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { + return reader.emitError("dialect '") + << name + << "' is unknown. If this is intended, please call " + "allowUnregisteredDialects() on the MLIRContext, or use " + "-allow-unregistered-dialect with the MLIR tool used."; + } + dialect = loadedDialect; + + // If the dialect was actually loaded, check to see if it has a bytecode + // interface. + if (loadedDialect) + interface = dyn_cast(loadedDialect); + if (!versionBuffer.empty()) { + if (!interface) + return reader.emitError("dialect '") + << name + << "' does not implement the bytecode interface, " + "but found a version entry"; + loadedVersion = interface->readVersion(reader); + if (!loadedVersion) + return failure(); + } + return success(); +} + LogicalResult BytecodeReader::parseDialectSection(ArrayRef sectionData) { EncodingReader sectionReader(sectionData, fileLoc); @@ -1353,9 +1396,34 @@ dialects.resize(numDialects); // Parse each of the dialects. - for (uint64_t i = 0; i < numDialects; ++i) - if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + for (uint64_t i = 0; i < numDialects; ++i) { + /// Before version 1, there wasn't any versioning available for dialects, + /// and the entryIdx represent the string itself. + if (version == 0) { + if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + return failure(); + continue; + } + // Parse ID representing dialect and version. + uint64_t dialectNameIdx; + bool versionAvailable; + if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, + versionAvailable))) + return failure(); + if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, + dialects[i].name))) return failure(); + if (versionAvailable) { + bytecode::Section::ID sectionID; + if (failed( + sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) + return failure(); + if (sectionID != bytecode::Section::kDialectVersions) { + emitError(fileLoc, "expected dialect version section"); + return failure(); + } + } + } // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { @@ -1379,7 +1447,11 @@ // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. if (!opName->opName) { - if (failed(opName->dialect->load(reader, getContext()))) + // Load the dialect and its version. + EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc); + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + versionReader); + if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), getContext()); @@ -1391,7 +1463,7 @@ // Resource Section LogicalResult BytecodeReader::parseResourceSection( - std::optional> resourceData, + EncodingReader &reader, std::optional> resourceData, std::optional> resourceOffsetData) { // Ensure both sections are either present or not. if (resourceData.has_value() != resourceOffsetData.has_value()) { @@ -1408,9 +1480,11 @@ return success(); // Initialize the resource reader with the resource sections. + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, - bufferOwnerRef); + dialectReader, bufferOwnerRef); } //===----------------------------------------------------------------------===// @@ -1442,6 +1516,18 @@ "not all forward unresolved forward operand references"); } + // Resolve dialect version. + for (const BytecodeDialect &byteCodeDialect : dialects) { + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (!byteCodeDialect.loadedVersion) + continue; + if (byteCodeDialect.interface && + failed(byteCodeDialect.interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect.loadedVersion))) + return failure(); + } + // Verify that the parsed operations are valid. if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) return failure(); diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -10,13 +10,10 @@ #include "../Encoding.h" #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" -#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" -#include "llvm/Support/Debug.h" -#include #define DEBUG_TYPE "mlir-bytecode-writer" @@ -261,6 +258,116 @@ unsigned requiredAlignment = 1; }; +//===----------------------------------------------------------------------===// +// StringSectionBuilder +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to simplify the process of emitting the string section. +class StringSectionBuilder { +public: + /// Add the given string to the string section, and return the index of the + /// string within the section. + size_t insert(StringRef str) { + auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); + return it.first->second; + } + + /// Write the current set of strings to the given emitter. + void write(EncodingEmitter &emitter) { + emitter.emitVarInt(strings.size()); + + // Emit the sizes in reverse order, so that we don't need to backpatch an + // offset to the string data or have a separate section. + for (const auto &it : llvm::reverse(strings)) + emitter.emitVarInt(it.first.size() + 1); + // Emit the string data itself. + for (const auto &it : strings) + emitter.emitNulTerminatedString(it.first.val()); + } + +private: + /// A set of strings referenced within the bytecode. The value of the map is + /// unused. + llvm::MapVector strings; +}; +} // namespace + +class DialectWriter : public DialectBytecodeWriter { +public: + DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, + StringSectionBuilder &stringSection) + : emitter(emitter), numberingState(numberingState), + stringSection(stringSection) {} + + //===--------------------------------------------------------------------===// + // IR + //===--------------------------------------------------------------------===// + + void writeAttribute(Attribute attr) override { + emitter.emitVarInt(numberingState.getNumber(attr)); + } + void writeType(Type type) override { + emitter.emitVarInt(numberingState.getNumber(type)); + } + + void writeResourceHandle(const AsmDialectResourceHandle &resource) override { + emitter.emitVarInt(numberingState.getNumber(resource)); + } + + //===--------------------------------------------------------------------===// + // Primitives + //===--------------------------------------------------------------------===// + + void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } + + void writeSignedVarInt(int64_t value) override { + emitter.emitSignedVarInt(value); + } + + void writeAPIntWithKnownWidth(const APInt &value) override { + size_t bitWidth = value.getBitWidth(); + + // If the value is a single byte, just emit it directly without going + // through a varint. + if (bitWidth <= 8) + return emitter.emitByte(value.getLimitedValue()); + + // If the value fits within a single varint, emit it directly. + if (bitWidth <= 64) + return emitter.emitSignedVarInt(value.getLimitedValue()); + + // Otherwise, we need to encode a variable number of active words. We use + // active words instead of the number of total words under the observation + // that smaller values will be more common. + unsigned numActiveWords = value.getActiveWords(); + emitter.emitVarInt(numActiveWords); + + const uint64_t *rawValueData = value.getRawData(); + for (unsigned i = 0; i < numActiveWords; ++i) + emitter.emitSignedVarInt(rawValueData[i]); + } + + void writeAPFloatWithKnownSemantics(const APFloat &value) override { + writeAPIntWithKnownWidth(value.bitcastToAPInt()); + } + + void writeOwnedString(StringRef str) override { + emitter.emitVarInt(stringSection.insert(str)); + } + + void writeOwnedBlob(ArrayRef blob) override { + emitter.emitVarInt(blob.size()); + emitter.emitOwnedBlob(ArrayRef( + reinterpret_cast(blob.data()), blob.size())); + } + +private: + EncodingEmitter &emitter; + IRNumberingState &numberingState; + StringSectionBuilder &stringSection; +}; + /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need /// to go through an intermediate buffer when interacting with code that wants a /// raw_ostream. @@ -307,41 +414,6 @@ emitBytes({reinterpret_cast(&value), sizeof(value)}); } -//===----------------------------------------------------------------------===// -// StringSectionBuilder -//===----------------------------------------------------------------------===// - -namespace { -/// This class is used to simplify the process of emitting the string section. -class StringSectionBuilder { -public: - /// Add the given string to the string section, and return the index of the - /// string within the section. - size_t insert(StringRef str) { - auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); - return it.first->second; - } - - /// Write the current set of strings to the given emitter. - void write(EncodingEmitter &emitter) { - emitter.emitVarInt(strings.size()); - - // Emit the sizes in reverse order, so that we don't need to backpatch an - // offset to the string data or have a separate section. - for (const auto &it : llvm::reverse(strings)) - emitter.emitVarInt(it.first.size() + 1); - // Emit the string data itself. - for (const auto &it : strings) - emitter.emitNulTerminatedString(it.first.val()); - } - -private: - /// A set of strings referenced within the bytecode. The value of the map is - /// unused. - llvm::MapVector strings; -}; -} // namespace - //===----------------------------------------------------------------------===// // Bytecode Writer //===----------------------------------------------------------------------===// @@ -464,8 +536,28 @@ // Emit the referenced dialects. auto dialects = numberingState.getDialects(); dialectEmitter.emitVarInt(llvm::size(dialects)); - for (DialectNumbering &dialect : dialects) - dialectEmitter.emitVarInt(stringSection.insert(dialect.name)); + for (DialectNumbering &dialect : dialects) { + // Write the string section and get the ID. + size_t nameID = stringSection.insert(dialect.name); + + // Try writing the version to the versionEmitter. + EncodingEmitter versionEmitter; + if (dialect.interface) { + // The writer used when emitting using a custom bytecode encoding. + DialectWriter versionWriter(versionEmitter, numberingState, + stringSection); + dialect.interface->writeVersion(versionWriter); + } + + // If the version emitter is empty, version is not available. We can encode + // this in the dialect ID, so if there is no version, we don't write the + // section. + size_t versionAvailable = versionEmitter.size() > 0; + dialectEmitter.emitVarIntWithFlag(nameID, versionAvailable); + if (versionAvailable) + dialectEmitter.emitSection(bytecode::Section::kDialectVersions, + std::move(versionEmitter)); + } // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { @@ -479,83 +571,6 @@ //===----------------------------------------------------------------------===// // Attributes and Types -namespace { -class DialectWriter : public DialectBytecodeWriter { -public: - DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, - StringSectionBuilder &stringSection) - : emitter(emitter), numberingState(numberingState), - stringSection(stringSection) {} - - //===--------------------------------------------------------------------===// - // IR - //===--------------------------------------------------------------------===// - - void writeAttribute(Attribute attr) override { - emitter.emitVarInt(numberingState.getNumber(attr)); - } - void writeType(Type type) override { - emitter.emitVarInt(numberingState.getNumber(type)); - } - - void writeResourceHandle(const AsmDialectResourceHandle &resource) override { - emitter.emitVarInt(numberingState.getNumber(resource)); - } - - //===--------------------------------------------------------------------===// - // Primitives - //===--------------------------------------------------------------------===// - - void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } - - void writeSignedVarInt(int64_t value) override { - emitter.emitSignedVarInt(value); - } - - void writeAPIntWithKnownWidth(const APInt &value) override { - size_t bitWidth = value.getBitWidth(); - - // If the value is a single byte, just emit it directly without going - // through a varint. - if (bitWidth <= 8) - return emitter.emitByte(value.getLimitedValue()); - - // If the value fits within a single varint, emit it directly. - if (bitWidth <= 64) - return emitter.emitSignedVarInt(value.getLimitedValue()); - - // Otherwise, we need to encode a variable number of active words. We use - // active words instead of the number of total words under the observation - // that smaller values will be more common. - unsigned numActiveWords = value.getActiveWords(); - emitter.emitVarInt(numActiveWords); - - const uint64_t *rawValueData = value.getRawData(); - for (unsigned i = 0; i < numActiveWords; ++i) - emitter.emitSignedVarInt(rawValueData[i]); - } - - void writeAPFloatWithKnownSemantics(const APFloat &value) override { - writeAPIntWithKnownWidth(value.bitcastToAPInt()); - } - - void writeOwnedString(StringRef str) override { - emitter.emitVarInt(stringSection.insert(str)); - } - - void writeOwnedBlob(ArrayRef blob) override { - emitter.emitVarInt(blob.size()); - emitter.emitOwnedBlob(ArrayRef( - reinterpret_cast(blob.data()), blob.size())); - } - -private: - EncodingEmitter &emitter; - IRNumberingState &numberingState; - StringSectionBuilder &stringSection; -}; -} // namespace - void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { EncodingEmitter attrTypeEmitter; EncodingEmitter offsetEmitter; diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version 0 +// VERSION: bytecode version 127 is newer than the current version 1 //===--------------------------------------------------------------------===// // Producer diff --git a/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc b/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-attr-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 +// CHECK1: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> () + +//===--------------------------------------------------------------------===// +// Test attribute upgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.0 +// COM: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-attr-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2 +// CHECK2: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> () diff --git a/mlir/test/Bytecode/versioning/versioned_op.mlir b/mlir/test/Bytecode/versioning/versioned_op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/versioning/versioned_op.mlir @@ -0,0 +1,41 @@ +// This file contains test cases related to the dialect post-parsing upgrade +// mechanism. + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: target=s390x-{{.*}} + +//===--------------------------------------------------------------------===// +// Test generic +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.0 +// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-op-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 +// CHECK1: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () + +//===--------------------------------------------------------------------===// +// Test upgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 1.12 +// COM: "test.versionedA"() {dimensions = 123 : i64} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2 +// CHECK2: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () + +//===--------------------------------------------------------------------===// +// Test forbidden downgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.2 +// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () +// COM: } +// RUN: not mlir-opt %S/versioned-op-2.2.mlirbc 2>&1 | FileCheck %s --check-prefix=ERR_NEW_VERSION +// ERR_NEW_VERSION: current test dialect version is 2.0, can't parse version: 2.2 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -10,15 +10,14 @@ #include "TestAttributes.h" #include "TestInterfaces.h" #include "TestTypes.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OperationSupport.h" @@ -32,9 +31,9 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" -#include #include +#include // Include this before the using namespace lines below to // test that we don't have namespace dependencies. @@ -47,6 +46,15 @@ registry.insert(); } +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion : public DialectVersion { + uint32_t major = 2; + uint32_t minor = 0; +}; + //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -70,6 +78,107 @@ TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase; }; +namespace { +enum test_encoding { k_attr_params = 0 }; +} + +// Test support for interacting with the Bytecode reader/writer. +struct TestBytecodeDialectInterface : public BytecodeDialectInterface { + using BytecodeDialectInterface::BytecodeDialectInterface; + TestBytecodeDialectInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const final { + if (auto concreteAttr = llvm::dyn_cast(attr)) { + writer.writeVarInt(test_encoding::k_attr_params); + writer.writeVarInt(concreteAttr.getV0()); + writer.writeVarInt(concreteAttr.getV1()); + return success(); + } + writer.writeAttribute(attr); + return success(); + } + + Attribute readAttribute(DialectBytecodeReader &reader, + const DialectVersion &version_) const final { + const auto &version = static_cast(version_); + if (version.major < 2) + return readAttrOldEncoding(reader); + if (version.major == 2 && version.minor == 0) + return readAttrNewEncoding(reader); + // Forbid reading future versions by returning nullptr. + return Attribute(); + } + + // Emit a specific version of the dialect. + void writeVersion(DialectBytecodeWriter &writer) const final { + auto version = TestDialectVersion(); + writer.writeVarInt(version.major); // major + writer.writeVarInt(version.minor); // minor + } + + std::unique_ptr + readVersion(DialectBytecodeReader &reader) const final { + uint64_t major, minor; + if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor))) + return nullptr; + auto version = std::make_unique(); + version->major = major; + version->minor = minor; + return version; + } + + LogicalResult upgradeFromVersion(Operation *topLevelOp, + const DialectVersion &version_) const final { + const auto &version = static_cast(version_); + if ((version.major == 2) && (version.minor == 0)) + return success(); + if (version.major > 2 || (version.major == 2 && version.minor > 0)) { + return topLevelOp->emitError() + << "current test dialect version is 2.0, can't parse version: " + << version.major << "." << version.minor; + } + // Prior version 2.0, the old op supported only a single attribute called + // "dimensions". We can perform the upgrade. + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + op->setAttr("modifier", BoolAttr::get(op->getContext(), false)); + }); + return success(); + } + +private: + Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || + encoding != test_encoding::k_attr_params) + return Attribute(); + // The new encoding has v0 first, v1 second. + uint64_t v0, v1; + if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1))) + return Attribute(); + return TestAttrParamsAttr::get(getContext(), static_cast(v0), + static_cast(v1)); + } + + Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || + encoding != test_encoding::k_attr_params) + return Attribute(); + // The old encoding has v1 first, v0 second. + uint64_t v0, v1; + if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0))) + return Attribute(); + return TestAttrParamsAttr::get(getContext(), static_cast(v0), + static_cast(v1)); + } +}; + // Test support for interacting with the AsmPrinter. struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; @@ -367,7 +476,7 @@ addInterface(blobInterface); addInterfaces(); + TestReductionPatternInterface, TestBytecodeDialectInterface>(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific @@ -1103,9 +1212,7 @@ return getOperand(); } -OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { - return getValue(); -} +OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( FoldAdaptor adaptor, SmallVectorImpl &results) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3149,4 +3149,43 @@ }]; } +//===----------------------------------------------------------------------===// +// Test Ops to upgrade base on the dialect versions +//===----------------------------------------------------------------------===// + +def TestVersionedOpA : TEST_Op<"versionedA"> { + // A previous version of the dialect (let's say 1.*) supported an attribute + // named "dimensions": + // let arguments = (ins + // AnyI64Attr:$dimensions + // ); + + // In the current version (2.0) "dimensions" was renamed to "dims", and a new + // boolean attribute "modifier" was added. The previous version of the op + // corresponds to "modifier=false". We support loading old IR through + // upgrading, see `upgradeFromVersion()` in `TestBytecodeDialectInterface`. + let arguments = (ins + AnyI64Attr:$dims, + BoolAttr:$modifier + ); +} + +def TestVersionedOpB : TEST_Op<"versionedB"> { + // A previous version of the dialect (let's say 1.*) we encoded TestAttrParams + // with a custom encoding: + // + // #test.attr_params -> { varInt: Y, varInt: X } + // + // In the current version (2.0) the encoding changed and the two parameters of + // the attribute are swapped: + // + // #test.attr_params -> { varInt: X, varInt: Y } + // + // We support loading old IR through a custom readAttribute method, see + // `readAttribute()` in `TestBytecodeDialectInterface` + let arguments = (ins + TestAttrParams:$attribute + ); +} + #endif // TEST_OPS