diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -158,15 +158,22 @@ op_name_group { dialect: varint, + version : dialect_version_section numOpNames: varint, opNames: varint[] } + +dialect_version_section { + size: varint, + version: byte[] +} + ``` Dialects are encoded as indexes to the name string within the string section. Operation names are encoded in groups by dialect, with each group containing the dialect, the number of operation names, and the array of indexes to each name -within the string section. +within the string section. The version is encoded as a nested section. ### Attribute/Type Sections diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -845,3 +845,18 @@ that are directly usable by any other dialect in MLIR. These types cover a range from primitive integer and floating-point values, attribute dictionaries, dense multi-dimensional arrays, and more. + +### IR Versionning + +A dialect can opt-in to handle versioning through the +`BytecodeDialectInterface`. Few hooks are exposed to the dialect to allow +managing a version encoded into the bytecode file. The version is loaded lazily +and allows to retrieve the version information while parsing the input IR, and +gives an opportunity to each dialect for which a version is present to perform +IR upgrades post-parsing through the `upgradeFromVersion` method. Custom +Attribute and Type encodings can also be upgraded according to the dialect +version using readAttribute and readType methods. + +There is no restriction on what kind of information a dialect is allowed to +encode to model its versioning. Currently, versioning is supported only for +bytecode formats. diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -235,6 +235,17 @@ virtual void writeOwnedBlob(ArrayRef blob) = 0; }; +//===--------------------------------------------------------------------===// +// Dialect Version Interface. +//===--------------------------------------------------------------------===// + +/// This class is used to represent the version of a dialect, for the purpose +/// of polymorphic destruction. +class DialectVersion { +public: + virtual ~DialectVersion() = default; +}; + //===----------------------------------------------------------------------===// // BytecodeDialectInterface //===----------------------------------------------------------------------===// @@ -256,6 +267,16 @@ return Attribute(); } + /// Read a versioned attribute encoding belonging to this dialect from the + /// given reader. This method should return null in the case of failure, and + /// falls back to the non-versioned reader in case the dialect implements + /// versioning but it does not support versioned custom encodings for the + /// attributes. + virtual Attribute readAttribute(DialectBytecodeReader &reader, + const DialectVersion &version) const { + return readAttribute(reader); + } + /// Read a type belonging to this dialect from the given reader. This method /// should return null in the case of failure. virtual Type readType(DialectBytecodeReader &reader) const { @@ -264,6 +285,16 @@ return Type(); } + /// Read a versioned type encoding belonging to this dialect from the given + /// reader. This method should return null in the case of failure, and + /// falls back to the non-versioned reader in case the dialect implements + /// versioning but it does not support versioned custom encodings for the + /// types. + virtual Type readType(DialectBytecodeReader &reader, + const DialectVersion &version) const { + return readType(reader); + } + //===--------------------------------------------------------------------===// // Writing //===--------------------------------------------------------------------===// @@ -285,6 +316,26 @@ DialectBytecodeWriter &writer) const { return failure(); } + + /// Write the version of this dialect to the given writer. + /// The first emitted VarInt should be the size of the entry. + virtual void writeVersion(DialectBytecodeWriter &writer) const {} + + virtual std::unique_ptr + readVersion(DialectBytecodeReader &reader) const { + reader.emitError("Dialect does not support versioning"); + return nullptr; + } + + /// Hook invoked after parsing completed, if a version directive was present + /// and included an entry for the current dialect. This hook offers the + /// opportunity to the dialect to visit the IR and upgrades constructs emitted + /// by the version of the dialect corresponding to the provided version. + virtual LogicalResult + upgradeFromVersion(Operation *topLevelOp, + const DialectVersion &version) const { + return success(); + } }; } // namespace mlir diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1463,9 +1463,9 @@ //===--------------------------------------------------------------------===// struct Argument { - UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. - Type type; // Type. - DictionaryAttr attrs; // Attributes if present. + UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. + Type type; // Type. + DictionaryAttr attrs; // Attributes if present. std::optional sourceLoc; // Source location specifier if present. }; diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -23,8 +23,11 @@ //===----------------------------------------------------------------------===// enum { + /// The minimum supported version of the bytecode. + kMinSupportedVersion = 0, + /// The current bytecode version. - kVersion = 0, + kVersion = 1, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, @@ -61,8 +64,11 @@ /// section. kResourceOffset = 6, + /// This section contains the versions of each dialect. + kDialectVersions = 7, + /// The total number of section types. - kNumSections = 7, + kNumSections = 8, }; } // namespace Section diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -47,6 +47,8 @@ return "Resource (5)"; case bytecode::Section::kResourceOffset: return "ResourceOffset (6)"; + case bytecode::Section::kDialectVersions: + return "DialectVersions (7)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } @@ -63,6 +65,7 @@ return false; case bytecode::Section::kResource: case bytecode::Section::kResourceOffset: + case bytecode::Section::kDialectVersions: return true; default: llvm_unreachable("unknown section ID"); @@ -400,31 +403,15 @@ //===----------------------------------------------------------------------===// namespace { +class DialectReader; + /// This struct represents a dialect entry within the bytecode. struct BytecodeDialect { /// Load the dialect into the provided context if it hasn't been loaded yet. /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(EncodingReader &reader, MLIRContext *ctx) { - if (dialect) - return success(); - Dialect *loadedDialect = ctx->getOrLoadDialect(name); - if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { - return reader.emitError( - "dialect '", name, - "' is unknown. If this is intended, please call " - "allowUnregisteredDialects() on the MLIRContext, or use " - "-allow-unregistered-dialect with the MLIR tool used."); - } - dialect = loadedDialect; - - // If the dialect was actually loaded, check to see if it has a bytecode - // interface. - if (loadedDialect) - interface = dyn_cast(loadedDialect); - return success(); - } + LogicalResult load(DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -446,6 +433,12 @@ /// The name of the dialect. StringRef name; + + /// A buffer containing the encoding of the dialect version parsed. + ArrayRef versionBuffer; + + /// Lazy loaded dialect version from the handle above. + std::unique_ptr loadedVersion; }; /// This struct represents an operation name entry within the bytecode. @@ -496,7 +489,7 @@ initialize(Location fileLoc, const ParserConfig &config, MutableArrayRef dialects, StringSectionReader &stringReader, ArrayRef sectionData, - ArrayRef offsetSectionData, + ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef); /// Parse a dialect resource handle from the resource section. @@ -643,7 +636,7 @@ Location fileLoc, const ParserConfig &config, MutableArrayRef dialects, StringSectionReader &stringReader, ArrayRef sectionData, - ArrayRef offsetSectionData, + ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef) { EncodingReader resourceReader(sectionData, fileLoc); EncodingReader offsetReader(offsetSectionData, fileLoc); @@ -684,7 +677,7 @@ while (!offsetReader.empty()) { BytecodeDialect *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(resourceReader, ctx))) + failed(dialect->load(dialectReader, ctx))) return failure(); Dialect *loadedDialect = dialect->getLoadedDialect(); if (!loadedDialect) { @@ -1051,7 +1044,8 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - if (failed(entry.dialect->load(reader, fileLoc.getContext()))) + DialectReader dialectReader(*this, stringReader, resourceReader, reader); + if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1060,12 +1054,23 @@ "' does not implement the bytecode interface"); } - // Ask the dialect to parse the entry. - DialectReader dialectReader(*this, stringReader, resourceReader, reader); - if constexpr (std::is_same_v) - entry.entry = entry.dialect->interface->readType(dialectReader); - else - entry.entry = entry.dialect->interface->readAttribute(dialectReader); + // Ask the dialect to parse the entry. If the dialect is versioned, parse + // using the versioned encoding readers. + if (entry.dialect->loadedVersion.get()) { + if constexpr (std::is_same_v) { + entry.entry = entry.dialect->interface->readType( + dialectReader, *entry.dialect->loadedVersion); + } else { + entry.entry = entry.dialect->interface->readAttribute( + dialectReader, *entry.dialect->loadedVersion); + } + } else { + if constexpr (std::is_same_v) { + entry.entry = entry.dialect->interface->readType(dialectReader); + } else { + entry.entry = entry.dialect->interface->readAttribute(dialectReader); + } + } return success(!!entry.entry); } @@ -1122,7 +1127,8 @@ // Resource Section LogicalResult - parseResourceSection(std::optional> resourceData, + parseResourceSection(EncodingReader &reader, + std::optional> resourceData, std::optional> resourceOffsetData); //===--------------------------------------------------------------------===// @@ -1306,7 +1312,7 @@ // Process the resource section if present. if (failed(parseResourceSection( - sectionDatas[bytecode::Section::kResource], + reader, sectionDatas[bytecode::Section::kResource], sectionDatas[bytecode::Section::kResourceOffset]))) return failure(); @@ -1326,7 +1332,8 @@ // Validate the bytecode version. uint64_t currentVersion = bytecode::kVersion; - if (version < currentVersion) { + uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; + if (version < minSupportedVersion) { return reader.emitError("bytecode version ", version, " is older than the current version of ", currentVersion, ", and upgrade is not supported"); @@ -1342,6 +1349,36 @@ //===----------------------------------------------------------------------===// // Dialect Section +LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { + if (dialect) + return success(); + Dialect *loadedDialect = ctx->getOrLoadDialect(name); + if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { + return reader.emitError("dialect '") + << name + << "' is unknown. If this is intended, please call " + "allowUnregisteredDialects() on the MLIRContext, or use " + "-allow-unregistered-dialect with the MLIR tool used."; + } + dialect = loadedDialect; + + // If the dialect was actually loaded, check to see if it has a bytecode + // interface. + if (loadedDialect) + interface = dyn_cast(loadedDialect); + if (!versionBuffer.empty()) { + if (!interface) + return reader.emitError("dialect '") + << name + << "' does not implement the bytecode interface, " + "but found a version entry"; + loadedVersion = interface->readVersion(reader); + if (!loadedVersion) + return failure(); + } + return success(); +} + LogicalResult BytecodeReader::parseDialectSection(ArrayRef sectionData) { EncodingReader sectionReader(sectionData, fileLoc); @@ -1353,9 +1390,21 @@ dialects.resize(numDialects); // Parse each of the dialects. - for (uint64_t i = 0; i < numDialects; ++i) + for (uint64_t i = 0; i < numDialects; ++i) { if (failed(stringReader.parseString(sectionReader, dialects[i].name))) return failure(); + /// Before version 1, there wasn't any versioning available for dialects. + if (version == 0) + continue; + bytecode::Section::ID sectionID; + if (failed( + sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) + return failure(); + if (sectionID != bytecode::Section::kDialectVersions) { + emitError(fileLoc, "expected dialect version section"); + return failure(); + } + } // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { @@ -1379,7 +1428,11 @@ // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. if (!opName->opName) { - if (failed(opName->dialect->load(reader, getContext()))) + // Load the dialect and its version. + EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc); + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + versionReader); + if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), getContext()); @@ -1391,7 +1444,7 @@ // Resource Section LogicalResult BytecodeReader::parseResourceSection( - std::optional> resourceData, + EncodingReader &reader, std::optional> resourceData, std::optional> resourceOffsetData) { // Ensure both sections are either present or not. if (resourceData.has_value() != resourceOffsetData.has_value()) { @@ -1408,9 +1461,11 @@ return success(); // Initialize the resource reader with the resource sections. + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, - bufferOwnerRef); + dialectReader, bufferOwnerRef); } //===----------------------------------------------------------------------===// @@ -1442,6 +1497,18 @@ "not all forward unresolved forward operand references"); } + // Resolve dialect version. + for (const BytecodeDialect &byteCodeDialect : dialects) { + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (!byteCodeDialect.loadedVersion) + continue; + if (byteCodeDialect.interface && + failed(byteCodeDialect.interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect.loadedVersion))) + return failure(); + } + // Verify that the parsed operations are valid. if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) return failure(); diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -10,13 +10,10 @@ #include "../Encoding.h" #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" -#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" -#include "llvm/Support/Debug.h" -#include #define DEBUG_TYPE "mlir-bytecode-writer" @@ -261,6 +258,116 @@ unsigned requiredAlignment = 1; }; +//===----------------------------------------------------------------------===// +// StringSectionBuilder +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to simplify the process of emitting the string section. +class StringSectionBuilder { +public: + /// Add the given string to the string section, and return the index of the + /// string within the section. + size_t insert(StringRef str) { + auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); + return it.first->second; + } + + /// Write the current set of strings to the given emitter. + void write(EncodingEmitter &emitter) { + emitter.emitVarInt(strings.size()); + + // Emit the sizes in reverse order, so that we don't need to backpatch an + // offset to the string data or have a separate section. + for (const auto &it : llvm::reverse(strings)) + emitter.emitVarInt(it.first.size() + 1); + // Emit the string data itself. + for (const auto &it : strings) + emitter.emitNulTerminatedString(it.first.val()); + } + +private: + /// A set of strings referenced within the bytecode. The value of the map is + /// unused. + llvm::MapVector strings; +}; +} // namespace + +class DialectWriter : public DialectBytecodeWriter { +public: + DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, + StringSectionBuilder &stringSection) + : emitter(emitter), numberingState(numberingState), + stringSection(stringSection) {} + + //===--------------------------------------------------------------------===// + // IR + //===--------------------------------------------------------------------===// + + void writeAttribute(Attribute attr) override { + emitter.emitVarInt(numberingState.getNumber(attr)); + } + void writeType(Type type) override { + emitter.emitVarInt(numberingState.getNumber(type)); + } + + void writeResourceHandle(const AsmDialectResourceHandle &resource) override { + emitter.emitVarInt(numberingState.getNumber(resource)); + } + + //===--------------------------------------------------------------------===// + // Primitives + //===--------------------------------------------------------------------===// + + void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } + + void writeSignedVarInt(int64_t value) override { + emitter.emitSignedVarInt(value); + } + + void writeAPIntWithKnownWidth(const APInt &value) override { + size_t bitWidth = value.getBitWidth(); + + // If the value is a single byte, just emit it directly without going + // through a varint. + if (bitWidth <= 8) + return emitter.emitByte(value.getLimitedValue()); + + // If the value fits within a single varint, emit it directly. + if (bitWidth <= 64) + return emitter.emitSignedVarInt(value.getLimitedValue()); + + // Otherwise, we need to encode a variable number of active words. We use + // active words instead of the number of total words under the observation + // that smaller values will be more common. + unsigned numActiveWords = value.getActiveWords(); + emitter.emitVarInt(numActiveWords); + + const uint64_t *rawValueData = value.getRawData(); + for (unsigned i = 0; i < numActiveWords; ++i) + emitter.emitSignedVarInt(rawValueData[i]); + } + + void writeAPFloatWithKnownSemantics(const APFloat &value) override { + writeAPIntWithKnownWidth(value.bitcastToAPInt()); + } + + void writeOwnedString(StringRef str) override { + emitter.emitVarInt(stringSection.insert(str)); + } + + void writeOwnedBlob(ArrayRef blob) override { + emitter.emitVarInt(blob.size()); + emitter.emitOwnedBlob(ArrayRef( + reinterpret_cast(blob.data()), blob.size())); + } + +private: + EncodingEmitter &emitter; + IRNumberingState &numberingState; + StringSectionBuilder &stringSection; +}; + /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need /// to go through an intermediate buffer when interacting with code that wants a /// raw_ostream. @@ -307,41 +414,6 @@ emitBytes({reinterpret_cast(&value), sizeof(value)}); } -//===----------------------------------------------------------------------===// -// StringSectionBuilder -//===----------------------------------------------------------------------===// - -namespace { -/// This class is used to simplify the process of emitting the string section. -class StringSectionBuilder { -public: - /// Add the given string to the string section, and return the index of the - /// string within the section. - size_t insert(StringRef str) { - auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); - return it.first->second; - } - - /// Write the current set of strings to the given emitter. - void write(EncodingEmitter &emitter) { - emitter.emitVarInt(strings.size()); - - // Emit the sizes in reverse order, so that we don't need to backpatch an - // offset to the string data or have a separate section. - for (const auto &it : llvm::reverse(strings)) - emitter.emitVarInt(it.first.size() + 1); - // Emit the string data itself. - for (const auto &it : strings) - emitter.emitNulTerminatedString(it.first.val()); - } - -private: - /// A set of strings referenced within the bytecode. The value of the map is - /// unused. - llvm::MapVector strings; -}; -} // namespace - //===----------------------------------------------------------------------===// // Bytecode Writer //===----------------------------------------------------------------------===// @@ -464,8 +536,18 @@ // Emit the referenced dialects. auto dialects = numberingState.getDialects(); dialectEmitter.emitVarInt(llvm::size(dialects)); - for (DialectNumbering &dialect : dialects) + for (DialectNumbering &dialect : dialects) { dialectEmitter.emitVarInt(stringSection.insert(dialect.name)); + EncodingEmitter versionEmitter; + if (dialect.interface) { + // The writer used when emitting using a custom bytecode encoding. + DialectWriter versionWriter(versionEmitter, numberingState, + stringSection); + dialect.interface->writeVersion(versionWriter); + } + dialectEmitter.emitSection(bytecode::Section::kDialectVersions, + std::move(versionEmitter)); + } // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { @@ -479,83 +561,6 @@ //===----------------------------------------------------------------------===// // Attributes and Types -namespace { -class DialectWriter : public DialectBytecodeWriter { -public: - DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, - StringSectionBuilder &stringSection) - : emitter(emitter), numberingState(numberingState), - stringSection(stringSection) {} - - //===--------------------------------------------------------------------===// - // IR - //===--------------------------------------------------------------------===// - - void writeAttribute(Attribute attr) override { - emitter.emitVarInt(numberingState.getNumber(attr)); - } - void writeType(Type type) override { - emitter.emitVarInt(numberingState.getNumber(type)); - } - - void writeResourceHandle(const AsmDialectResourceHandle &resource) override { - emitter.emitVarInt(numberingState.getNumber(resource)); - } - - //===--------------------------------------------------------------------===// - // Primitives - //===--------------------------------------------------------------------===// - - void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } - - void writeSignedVarInt(int64_t value) override { - emitter.emitSignedVarInt(value); - } - - void writeAPIntWithKnownWidth(const APInt &value) override { - size_t bitWidth = value.getBitWidth(); - - // If the value is a single byte, just emit it directly without going - // through a varint. - if (bitWidth <= 8) - return emitter.emitByte(value.getLimitedValue()); - - // If the value fits within a single varint, emit it directly. - if (bitWidth <= 64) - return emitter.emitSignedVarInt(value.getLimitedValue()); - - // Otherwise, we need to encode a variable number of active words. We use - // active words instead of the number of total words under the observation - // that smaller values will be more common. - unsigned numActiveWords = value.getActiveWords(); - emitter.emitVarInt(numActiveWords); - - const uint64_t *rawValueData = value.getRawData(); - for (unsigned i = 0; i < numActiveWords; ++i) - emitter.emitSignedVarInt(rawValueData[i]); - } - - void writeAPFloatWithKnownSemantics(const APFloat &value) override { - writeAPIntWithKnownWidth(value.bitcastToAPInt()); - } - - void writeOwnedString(StringRef str) override { - emitter.emitVarInt(stringSection.insert(str)); - } - - void writeOwnedBlob(ArrayRef blob) override { - emitter.emitVarInt(blob.size()); - emitter.emitOwnedBlob(ArrayRef( - reinterpret_cast(blob.data()), blob.size())); - } - -private: - EncodingEmitter &emitter; - IRNumberingState &numberingState; - StringSectionBuilder &stringSection; -}; -} // namespace - void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { EncodingEmitter attrTypeEmitter; EncodingEmitter offsetEmitter; diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version 0 +// VERSION: bytecode version 127 is newer than the current version 1 //===--------------------------------------------------------------------===// // Producer diff --git a/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc b/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-attr-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 +// CHECK1: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> () + +//===--------------------------------------------------------------------===// +// Test attribute upgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.0 +// COM: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-attr-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2 +// CHECK2: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> () diff --git a/mlir/test/Bytecode/versioning/versioned_op.mlir b/mlir/test/Bytecode/versioning/versioned_op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/versioning/versioned_op.mlir @@ -0,0 +1,41 @@ +// This file contains test cases related to the dialect post-parsing upgrade +// mechanism. + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: target=s390x-{{.*}} + +//===--------------------------------------------------------------------===// +// Test generic +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.0 +// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-op-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 +// CHECK1: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () + +//===--------------------------------------------------------------------===// +// Test upgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 1.12 +// COM: "test.versionedA"() {dimensions = 123 : i64} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2 +// CHECK2: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () + +//===--------------------------------------------------------------------===// +// Test forbidden downgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.2 +// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () +// COM: } +// RUN: not mlir-opt %S/versioned-op-2.2.mlirbc 2>&1 | FileCheck %s --check-prefix=ERR_NEW_VERSION +// ERR_NEW_VERSION: current test dialect version is 2.0, can't parse version: 2.2 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -10,15 +10,14 @@ #include "TestAttributes.h" #include "TestInterfaces.h" #include "TestTypes.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OperationSupport.h" @@ -32,9 +31,10 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" -#include +#include "llvm/Support/Endian.h" #include +#include // Include this before the using namespace lines below to // test that we don't have namespace dependencies. @@ -47,6 +47,15 @@ registry.insert(); } +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion : public DialectVersion { + uint32_t major = 2; + uint32_t minor = 0; +}; + //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -70,6 +79,107 @@ TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase; }; +namespace { +enum test_encoding { k_attr_params = 0 }; +} + +// Test support for interacting with the Bytecode reader/writer. +struct TestBytecodeDialectInterface : public BytecodeDialectInterface { + using BytecodeDialectInterface::BytecodeDialectInterface; + TestBytecodeDialectInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const final { + if (auto concreteAttr = llvm::dyn_cast(attr)) { + writer.writeVarInt(test_encoding::k_attr_params); + writer.writeVarInt(concreteAttr.getV0()); + writer.writeVarInt(concreteAttr.getV1()); + return success(); + } + writer.writeAttribute(attr); + return success(); + } + + Attribute readAttribute(DialectBytecodeReader &reader, + const DialectVersion &version_) const final { + const auto &version = static_cast(version_); + if (version.major < 2) + return readAttrOldEncoding(reader); + if (version.major == 2 && version.minor == 0) + return readAttrNewEncoding(reader); + // For deprecated syntax, return nullptr. + return Attribute(); + } + + // Emit a specific version of the dialect. + void writeVersion(DialectBytecodeWriter &writer) const final { + auto version = TestDialectVersion(); + writer.writeVarInt(version.major); // major + writer.writeVarInt(version.minor); // minor + } + + std::unique_ptr + readVersion(DialectBytecodeReader &reader) const final { + uint64_t major, minor; + if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor))) + return nullptr; + auto version = std::make_unique(); + version->major = major; + version->minor = minor; + return version; + } + + LogicalResult upgradeFromVersion(Operation *topLevelOp, + const DialectVersion &version_) const final { + const auto &version = static_cast(version_); + if ((version.major == 2) && (version.minor == 0)) + return success(); + if (version.major > 2 || (version.major == 2 && version.minor > 0)) { + return topLevelOp->emitError() + << "current test dialect version is 2.0, can't parse version: " + << version.major << "." << version.minor; + } + // Prior version 2.0, the old op supported only a single attribute called + // "dimensions". We can perform the upgrade. + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + op->setAttr("modifier", BoolAttr::get(op->getContext(), false)); + }); + return success(); + } + +private: + Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || + encoding != test_encoding::k_attr_params) + return Attribute(); + // The new encoding has v0 first, v1 second. + uint64_t v0, v1; + if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1))) + return Attribute(); + return TestAttrParamsAttr::get(getContext(), static_cast(v0), + static_cast(v1)); + } + + Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || + encoding != test_encoding::k_attr_params) + return Attribute(); + // The old encoding has v1 first, v0 second. + uint64_t v0, v1; + if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0))) + return Attribute(); + return TestAttrParamsAttr::get(getContext(), static_cast(v0), + static_cast(v1)); + } +}; + // Test support for interacting with the AsmPrinter. struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; @@ -367,7 +477,7 @@ addInterface(blobInterface); addInterfaces(); + TestReductionPatternInterface, TestBytecodeDialectInterface>(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific @@ -1103,9 +1213,7 @@ return getOperand(); } -OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { - return getValue(); -} +OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( FoldAdaptor adaptor, SmallVectorImpl &results) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3149,4 +3149,43 @@ }]; } +//===----------------------------------------------------------------------===// +// Test Ops to upgrade base on the dialect versions +//===----------------------------------------------------------------------===// + +def TestVersionedOpA : TEST_Op<"versionedA"> { + // A previous version of the dialect (let's say 1.*) supported an attribute + // named "dimensions": + // let arguments = (ins + // AnyI64Attr:$dimensions + // ); + + // In the current version (2.0) "dimensions" was renamed to "dims", and a new + // boolean attribute "modifier" was added. The previous version of the op + // corresponds to "modifier=false". We support loading old IR through + // upgrading, see `upgradeFromVersion()` in `TestBytecodeDialectInterface`. + let arguments = (ins + AnyI64Attr:$dims, + BoolAttr:$modifier + ); +} + +def TestVersionedOpB : TEST_Op<"versionedB"> { + // A previous version of the dialect (let's say 1.*) we encoded TestAttrParams + // with a custom encoding: + // + // #test.attr_params -> { varInt: Y, varInt: X } + // + // In the current version (2.0) the encoding changed and the two parameters of + // the attribute are swapped: + // + // #test.attr_params -> { varInt: X, varInt: Y } + // + // We support loading old IR through a custom readAttribute method, see + // `readAttribute()` in `TestBytecodeDialectInterface` + let arguments = (ins + TestAttrParams:$attribute + ); +} + #endif // TEST_OPS