diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
--- a/mlir/docs/BytecodeFormat.md
+++ b/mlir/docs/BytecodeFormat.md
@@ -6,7 +6,8 @@
## Magic Number
-MLIR uses the following four-byte magic number to indicate bytecode files:
+MLIR uses the following four-byte magic number to
+indicate bytecode files:
'\[‘M’8, ‘L’8, ‘ï’8, ‘R’8\]'
@@ -157,16 +158,25 @@
}
op_name_group {
- dialect: varint,
+ dialect: varint // (dialectID << 1) | (hasVersion),
+ version : dialect_version_section
numOpNames: varint,
opNames: varint[]
}
+
+dialect_version_section {
+ size: varint,
+ version: byte[]
+}
+
```
-Dialects are encoded as indexes to the name string within the string section.
-Operation names are encoded in groups by dialect, with each group containing the
-dialect, the number of operation names, and the array of indexes to each name
-within the string section.
+Dialects are encoded as a `varint` containing the index to the name string
+within the string section, plus a flag indicating whether the dialect is
+versioned. Operation names are encoded in groups by dialect, with each group
+containing the dialect, the number of operation names, and the array of indexes
+to each name within the string section. The version is encoded as a nested
+section.
### Attribute/Type Sections
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -845,3 +845,18 @@
that are directly usable by any other dialect in MLIR. These types cover a range
from primitive integer and floating-point values, attribute dictionaries, dense
multi-dimensional arrays, and more.
+
+### IR Versionning
+
+A dialect can opt-in to handle versioning through the
+`BytecodeDialectInterface`. Few hooks are exposed to the dialect to allow
+managing a version encoded into the bytecode file. The version is loaded lazily
+and allows to retrieve the version information while parsing the input IR, and
+gives an opportunity to each dialect for which a version is present to perform
+IR upgrades post-parsing through the `upgradeFromVersion` method. Custom
+Attribute and Type encodings can also be upgraded according to the dialect
+version using readAttribute and readType methods.
+
+There is no restriction on what kind of information a dialect is allowed to
+encode to model its versioning. Currently, versioning is supported only for
+bytecode formats.
diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
--- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h
+++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h
@@ -235,6 +235,17 @@
virtual void writeOwnedBlob(ArrayRef blob) = 0;
};
+//===--------------------------------------------------------------------===//
+// Dialect Version Interface.
+//===--------------------------------------------------------------------===//
+
+/// This class is used to represent the version of a dialect, for the purpose
+/// of polymorphic destruction.
+class DialectVersion {
+public:
+ virtual ~DialectVersion() = default;
+};
+
//===----------------------------------------------------------------------===//
// BytecodeDialectInterface
//===----------------------------------------------------------------------===//
@@ -256,6 +267,19 @@
return Attribute();
}
+ /// Read a versioned attribute encoding belonging to this dialect from the
+ /// given reader. This method should return null in the case of failure, and
+ /// falls back to the non-versioned reader in case the dialect implements
+ /// versioning but it does not support versioned custom encodings for the
+ /// attributes.
+ virtual Attribute readAttribute(DialectBytecodeReader &reader,
+ const DialectVersion &version) const {
+ reader.emitError()
+ << "dialect " << getDialect()->getNamespace()
+ << " does not support reading versioned attributes from bytecode";
+ return Attribute();
+ }
+
/// Read a type belonging to this dialect from the given reader. This method
/// should return null in the case of failure.
virtual Type readType(DialectBytecodeReader &reader) const {
@@ -264,6 +288,19 @@
return Type();
}
+ /// Read a versioned type encoding belonging to this dialect from the given
+ /// reader. This method should return null in the case of failure, and
+ /// falls back to the non-versioned reader in case the dialect implements
+ /// versioning but it does not support versioned custom encodings for the
+ /// types.
+ virtual Type readType(DialectBytecodeReader &reader,
+ const DialectVersion &version) const {
+ reader.emitError()
+ << "dialect " << getDialect()->getNamespace()
+ << " does not support reading versioned types from bytecode";
+ return Type();
+ }
+
//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//
@@ -285,6 +322,27 @@
DialectBytecodeWriter &writer) const {
return failure();
}
+
+ /// Write the version of this dialect to the given writer.
+ virtual void writeVersion(DialectBytecodeWriter &writer) const {}
+
+ // Read the version of this dialect from the provided reader and return it as
+ // a `unique_ptr` to a dialect version object.
+ virtual std::unique_ptr
+ readVersion(DialectBytecodeReader &reader) const {
+ reader.emitError("Dialect does not support versioning");
+ return nullptr;
+ }
+
+ /// Hook invoked after parsing completed, if a version directive was present
+ /// and included an entry for the current dialect. This hook offers the
+ /// opportunity to the dialect to visit the IR and upgrades constructs emitted
+ /// by the version of the dialect corresponding to the provided version.
+ virtual LogicalResult
+ upgradeFromVersion(Operation *topLevelOp,
+ const DialectVersion &version) const {
+ return success();
+ }
};
} // namespace mlir
diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h
--- a/mlir/lib/Bytecode/Encoding.h
+++ b/mlir/lib/Bytecode/Encoding.h
@@ -23,8 +23,11 @@
//===----------------------------------------------------------------------===//
enum {
+ /// The minimum supported version of the bytecode.
+ kMinSupportedVersion = 0,
+
/// The current bytecode version.
- kVersion = 0,
+ kVersion = 1,
/// An arbitrary value used to fill alignment padding.
kAlignmentByte = 0xCB,
@@ -61,8 +64,11 @@
/// section.
kResourceOffset = 6,
+ /// This section contains the versions of each dialect.
+ kDialectVersions = 7,
+
/// The total number of section types.
- kNumSections = 7,
+ kNumSections = 8,
};
} // namespace Section
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
--- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -47,6 +47,8 @@
return "Resource (5)";
case bytecode::Section::kResourceOffset:
return "ResourceOffset (6)";
+ case bytecode::Section::kDialectVersions:
+ return "DialectVersions (7)";
default:
return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str();
}
@@ -63,6 +65,7 @@
return false;
case bytecode::Section::kResource:
case bytecode::Section::kResourceOffset:
+ case bytecode::Section::kDialectVersions:
return true;
default:
llvm_unreachable("unknown section ID");
@@ -350,6 +353,13 @@
return parseEntry(reader, strings, result, "string");
}
+ /// Parse a shared string from the string section. The shared string is
+ /// encoded using an index to a corresponding string in the string section.
+ LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index,
+ StringRef &result) {
+ return resolveEntry(reader, strings, index, result, "string");
+ }
+
private:
/// The table of strings referenced within the bytecode file.
SmallVector strings;
@@ -400,31 +410,15 @@
//===----------------------------------------------------------------------===//
namespace {
+class DialectReader;
+
/// This struct represents a dialect entry within the bytecode.
struct BytecodeDialect {
/// Load the dialect into the provided context if it hasn't been loaded yet.
/// Returns failure if the dialect couldn't be loaded *and* the provided
/// context does not allow unregistered dialects. The provided reader is used
/// for error emission if necessary.
- LogicalResult load(EncodingReader &reader, MLIRContext *ctx) {
- if (dialect)
- return success();
- Dialect *loadedDialect = ctx->getOrLoadDialect(name);
- if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
- return reader.emitError(
- "dialect '", name,
- "' is unknown. If this is intended, please call "
- "allowUnregisteredDialects() on the MLIRContext, or use "
- "-allow-unregistered-dialect with the MLIR tool used.");
- }
- dialect = loadedDialect;
-
- // If the dialect was actually loaded, check to see if it has a bytecode
- // interface.
- if (loadedDialect)
- interface = dyn_cast(loadedDialect);
- return success();
- }
+ LogicalResult load(DialectReader &reader, MLIRContext *ctx);
/// Return the loaded dialect, or nullptr if the dialect is unknown. This can
/// only be called after `load`.
@@ -446,6 +440,12 @@
/// The name of the dialect.
StringRef name;
+
+ /// A buffer containing the encoding of the dialect version parsed.
+ ArrayRef versionBuffer;
+
+ /// Lazy loaded dialect version from the handle above.
+ std::unique_ptr loadedVersion;
};
/// This struct represents an operation name entry within the bytecode.
@@ -496,7 +496,7 @@
initialize(Location fileLoc, const ParserConfig &config,
MutableArrayRef dialects,
StringSectionReader &stringReader, ArrayRef sectionData,
- ArrayRef offsetSectionData,
+ ArrayRef offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr &bufferOwnerRef);
/// Parse a dialect resource handle from the resource section.
@@ -643,7 +643,7 @@
Location fileLoc, const ParserConfig &config,
MutableArrayRef dialects,
StringSectionReader &stringReader, ArrayRef sectionData,
- ArrayRef offsetSectionData,
+ ArrayRef offsetSectionData, DialectReader &dialectReader,
const std::shared_ptr &bufferOwnerRef) {
EncodingReader resourceReader(sectionData, fileLoc);
EncodingReader offsetReader(offsetSectionData, fileLoc);
@@ -684,7 +684,7 @@
while (!offsetReader.empty()) {
BytecodeDialect *dialect;
if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) ||
- failed(dialect->load(resourceReader, ctx)))
+ failed(dialect->load(dialectReader, ctx)))
return failure();
Dialect *loadedDialect = dialect->getLoadedDialect();
if (!loadedDialect) {
@@ -1051,7 +1051,8 @@
LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry,
EncodingReader &reader,
StringRef entryType) {
- if (failed(entry.dialect->load(reader, fileLoc.getContext())))
+ DialectReader dialectReader(*this, stringReader, resourceReader, reader);
+ if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
return failure();
// Ensure that the dialect implements the bytecode interface.
@@ -1060,12 +1061,22 @@
"' does not implement the bytecode interface");
}
- // Ask the dialect to parse the entry.
- DialectReader dialectReader(*this, stringReader, resourceReader, reader);
- if constexpr (std::is_same_v)
- entry.entry = entry.dialect->interface->readType(dialectReader);
- else
- entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+ // Ask the dialect to parse the entry. If the dialect is versioned, parse
+ // using the versioned encoding readers.
+ if (entry.dialect->loadedVersion.get()) {
+ if constexpr (std::is_same_v)
+ entry.entry = entry.dialect->interface->readType(
+ dialectReader, *entry.dialect->loadedVersion);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(
+ dialectReader, *entry.dialect->loadedVersion);
+
+ } else {
+ if constexpr (std::is_same_v)
+ entry.entry = entry.dialect->interface->readType(dialectReader);
+ else
+ entry.entry = entry.dialect->interface->readAttribute(dialectReader);
+ }
return success(!!entry.entry);
}
@@ -1122,7 +1133,8 @@
// Resource Section
LogicalResult
- parseResourceSection(std::optional> resourceData,
+ parseResourceSection(EncodingReader &reader,
+ std::optional> resourceData,
std::optional> resourceOffsetData);
//===--------------------------------------------------------------------===//
@@ -1306,7 +1318,7 @@
// Process the resource section if present.
if (failed(parseResourceSection(
- sectionDatas[bytecode::Section::kResource],
+ reader, sectionDatas[bytecode::Section::kResource],
sectionDatas[bytecode::Section::kResourceOffset])))
return failure();
@@ -1326,7 +1338,8 @@
// Validate the bytecode version.
uint64_t currentVersion = bytecode::kVersion;
- if (version < currentVersion) {
+ uint64_t minSupportedVersion = bytecode::kMinSupportedVersion;
+ if (version < minSupportedVersion) {
return reader.emitError("bytecode version ", version,
" is older than the current version of ",
currentVersion, ", and upgrade is not supported");
@@ -1342,6 +1355,36 @@
//===----------------------------------------------------------------------===//
// Dialect Section
+LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) {
+ if (dialect)
+ return success();
+ Dialect *loadedDialect = ctx->getOrLoadDialect(name);
+ if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
+ return reader.emitError("dialect '")
+ << name
+ << "' is unknown. If this is intended, please call "
+ "allowUnregisteredDialects() on the MLIRContext, or use "
+ "-allow-unregistered-dialect with the MLIR tool used.";
+ }
+ dialect = loadedDialect;
+
+ // If the dialect was actually loaded, check to see if it has a bytecode
+ // interface.
+ if (loadedDialect)
+ interface = dyn_cast(loadedDialect);
+ if (!versionBuffer.empty()) {
+ if (!interface)
+ return reader.emitError("dialect '")
+ << name
+ << "' does not implement the bytecode interface, "
+ "but found a version entry";
+ loadedVersion = interface->readVersion(reader);
+ if (!loadedVersion)
+ return failure();
+ }
+ return success();
+}
+
LogicalResult
BytecodeReader::parseDialectSection(ArrayRef sectionData) {
EncodingReader sectionReader(sectionData, fileLoc);
@@ -1353,9 +1396,34 @@
dialects.resize(numDialects);
// Parse each of the dialects.
- for (uint64_t i = 0; i < numDialects; ++i)
- if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
+ for (uint64_t i = 0; i < numDialects; ++i) {
+ /// Before version 1, there wasn't any versioning available for dialects,
+ /// and the entryIdx represent the string itself.
+ if (version == 0) {
+ if (failed(stringReader.parseString(sectionReader, dialects[i].name)))
+ return failure();
+ continue;
+ }
+ // Parse ID representing dialect and version.
+ uint64_t dialectNameIdx;
+ bool versionAvailable;
+ if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx,
+ versionAvailable)))
+ return failure();
+ if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx,
+ dialects[i].name)))
return failure();
+ if (versionAvailable) {
+ bytecode::Section::ID sectionID;
+ if (failed(
+ sectionReader.parseSection(sectionID, dialects[i].versionBuffer)))
+ return failure();
+ if (sectionID != bytecode::Section::kDialectVersions) {
+ emitError(fileLoc, "expected dialect version section");
+ return failure();
+ }
+ }
+ }
// Parse the operation names, which are grouped by dialect.
auto parseOpName = [&](BytecodeDialect *dialect) {
@@ -1379,7 +1447,11 @@
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
if (!opName->opName) {
- if (failed(opName->dialect->load(reader, getContext())))
+ // Load the dialect and its version.
+ EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc);
+ DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
+ versionReader);
+ if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
getContext());
@@ -1391,7 +1463,7 @@
// Resource Section
LogicalResult BytecodeReader::parseResourceSection(
- std::optional> resourceData,
+ EncodingReader &reader, std::optional> resourceData,
std::optional> resourceOffsetData) {
// Ensure both sections are either present or not.
if (resourceData.has_value() != resourceOffsetData.has_value()) {
@@ -1408,9 +1480,11 @@
return success();
// Initialize the resource reader with the resource sections.
+ DialectReader dialectReader(attrTypeReader, stringReader, resourceReader,
+ reader);
return resourceReader.initialize(fileLoc, config, dialects, stringReader,
*resourceData, *resourceOffsetData,
- bufferOwnerRef);
+ dialectReader, bufferOwnerRef);
}
//===----------------------------------------------------------------------===//
@@ -1442,6 +1516,18 @@
"not all forward unresolved forward operand references");
}
+ // Resolve dialect version.
+ for (const BytecodeDialect &byteCodeDialect : dialects) {
+ // Parsing is complete, give an opportunity to each dialect to visit the
+ // IR and perform upgrades.
+ if (!byteCodeDialect.loadedVersion)
+ continue;
+ if (byteCodeDialect.interface &&
+ failed(byteCodeDialect.interface->upgradeFromVersion(
+ *moduleOp, *byteCodeDialect.loadedVersion)))
+ return failure();
+ }
+
// Verify that the parsed operations are valid.
if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp)))
return failure();
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -10,13 +10,10 @@
#include "../Encoding.h"
#include "IRNumbering.h"
#include "mlir/Bytecode/BytecodeImplementation.h"
-#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/CachedHashString.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
-#include "llvm/Support/Debug.h"
-#include
#define DEBUG_TYPE "mlir-bytecode-writer"
@@ -261,6 +258,116 @@
unsigned requiredAlignment = 1;
};
+//===----------------------------------------------------------------------===//
+// StringSectionBuilder
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class is used to simplify the process of emitting the string section.
+class StringSectionBuilder {
+public:
+ /// Add the given string to the string section, and return the index of the
+ /// string within the section.
+ size_t insert(StringRef str) {
+ auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()});
+ return it.first->second;
+ }
+
+ /// Write the current set of strings to the given emitter.
+ void write(EncodingEmitter &emitter) {
+ emitter.emitVarInt(strings.size());
+
+ // Emit the sizes in reverse order, so that we don't need to backpatch an
+ // offset to the string data or have a separate section.
+ for (const auto &it : llvm::reverse(strings))
+ emitter.emitVarInt(it.first.size() + 1);
+ // Emit the string data itself.
+ for (const auto &it : strings)
+ emitter.emitNulTerminatedString(it.first.val());
+ }
+
+private:
+ /// A set of strings referenced within the bytecode. The value of the map is
+ /// unused.
+ llvm::MapVector strings;
+};
+} // namespace
+
+class DialectWriter : public DialectBytecodeWriter {
+public:
+ DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState,
+ StringSectionBuilder &stringSection)
+ : emitter(emitter), numberingState(numberingState),
+ stringSection(stringSection) {}
+
+ //===--------------------------------------------------------------------===//
+ // IR
+ //===--------------------------------------------------------------------===//
+
+ void writeAttribute(Attribute attr) override {
+ emitter.emitVarInt(numberingState.getNumber(attr));
+ }
+ void writeType(Type type) override {
+ emitter.emitVarInt(numberingState.getNumber(type));
+ }
+
+ void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
+ emitter.emitVarInt(numberingState.getNumber(resource));
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Primitives
+ //===--------------------------------------------------------------------===//
+
+ void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); }
+
+ void writeSignedVarInt(int64_t value) override {
+ emitter.emitSignedVarInt(value);
+ }
+
+ void writeAPIntWithKnownWidth(const APInt &value) override {
+ size_t bitWidth = value.getBitWidth();
+
+ // If the value is a single byte, just emit it directly without going
+ // through a varint.
+ if (bitWidth <= 8)
+ return emitter.emitByte(value.getLimitedValue());
+
+ // If the value fits within a single varint, emit it directly.
+ if (bitWidth <= 64)
+ return emitter.emitSignedVarInt(value.getLimitedValue());
+
+ // Otherwise, we need to encode a variable number of active words. We use
+ // active words instead of the number of total words under the observation
+ // that smaller values will be more common.
+ unsigned numActiveWords = value.getActiveWords();
+ emitter.emitVarInt(numActiveWords);
+
+ const uint64_t *rawValueData = value.getRawData();
+ for (unsigned i = 0; i < numActiveWords; ++i)
+ emitter.emitSignedVarInt(rawValueData[i]);
+ }
+
+ void writeAPFloatWithKnownSemantics(const APFloat &value) override {
+ writeAPIntWithKnownWidth(value.bitcastToAPInt());
+ }
+
+ void writeOwnedString(StringRef str) override {
+ emitter.emitVarInt(stringSection.insert(str));
+ }
+
+ void writeOwnedBlob(ArrayRef blob) override {
+ emitter.emitVarInt(blob.size());
+ emitter.emitOwnedBlob(ArrayRef(
+ reinterpret_cast(blob.data()), blob.size()));
+ }
+
+private:
+ EncodingEmitter &emitter;
+ IRNumberingState &numberingState;
+ StringSectionBuilder &stringSection;
+};
+
/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need
/// to go through an intermediate buffer when interacting with code that wants a
/// raw_ostream.
@@ -307,41 +414,6 @@
emitBytes({reinterpret_cast(&value), sizeof(value)});
}
-//===----------------------------------------------------------------------===//
-// StringSectionBuilder
-//===----------------------------------------------------------------------===//
-
-namespace {
-/// This class is used to simplify the process of emitting the string section.
-class StringSectionBuilder {
-public:
- /// Add the given string to the string section, and return the index of the
- /// string within the section.
- size_t insert(StringRef str) {
- auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()});
- return it.first->second;
- }
-
- /// Write the current set of strings to the given emitter.
- void write(EncodingEmitter &emitter) {
- emitter.emitVarInt(strings.size());
-
- // Emit the sizes in reverse order, so that we don't need to backpatch an
- // offset to the string data or have a separate section.
- for (const auto &it : llvm::reverse(strings))
- emitter.emitVarInt(it.first.size() + 1);
- // Emit the string data itself.
- for (const auto &it : strings)
- emitter.emitNulTerminatedString(it.first.val());
- }
-
-private:
- /// A set of strings referenced within the bytecode. The value of the map is
- /// unused.
- llvm::MapVector strings;
-};
-} // namespace
-
//===----------------------------------------------------------------------===//
// Bytecode Writer
//===----------------------------------------------------------------------===//
@@ -464,8 +536,28 @@
// Emit the referenced dialects.
auto dialects = numberingState.getDialects();
dialectEmitter.emitVarInt(llvm::size(dialects));
- for (DialectNumbering &dialect : dialects)
- dialectEmitter.emitVarInt(stringSection.insert(dialect.name));
+ for (DialectNumbering &dialect : dialects) {
+ // Write the string section and get the ID.
+ size_t nameID = stringSection.insert(dialect.name);
+
+ // Try writing the version to the versionEmitter.
+ EncodingEmitter versionEmitter;
+ if (dialect.interface) {
+ // The writer used when emitting using a custom bytecode encoding.
+ DialectWriter versionWriter(versionEmitter, numberingState,
+ stringSection);
+ dialect.interface->writeVersion(versionWriter);
+ }
+
+ // If the version emitter is empty, version is not available. We can encode
+ // this in the dialect ID, so if there is no version, we don't write the
+ // section.
+ size_t versionAvailable = versionEmitter.size() > 0;
+ dialectEmitter.emitVarIntWithFlag(nameID, versionAvailable);
+ if (versionAvailable)
+ dialectEmitter.emitSection(bytecode::Section::kDialectVersions,
+ std::move(versionEmitter));
+ }
// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
@@ -479,83 +571,6 @@
//===----------------------------------------------------------------------===//
// Attributes and Types
-namespace {
-class DialectWriter : public DialectBytecodeWriter {
-public:
- DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState,
- StringSectionBuilder &stringSection)
- : emitter(emitter), numberingState(numberingState),
- stringSection(stringSection) {}
-
- //===--------------------------------------------------------------------===//
- // IR
- //===--------------------------------------------------------------------===//
-
- void writeAttribute(Attribute attr) override {
- emitter.emitVarInt(numberingState.getNumber(attr));
- }
- void writeType(Type type) override {
- emitter.emitVarInt(numberingState.getNumber(type));
- }
-
- void writeResourceHandle(const AsmDialectResourceHandle &resource) override {
- emitter.emitVarInt(numberingState.getNumber(resource));
- }
-
- //===--------------------------------------------------------------------===//
- // Primitives
- //===--------------------------------------------------------------------===//
-
- void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); }
-
- void writeSignedVarInt(int64_t value) override {
- emitter.emitSignedVarInt(value);
- }
-
- void writeAPIntWithKnownWidth(const APInt &value) override {
- size_t bitWidth = value.getBitWidth();
-
- // If the value is a single byte, just emit it directly without going
- // through a varint.
- if (bitWidth <= 8)
- return emitter.emitByte(value.getLimitedValue());
-
- // If the value fits within a single varint, emit it directly.
- if (bitWidth <= 64)
- return emitter.emitSignedVarInt(value.getLimitedValue());
-
- // Otherwise, we need to encode a variable number of active words. We use
- // active words instead of the number of total words under the observation
- // that smaller values will be more common.
- unsigned numActiveWords = value.getActiveWords();
- emitter.emitVarInt(numActiveWords);
-
- const uint64_t *rawValueData = value.getRawData();
- for (unsigned i = 0; i < numActiveWords; ++i)
- emitter.emitSignedVarInt(rawValueData[i]);
- }
-
- void writeAPFloatWithKnownSemantics(const APFloat &value) override {
- writeAPIntWithKnownWidth(value.bitcastToAPInt());
- }
-
- void writeOwnedString(StringRef str) override {
- emitter.emitVarInt(stringSection.insert(str));
- }
-
- void writeOwnedBlob(ArrayRef blob) override {
- emitter.emitVarInt(blob.size());
- emitter.emitOwnedBlob(ArrayRef(
- reinterpret_cast(blob.data()), blob.size()));
- }
-
-private:
- EncodingEmitter &emitter;
- IRNumberingState &numberingState;
- StringSectionBuilder &stringSection;
-};
-} // namespace
-
void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
EncodingEmitter attrTypeEmitter;
EncodingEmitter offsetEmitter;
diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir
--- a/mlir/test/Bytecode/invalid/invalid-structure.mlir
+++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir
@@ -9,7 +9,7 @@
//===--------------------------------------------------------------------===//
// RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION
-// VERSION: bytecode version 127 is newer than the current version 0
+// VERSION: bytecode version 127 is newer than the current version 1
//===--------------------------------------------------------------------===//
// Producer
diff --git a/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc b/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc
new file mode 100644
index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000
GIT binary patch
literal 0
Hc$@} : () -> ()
+// COM: }
+// RUN: mlir-opt %S/versioned-attr-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1
+// CHECK1: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
+
+//===--------------------------------------------------------------------===//
+// Test attribute upgrade
+//===--------------------------------------------------------------------===//
+
+// COM: bytecode contains
+// COM: module {
+// COM: version: 2.0
+// COM: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
+// COM: }
+// RUN: mlir-opt %S/versioned-attr-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2
+// CHECK2: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> ()
diff --git a/mlir/test/Bytecode/versioning/versioned_op.mlir b/mlir/test/Bytecode/versioning/versioned_op.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bytecode/versioning/versioned_op.mlir
@@ -0,0 +1,41 @@
+// This file contains test cases related to the dialect post-parsing upgrade
+// mechanism.
+
+// Bytecode currently does not support big-endian platforms
+// UNSUPPORTED: target=s390x-{{.*}}
+
+//===--------------------------------------------------------------------===//
+// Test generic
+//===--------------------------------------------------------------------===//
+
+// COM: bytecode contains
+// COM: module {
+// COM: version: 2.0
+// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+// COM: }
+// RUN: mlir-opt %S/versioned-op-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1
+// CHECK1: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+
+//===--------------------------------------------------------------------===//
+// Test upgrade
+//===--------------------------------------------------------------------===//
+
+// COM: bytecode contains
+// COM: module {
+// COM: version: 1.12
+// COM: "test.versionedA"() {dimensions = 123 : i64} : () -> ()
+// COM: }
+// RUN: mlir-opt %S/versioned-op-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2
+// CHECK2: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+
+//===--------------------------------------------------------------------===//
+// Test forbidden downgrade
+//===--------------------------------------------------------------------===//
+
+// COM: bytecode contains
+// COM: module {
+// COM: version: 2.2
+// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> ()
+// COM: }
+// RUN: not mlir-opt %S/versioned-op-2.2.mlirbc 2>&1 | FileCheck %s --check-prefix=ERR_NEW_VERSION
+// ERR_NEW_VERSION: current test dialect version is 2.0, can't parse version: 2.2
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -10,15 +10,14 @@
#include "TestAttributes.h"
#include "TestInterfaces.h"
#include "TestTypes.h"
+#include "mlir/Bytecode/BytecodeImplementation.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
-#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OperationSupport.h"
@@ -32,9 +31,9 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSwitch.h"
-#include
#include
+#include
// Include this before the using namespace lines below to
// test that we don't have namespace dependencies.
@@ -47,6 +46,15 @@
registry.insert();
}
+//===----------------------------------------------------------------------===//
+// TestDialect version utilities
+//===----------------------------------------------------------------------===//
+
+struct TestDialectVersion : public DialectVersion {
+ uint32_t major = 2;
+ uint32_t minor = 0;
+};
+
//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
@@ -70,6 +78,107 @@
TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
};
+namespace {
+enum test_encoding { k_attr_params = 0 };
+}
+
+// Test support for interacting with the Bytecode reader/writer.
+struct TestBytecodeDialectInterface : public BytecodeDialectInterface {
+ using BytecodeDialectInterface::BytecodeDialectInterface;
+ TestBytecodeDialectInterface(Dialect *dialect)
+ : BytecodeDialectInterface(dialect) {}
+
+ LogicalResult writeAttribute(Attribute attr,
+ DialectBytecodeWriter &writer) const final {
+ if (auto concreteAttr = llvm::dyn_cast(attr)) {
+ writer.writeVarInt(test_encoding::k_attr_params);
+ writer.writeVarInt(concreteAttr.getV0());
+ writer.writeVarInt(concreteAttr.getV1());
+ return success();
+ }
+ writer.writeAttribute(attr);
+ return success();
+ }
+
+ Attribute readAttribute(DialectBytecodeReader &reader,
+ const DialectVersion &version_) const final {
+ const auto &version = static_cast(version_);
+ if (version.major < 2)
+ return readAttrOldEncoding(reader);
+ if (version.major == 2 && version.minor == 0)
+ return readAttrNewEncoding(reader);
+ // Forbid reading future versions by returning nullptr.
+ return Attribute();
+ }
+
+ // Emit a specific version of the dialect.
+ void writeVersion(DialectBytecodeWriter &writer) const final {
+ auto version = TestDialectVersion();
+ writer.writeVarInt(version.major); // major
+ writer.writeVarInt(version.minor); // minor
+ }
+
+ std::unique_ptr
+ readVersion(DialectBytecodeReader &reader) const final {
+ uint64_t major, minor;
+ if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor)))
+ return nullptr;
+ auto version = std::make_unique();
+ version->major = major;
+ version->minor = minor;
+ return version;
+ }
+
+ LogicalResult upgradeFromVersion(Operation *topLevelOp,
+ const DialectVersion &version_) const final {
+ const auto &version = static_cast(version_);
+ if ((version.major == 2) && (version.minor == 0))
+ return success();
+ if (version.major > 2 || (version.major == 2 && version.minor > 0)) {
+ return topLevelOp->emitError()
+ << "current test dialect version is 2.0, can't parse version: "
+ << version.major << "." << version.minor;
+ }
+ // Prior version 2.0, the old op supported only a single attribute called
+ // "dimensions". We can perform the upgrade.
+ topLevelOp->walk([](TestVersionedOpA op) {
+ if (auto dims = op->getAttr("dimensions")) {
+ op->removeAttr("dimensions");
+ op->setAttr("dims", dims);
+ }
+ op->setAttr("modifier", BoolAttr::get(op->getContext(), false));
+ });
+ return success();
+ }
+
+private:
+ Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const {
+ uint64_t encoding;
+ if (failed(reader.readVarInt(encoding)) ||
+ encoding != test_encoding::k_attr_params)
+ return Attribute();
+ // The new encoding has v0 first, v1 second.
+ uint64_t v0, v1;
+ if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1)))
+ return Attribute();
+ return TestAttrParamsAttr::get(getContext(), static_cast(v0),
+ static_cast(v1));
+ }
+
+ Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const {
+ uint64_t encoding;
+ if (failed(reader.readVarInt(encoding)) ||
+ encoding != test_encoding::k_attr_params)
+ return Attribute();
+ // The old encoding has v1 first, v0 second.
+ uint64_t v0, v1;
+ if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0)))
+ return Attribute();
+ return TestAttrParamsAttr::get(getContext(), static_cast(v0),
+ static_cast(v1));
+ }
+};
+
// Test support for interacting with the AsmPrinter.
struct TestOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
@@ -367,7 +476,7 @@
addInterface(blobInterface);
addInterfaces();
+ TestReductionPatternInterface, TestBytecodeDialectInterface>();
allowUnknownOperations();
// Instantiate our fallback op interface that we'll use on specific
@@ -1103,9 +1212,7 @@
return getOperand();
}
-OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) {
- return getValue();
-}
+OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); }
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
FoldAdaptor adaptor, SmallVectorImpl &results) {
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -3149,4 +3149,43 @@
}];
}
+//===----------------------------------------------------------------------===//
+// Test Ops to upgrade base on the dialect versions
+//===----------------------------------------------------------------------===//
+
+def TestVersionedOpA : TEST_Op<"versionedA"> {
+ // A previous version of the dialect (let's say 1.*) supported an attribute
+ // named "dimensions":
+ // let arguments = (ins
+ // AnyI64Attr:$dimensions
+ // );
+
+ // In the current version (2.0) "dimensions" was renamed to "dims", and a new
+ // boolean attribute "modifier" was added. The previous version of the op
+ // corresponds to "modifier=false". We support loading old IR through
+ // upgrading, see `upgradeFromVersion()` in `TestBytecodeDialectInterface`.
+ let arguments = (ins
+ AnyI64Attr:$dims,
+ BoolAttr:$modifier
+ );
+}
+
+def TestVersionedOpB : TEST_Op<"versionedB"> {
+ // A previous version of the dialect (let's say 1.*) we encoded TestAttrParams
+ // with a custom encoding:
+ //
+ // #test.attr_params -> { varInt: Y, varInt: X }
+ //
+ // In the current version (2.0) the encoding changed and the two parameters of
+ // the attribute are swapped:
+ //
+ // #test.attr_params -> { varInt: X, varInt: Y }
+ //
+ // We support loading old IR through a custom readAttribute method, see
+ // `readAttribute()` in `TestBytecodeDialectInterface`
+ let arguments = (ins
+ TestAttrParams:$attribute
+ );
+}
+
#endif // TEST_OPS