diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -158,15 +158,22 @@ op_name_group { dialect: varint, + version : dialect_version_section numOpNames: varint, opNames: varint[] } + +dialect_version_section { + size: varint, + version: byte[] +} + ``` Dialects are encoded as indexes to the name string within the string section. Operation names are encoded in groups by dialect, with each group containing the dialect, the number of operation names, and the array of indexes to each name -within the string section. +within the string section. The version is encoded as a nested section. ### Attribute/Type Sections @@ -403,24 +410,3 @@ A block is encoded with an array of operations and block arguments. The first field is an encoding that combines the number of operations in the block, with a flag indicating if the block has arguments. - -### Dialect Version Section - -``` -dialect_version_section { - dialects: dialect_info[] -} - -dialect_info { - dialect: varint, - size: varint, - version: byte[] -} -``` - -The dialect version section contains details about the version for each of the -versioned dialects contained in the module. A dialect is considered versioned if -implements the OpAsmDialectInterface hooks as described in the MLIR Language -Reference. The dialect info section is optional and written only if the dialect -is registered in the context. When version information for a versioned dialect -are missing, a dialect will be parsed without performing upgrades. \ No newline at end of file diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -848,15 +848,15 @@ ### IR Versionning -A dialect can opt-in to handle versioning in a custom way. Two hooks to expose a -dialect version to the printer and parser are available through the -`OpAsmDialectInterface`. First, the `getProducerVersion()` method allows to -inject a target dialect version to the printer, which will be written to the -output mlir file. Second, the `upgradeFromVersion()` method allows to retrieve -the version information while parsing the input IR, and gives an opportunity to -each dialect for which a version is present to perform IR upgrades. - -Version information are stored on a buffer through an `AsmDialectVersionHandle`. +A dialect can opt-in to handle versioning through the +`BytecodeDialectInterface`. Few hooks are exposed to the dialect to allow +managing a version encoded into the bytecode file. The version is loaded lazily +and allows to retrieve the version information while parsing the input IR, and +gives an opportunity to each dialect for which a version is present to perform +IR upgrades post-parsing through the `upgradeFromVersion` method. Custom +Attribute and Type encodings can also be upgraded according to the dialect +version using readAttribute and readType methods. + There is no restriction on what kind of information a dialect is allowed to encode to model its versioning. Currently, versioning is supported only for bytecode formats. \ No newline at end of file diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -235,6 +235,17 @@ virtual void writeOwnedBlob(ArrayRef blob) = 0; }; +//===--------------------------------------------------------------------===// +// Dialect Version Interface. +//===--------------------------------------------------------------------===// + +/// This class is used to represent the version of a dialect, for the purpose +/// of polymorphic destruction. +class DialectVersion { +public: + virtual ~DialectVersion() = default; +}; + //===----------------------------------------------------------------------===// // BytecodeDialectInterface //===----------------------------------------------------------------------===// @@ -256,6 +267,16 @@ return Attribute(); } + /// Read a versioned attribute encoding belonging to this dialect from the + /// given reader. This method should return null in the case of failure, and + /// falls back to the non-versioned reader in case the dialect implements + /// versioning but it does not support versioned custom encodings for the + /// attributes. + virtual Attribute readAttribute(DialectBytecodeReader &reader, + const DialectVersion &version) const { + return readAttribute(reader); + } + /// Read a type belonging to this dialect from the given reader. This method /// should return null in the case of failure. virtual Type readType(DialectBytecodeReader &reader) const { @@ -264,6 +285,16 @@ return Type(); } + /// Read a versioned type encoding belonging to this dialect from the given + /// reader. This method should return null in the case of failure, and + /// falls back to the non-versioned reader in case the dialect implements + /// versioning but it does not support versioned custom encodings for the + /// types. + virtual Type readType(DialectBytecodeReader &reader, + const DialectVersion &version) const { + return readType(reader); + } + //===--------------------------------------------------------------------===// // Writing //===--------------------------------------------------------------------===// @@ -285,6 +316,26 @@ DialectBytecodeWriter &writer) const { return failure(); } + + /// Write the version of this dialect to the given writer. + /// The first emitted VarInt should be the size of the entry. + virtual void writeVersion(DialectBytecodeWriter &writer) const {} + + virtual std::unique_ptr + readVersion(DialectBytecodeReader &reader) const { + reader.emitError("Dialect does not support versioning"); + return nullptr; + } + + /// Hook invoked after parsing completed, if a version directive was present + /// and included an entry for the current dialect. This hook offers the + /// opportunity to the dialect to visit the IR and upgrades constructs emitted + /// by the version of the dialect corresponding to the provided version. + virtual LogicalResult + upgradeFromVersion(Operation *topLevelOp, + const DialectVersion &version) const { + return success(); + } }; } // namespace mlir diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -450,68 +450,6 @@ return p; } -//===--------------------------------------------------------------------===// -// Dialect Asm Version Interface. -//===--------------------------------------------------------------------===// - -namespace { -/// Simple wrapper around StorageAllocator to expose a buffer through -/// `AsmDialectVersionHandle`. -class AsmDialectVersionStorage : private StorageUniquer::StorageAllocator { -public: - AsmDialectVersionStorage(ArrayRef in) { buffer = copyInto(in); }; - auto getBuffer() const { return buffer; } - -private: - ArrayRef buffer; -}; -} // namespace - -/// This class represents a handle to a dialect version storage. It is a -/// light-weight class that holds a pointer to a buffer and can be easily -/// iterated on and passed around by value. Each dialect will be able to encode -/// and decode the a version onto the storage exposed through the handle using -/// the `getProducerVersion` and `upgradeFromVersion` methods of the -/// `OpAsmDialectInterface`. -class AsmDialectVersionHandle { -public: - AsmDialectVersionHandle() = default; - AsmDialectVersionHandle(ArrayRef buffer, StringRef name) - : storage(std::make_shared(buffer)), - dialectName(name) {} - operator bool() const { return storage && !storage->getBuffer().empty(); } - - /// Return a reference to the storage buffer. - auto getBuffer() const { - assert(storage && "buffer must exist to be retrieved"); - return storage->getBuffer(); - } - - /// Return an opaque pointer to the raw data. - const void *getRawData() const { - if (storage) - return reinterpret_cast(storage->getBuffer().data()); - return nullptr; - } - - /// Return the size of the storage buffer. - auto size() const { - if (!storage) - return size_t(0); - return storage->getBuffer().size(); - } - - /// Return the dialect that owns the version. - StringRef getDialectName() const { return dialectName; } - -private: - /// The data associated with the version. - std::shared_ptr storage; - - /// The dialect owning the version. - StringRef dialectName; -}; - //===----------------------------------------------------------------------===// // AsmParser //===----------------------------------------------------------------------===// @@ -1644,22 +1582,6 @@ return AliasResult::NoAlias; } - /// Hook provided by the dialect to emit a version when printing. The - /// handle will be available when parsing back, and the dialect implementation - /// will be able to use it to load previous known version. This management is - /// entirely under the responsibility of the individual dialects. - virtual AsmDialectVersionHandle getProducerVersion() const { return {}; } - - /// Hook invoked after parsing completed, if a version directive was present - /// and included an entry for the current dialect. This hook offers the - /// opportunity to the dialect to visit the IR and upgrades constructs emitted - /// by the version of the dialect corresponding to the provided version. - virtual LogicalResult - upgradeFromVersion(Operation *topLevelOp, - AsmDialectVersionHandle versionHandle) const { - return success(); - } - //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -23,8 +23,11 @@ //===----------------------------------------------------------------------===// enum { + /// The minimum supported version of the bytecode. + kMinSupportedVersion = 0, + /// The current bytecode version. - kVersion = 0, + kVersion = 1, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -403,31 +403,15 @@ //===----------------------------------------------------------------------===// namespace { +class DialectReader; + /// This struct represents a dialect entry within the bytecode. struct BytecodeDialect { /// Load the dialect into the provided context if it hasn't been loaded yet. /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(EncodingReader &reader, MLIRContext *ctx) { - if (dialect) - return success(); - Dialect *loadedDialect = ctx->getOrLoadDialect(name); - if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { - return reader.emitError( - "dialect '", name, - "' is unknown. If this is intended, please call " - "allowUnregisteredDialects() on the MLIRContext, or use " - "-allow-unregistered-dialect with the MLIR tool used."); - } - dialect = loadedDialect; - - // If the dialect was actually loaded, check to see if it has a bytecode - // interface. - if (loadedDialect) - interface = dyn_cast(loadedDialect); - return success(); - } + LogicalResult load(DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -450,8 +434,11 @@ /// The name of the dialect. StringRef name; - /// Handle for the dialect version parsed. - AsmDialectVersionHandle versionHandle; + /// A buffer containing the encoding of the dialect version parsed. + ArrayRef versionBuffer; + + /// Lazy loaded dialect version from the handle above. + std::unique_ptr loadedVersion; }; /// This struct represents an operation name entry within the bytecode. @@ -502,7 +489,7 @@ initialize(Location fileLoc, const ParserConfig &config, MutableArrayRef dialects, StringSectionReader &stringReader, ArrayRef sectionData, - ArrayRef offsetSectionData, + ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef); /// Parse a dialect resource handle from the resource section. @@ -649,7 +636,7 @@ Location fileLoc, const ParserConfig &config, MutableArrayRef dialects, StringSectionReader &stringReader, ArrayRef sectionData, - ArrayRef offsetSectionData, + ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef) { EncodingReader resourceReader(sectionData, fileLoc); EncodingReader offsetReader(offsetSectionData, fileLoc); @@ -690,7 +677,7 @@ while (!offsetReader.empty()) { BytecodeDialect *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(resourceReader, ctx))) + failed(dialect->load(dialectReader, ctx))) return failure(); Dialect *loadedDialect = dialect->getLoadedDialect(); if (!loadedDialect) { @@ -1055,7 +1042,8 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - if (failed(entry.dialect->load(reader, fileLoc.getContext()))) + DialectReader dialectReader(*this, stringReader, resourceReader, reader); + if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1064,12 +1052,23 @@ "' does not implement the bytecode interface"); } - // Ask the dialect to parse the entry. - DialectReader dialectReader(*this, stringReader, resourceReader, reader); - if constexpr (std::is_same_v) - entry.entry = entry.dialect->interface->readType(dialectReader); - else - entry.entry = entry.dialect->interface->readAttribute(dialectReader); + // Ask the dialect to parse the entry. If the dialect is versioned, parse + // using the versioned encoding readers. + if (entry.dialect->loadedVersion.get()) { + if constexpr (std::is_same_v) { + entry.entry = entry.dialect->interface->readType( + dialectReader, *entry.dialect->loadedVersion); + } else { + entry.entry = entry.dialect->interface->readAttribute( + dialectReader, *entry.dialect->loadedVersion); + } + } else { + if constexpr (std::is_same_v) { + entry.entry = entry.dialect->interface->readType(dialectReader); + } else { + entry.entry = entry.dialect->interface->readAttribute(dialectReader); + } + } return success(!!entry.entry); } @@ -1126,7 +1125,8 @@ // Resource Section LogicalResult - parseResourceSection(std::optional> resourceData, + parseResourceSection(EncodingReader &reader, + std::optional> resourceData, std::optional> resourceOffsetData); //===--------------------------------------------------------------------===// @@ -1171,13 +1171,6 @@ LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); - //===--------------------------------------------------------------------===// - // Dialect Versions Section - - /// Parse dialect versions. - LogicalResult - parseDialectVersionsSection(std::optional> sectionData); - //===--------------------------------------------------------------------===// // Value Processing @@ -1315,14 +1308,9 @@ if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); - // Process the dialect version section. - if (failed(parseDialectVersionsSection( - sectionDatas[bytecode::Section::kDialectVersions]))) - return failure(); - // Process the resource section if present. if (failed(parseResourceSection( - sectionDatas[bytecode::Section::kResource], + reader, sectionDatas[bytecode::Section::kResource], sectionDatas[bytecode::Section::kResourceOffset]))) return failure(); @@ -1342,7 +1330,8 @@ // Validate the bytecode version. uint64_t currentVersion = bytecode::kVersion; - if (version < currentVersion) { + uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; + if (version < minSupportedVersion) { return reader.emitError("bytecode version ", version, " is older than the current version of ", currentVersion, ", and upgrade is not supported"); @@ -1358,6 +1347,36 @@ //===----------------------------------------------------------------------===// // Dialect Section +LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { + if (dialect) + return success(); + Dialect *loadedDialect = ctx->getOrLoadDialect(name); + if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { + return reader.emitError("dialect '") + << name + << "' is unknown. If this is intended, please call " + "allowUnregisteredDialects() on the MLIRContext, or use " + "-allow-unregistered-dialect with the MLIR tool used."; + } + dialect = loadedDialect; + + // If the dialect was actually loaded, check to see if it has a bytecode + // interface. + if (loadedDialect) + interface = dyn_cast(loadedDialect); + if (!versionBuffer.empty()) { + if (!interface) + return reader.emitError("dialect '") + << name + << "' does not implement the bytecode interface, " + "but found a version entry"; + loadedVersion = interface->readVersion(reader); + if (!loadedVersion) + return failure(); + } + return success(); +} + LogicalResult BytecodeReader::parseDialectSection(ArrayRef sectionData) { EncodingReader sectionReader(sectionData, fileLoc); @@ -1369,9 +1388,21 @@ dialects.resize(numDialects); // Parse each of the dialects. - for (uint64_t i = 0; i < numDialects; ++i) + for (uint64_t i = 0; i < numDialects; ++i) { if (failed(stringReader.parseString(sectionReader, dialects[i].name))) return failure(); + /// Before version 1, there wasn't any versioning available for dialects. + if (version == 0) + continue; + bytecode::Section::ID sectionID; + if (failed( + sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) + return failure(); + if (sectionID != bytecode::Section::kDialectVersions) { + emitError(fileLoc, "expected dialect version section"); + return failure(); + } + } // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { @@ -1395,7 +1426,11 @@ // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. if (!opName->opName) { - if (failed(opName->dialect->load(reader, getContext()))) + // Load the dialect and its version. + EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc); + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + versionReader); + if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), getContext()); @@ -1407,7 +1442,7 @@ // Resource Section LogicalResult BytecodeReader::parseResourceSection( - std::optional> resourceData, + EncodingReader &reader, std::optional> resourceData, std::optional> resourceOffsetData) { // Ensure both sections are either present or not. if (resourceData.has_value() != resourceOffsetData.has_value()) { @@ -1424,9 +1459,11 @@ return success(); // Initialize the resource reader with the resource sections. + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, - bufferOwnerRef); + dialectReader, bufferOwnerRef); } //===----------------------------------------------------------------------===// @@ -1462,17 +1499,12 @@ for (const BytecodeDialect &byteCodeDialect : dialects) { // Parsing is complete, give an opportunity to each dialect to visit the // IR and perform upgrades. - if (byteCodeDialect.versionHandle) { - Dialect *dialect = moduleOp->getContext()->getOrLoadDialect( - byteCodeDialect.versionHandle.getDialectName()); - auto *asmIface = - llvm::dyn_cast_or_null(dialect); - if (!asmIface) - continue; - if (failed(asmIface->upgradeFromVersion(*moduleOp, - byteCodeDialect.versionHandle))) - return failure(); - } + if (!byteCodeDialect.loadedVersion) + continue; + if (byteCodeDialect.interface && + failed(byteCodeDialect.interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect.loadedVersion))) + return failure(); } // Verify that the parsed operations are valid. @@ -1708,45 +1740,6 @@ return defineValues(reader, block->getArguments()); } -//===--------------------------------------------------------------------===// -// Dialect Versions Section - -LogicalResult BytecodeReader::parseDialectVersionsSection( - std::optional> sectionData) { - // If the dialect versions are absent, there is nothing to do. - if (!sectionData.has_value()) - return success(); - - EncodingReader sectionReader(sectionData.value(), fileLoc); - - // Parse the number of dialects in the section. - uint64_t numVersionedDialects; - if (failed(sectionReader.parseVarInt(numVersionedDialects))) - return failure(); - - // Parse each of the dialect versions. - llvm::StringMap versionedDialectMap; - for (uint64_t i = 0; i < numVersionedDialects; ++i) { - StringRef dialectName; - if (failed(stringReader.parseString(sectionReader, dialectName))) - return failure(); - uint64_t dialectVersionSize; - if (failed(sectionReader.parseVarInt(dialectVersionSize))) - return failure(); - ArrayRef bytes; - if (failed(sectionReader.parseBytes(dialectVersionSize, bytes))) - return failure(); - versionedDialectMap.insert( - {dialectName, AsmDialectVersionHandle(bytes, dialectName)}); - } - - for (auto &dialect : dialects) - if (versionedDialectMap.count(dialect.name)) - dialect.versionHandle = versionedDialectMap.lookup(dialect.name); - - return success(); -} - //===----------------------------------------------------------------------===// // Value Processing diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -10,13 +10,10 @@ #include "../Encoding.h" #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" -#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/CachedHashString.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" -#include "llvm/Support/Debug.h" -#include #define DEBUG_TYPE "mlir-bytecode-writer" @@ -261,6 +258,116 @@ unsigned requiredAlignment = 1; }; +//===----------------------------------------------------------------------===// +// StringSectionBuilder +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to simplify the process of emitting the string section. +class StringSectionBuilder { +public: + /// Add the given string to the string section, and return the index of the + /// string within the section. + size_t insert(StringRef str) { + auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); + return it.first->second; + } + + /// Write the current set of strings to the given emitter. + void write(EncodingEmitter &emitter) { + emitter.emitVarInt(strings.size()); + + // Emit the sizes in reverse order, so that we don't need to backpatch an + // offset to the string data or have a separate section. + for (const auto &it : llvm::reverse(strings)) + emitter.emitVarInt(it.first.size() + 1); + // Emit the string data itself. + for (const auto &it : strings) + emitter.emitNulTerminatedString(it.first.val()); + } + +private: + /// A set of strings referenced within the bytecode. The value of the map is + /// unused. + llvm::MapVector strings; +}; +} // namespace + +class DialectWriter : public DialectBytecodeWriter { +public: + DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, + StringSectionBuilder &stringSection) + : emitter(emitter), numberingState(numberingState), + stringSection(stringSection) {} + + //===--------------------------------------------------------------------===// + // IR + //===--------------------------------------------------------------------===// + + void writeAttribute(Attribute attr) override { + emitter.emitVarInt(numberingState.getNumber(attr)); + } + void writeType(Type type) override { + emitter.emitVarInt(numberingState.getNumber(type)); + } + + void writeResourceHandle(const AsmDialectResourceHandle &resource) override { + emitter.emitVarInt(numberingState.getNumber(resource)); + } + + //===--------------------------------------------------------------------===// + // Primitives + //===--------------------------------------------------------------------===// + + void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } + + void writeSignedVarInt(int64_t value) override { + emitter.emitSignedVarInt(value); + } + + void writeAPIntWithKnownWidth(const APInt &value) override { + size_t bitWidth = value.getBitWidth(); + + // If the value is a single byte, just emit it directly without going + // through a varint. + if (bitWidth <= 8) + return emitter.emitByte(value.getLimitedValue()); + + // If the value fits within a single varint, emit it directly. + if (bitWidth <= 64) + return emitter.emitSignedVarInt(value.getLimitedValue()); + + // Otherwise, we need to encode a variable number of active words. We use + // active words instead of the number of total words under the observation + // that smaller values will be more common. + unsigned numActiveWords = value.getActiveWords(); + emitter.emitVarInt(numActiveWords); + + const uint64_t *rawValueData = value.getRawData(); + for (unsigned i = 0; i < numActiveWords; ++i) + emitter.emitSignedVarInt(rawValueData[i]); + } + + void writeAPFloatWithKnownSemantics(const APFloat &value) override { + writeAPIntWithKnownWidth(value.bitcastToAPInt()); + } + + void writeOwnedString(StringRef str) override { + emitter.emitVarInt(stringSection.insert(str)); + } + + void writeOwnedBlob(ArrayRef blob) override { + emitter.emitVarInt(blob.size()); + emitter.emitOwnedBlob(ArrayRef( + reinterpret_cast(blob.data()), blob.size())); + } + +private: + EncodingEmitter &emitter; + IRNumberingState &numberingState; + StringSectionBuilder &stringSection; +}; + /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need /// to go through an intermediate buffer when interacting with code that wants a /// raw_ostream. @@ -307,41 +414,6 @@ emitBytes({reinterpret_cast(&value), sizeof(value)}); } -//===----------------------------------------------------------------------===// -// StringSectionBuilder -//===----------------------------------------------------------------------===// - -namespace { -/// This class is used to simplify the process of emitting the string section. -class StringSectionBuilder { -public: - /// Add the given string to the string section, and return the index of the - /// string within the section. - size_t insert(StringRef str) { - auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); - return it.first->second; - } - - /// Write the current set of strings to the given emitter. - void write(EncodingEmitter &emitter) { - emitter.emitVarInt(strings.size()); - - // Emit the sizes in reverse order, so that we don't need to backpatch an - // offset to the string data or have a separate section. - for (const auto &it : llvm::reverse(strings)) - emitter.emitVarInt(it.first.size() + 1); - // Emit the string data itself. - for (const auto &it : strings) - emitter.emitNulTerminatedString(it.first.val()); - } - -private: - /// A set of strings referenced within the bytecode. The value of the map is - /// unused. - llvm::MapVector strings; -}; -} // namespace - //===----------------------------------------------------------------------===// // Bytecode Writer //===----------------------------------------------------------------------===// @@ -385,11 +457,6 @@ void writeStringSection(EncodingEmitter &emitter); - //===--------------------------------------------------------------------===// - // Dialect versions - - void writeDialectVersionsSection(EncodingEmitter &emitter); - //===--------------------------------------------------------------------===// // Fields @@ -430,9 +497,6 @@ // Emit the string section. writeStringSection(emitter); - // Emit the dialect section. - writeDialectVersionsSection(emitter); - // Write the generated bytecode to the provided output stream. emitter.writeTo(os); } @@ -472,8 +536,18 @@ // Emit the referenced dialects. auto dialects = numberingState.getDialects(); dialectEmitter.emitVarInt(llvm::size(dialects)); - for (DialectNumbering &dialect : dialects) + for (DialectNumbering &dialect : dialects) { dialectEmitter.emitVarInt(stringSection.insert(dialect.name)); + EncodingEmitter versionEmitter; + if (dialect.interface) { + // The writer used when emitting using a custom bytecode encoding. + DialectWriter versionWriter(versionEmitter, numberingState, + stringSection); + dialect.interface->writeVersion(versionWriter); + } + dialectEmitter.emitSection(bytecode::Section::kDialectVersions, + std::move(versionEmitter)); + } // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { @@ -484,121 +558,9 @@ emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter)); } -void BytecodeWriter::writeDialectVersionsSection(EncodingEmitter &emitter) { - EncodingEmitter dialectEmitter; - - // Get dialect version. - auto getVersion = - [&](const OpAsmDialectInterface *asmIface) -> AsmDialectVersionHandle { - // A dialect can be nullptr if not loaded. In such a case, we can't print - // the version properly. - if (asmIface) - return asmIface->getProducerVersion(); - return {}; - }; - - // Emit the referenced dialects. - auto dialects = numberingState.getDialects(); - llvm::SmallVector> - dialectVersionPair; - - for (DialectNumbering &dialect : dialects) { - if (auto version = getVersion(dialect.asmInterface)) { - dialectVersionPair.push_back( - std::make_pair(stringSection.insert(dialect.name), version)); - } - } - dialectEmitter.emitVarInt(dialectVersionPair.size()); - for (auto item : dialectVersionPair) { - dialectEmitter.emitVarInt(item.first); - dialectEmitter.emitVarInt(item.second.size()); - dialectEmitter.emitBytes(item.second.getBuffer()); - } - - emitter.emitSection(bytecode::Section::kDialectVersions, - std::move(dialectEmitter)); -} - //===----------------------------------------------------------------------===// // Attributes and Types -namespace { -class DialectWriter : public DialectBytecodeWriter { -public: - DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, - StringSectionBuilder &stringSection) - : emitter(emitter), numberingState(numberingState), - stringSection(stringSection) {} - - //===--------------------------------------------------------------------===// - // IR - //===--------------------------------------------------------------------===// - - void writeAttribute(Attribute attr) override { - emitter.emitVarInt(numberingState.getNumber(attr)); - } - void writeType(Type type) override { - emitter.emitVarInt(numberingState.getNumber(type)); - } - - void writeResourceHandle(const AsmDialectResourceHandle &resource) override { - emitter.emitVarInt(numberingState.getNumber(resource)); - } - - //===--------------------------------------------------------------------===// - // Primitives - //===--------------------------------------------------------------------===// - - void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } - - void writeSignedVarInt(int64_t value) override { - emitter.emitSignedVarInt(value); - } - - void writeAPIntWithKnownWidth(const APInt &value) override { - size_t bitWidth = value.getBitWidth(); - - // If the value is a single byte, just emit it directly without going - // through a varint. - if (bitWidth <= 8) - return emitter.emitByte(value.getLimitedValue()); - - // If the value fits within a single varint, emit it directly. - if (bitWidth <= 64) - return emitter.emitSignedVarInt(value.getLimitedValue()); - - // Otherwise, we need to encode a variable number of active words. We use - // active words instead of the number of total words under the observation - // that smaller values will be more common. - unsigned numActiveWords = value.getActiveWords(); - emitter.emitVarInt(numActiveWords); - - const uint64_t *rawValueData = value.getRawData(); - for (unsigned i = 0; i < numActiveWords; ++i) - emitter.emitSignedVarInt(rawValueData[i]); - } - - void writeAPFloatWithKnownSemantics(const APFloat &value) override { - writeAPIntWithKnownWidth(value.bitcastToAPInt()); - } - - void writeOwnedString(StringRef str) override { - emitter.emitVarInt(stringSection.insert(str)); - } - - void writeOwnedBlob(ArrayRef blob) override { - emitter.emitVarInt(blob.size()); - emitter.emitOwnedBlob(ArrayRef( - reinterpret_cast(blob.data()), blob.size())); - } - -private: - EncodingEmitter &emitter; - IRNumberingState &numberingState; - StringSectionBuilder &stringSection; -}; -} // namespace - void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { EncodingEmitter attrTypeEmitter; EncodingEmitter offsetEmitter; diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version 0 +// VERSION: bytecode version 127 is newer than the current version 1 //===--------------------------------------------------------------------===// // Producer diff --git a/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc b/mlir/test/Bytecode/versioning/versioned-attr-1.12.mlirbc new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-attr-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 +// CHECK1: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> () + +//===--------------------------------------------------------------------===// +// Test attribute upgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.0 +// COM: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-attr-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2 +// CHECK2: "test.versionedB"() {attribute = #test.attr_params<42, 24>} : () -> () diff --git a/mlir/test/Bytecode/versioning/versioned_op.mlir b/mlir/test/Bytecode/versioning/versioned_op.mlir --- a/mlir/test/Bytecode/versioning/versioned_op.mlir +++ b/mlir/test/Bytecode/versioning/versioned_op.mlir @@ -1,5 +1,5 @@ -// This file contains various failure test cases related to the structure of -// the dialect section. +// This file contains test cases related to the dialect post-parsing upgrade +// mechanism. // Bytecode currently does not support big-endian platforms // UNSUPPORTED: target=s390x-{{.*}} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -10,15 +10,14 @@ #include "TestAttributes.h" #include "TestInterfaces.h" #include "TestTypes.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" -#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OperationSupport.h" @@ -52,33 +51,11 @@ // TestDialect version utilities //===----------------------------------------------------------------------===// -struct TestDialectVersion { +struct TestDialectVersion : public DialectVersion { uint32_t major = 2; uint32_t minor = 0; }; -// Encode/decode a version attribute -AsmDialectVersionHandle encodeDialectVersion(MLIRContext *ctx, - const TestDialectVersion version) { - uint64_t encoding; - uint32_t *ptr = (uint32_t *)&encoding; - llvm::support::endian::write32le((void *)ptr++, version.major); - llvm::support::endian::write32le((void *)ptr, version.minor); - ArrayRef encode(reinterpret_cast(&encoding), - sizeof(encoding)); - return AsmDialectVersionHandle( - encode, ctx->getOrLoadDialect()->getNamespace()); -} - -TestDialectVersion decodeDialectVersion(AsmDialectVersionHandle handle) { - uint64_t encoding = *reinterpret_cast(handle.getRawData()); - uint32_t *ptr = (uint32_t *)&encoding; - TestDialectVersion version; - version.major = llvm::support::endian::read32le((void *)ptr++); - version.minor = llvm::support::endian::read32le((void *)ptr); - return version; -} - //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -102,6 +79,107 @@ TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase; }; +namespace { +enum test_encoding { k_attr_params = 0 }; +} + +// Test support for interacting with the Bytecode reader/writer. +struct TestBytecodeDialectInterface : public BytecodeDialectInterface { + using BytecodeDialectInterface::BytecodeDialectInterface; + TestBytecodeDialectInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const final { + if (auto concreteAttr = llvm::dyn_cast(attr)) { + writer.writeVarInt(test_encoding::k_attr_params); + writer.writeVarInt(concreteAttr.getV0()); + writer.writeVarInt(concreteAttr.getV1()); + return success(); + } + writer.writeAttribute(attr); + return success(); + } + + Attribute readAttribute(DialectBytecodeReader &reader, + const DialectVersion &version_) const final { + const auto &version = static_cast(version_); + if (version.major < 2) + return readAttrOldEncoding(reader); + if (version.major == 2 && version.minor == 0) + return readAttrNewEncoding(reader); + // For deprecated syntax, return nullptr. + return Attribute(); + } + + // Emit a specific version of the dialect. + void writeVersion(DialectBytecodeWriter &writer) const final { + auto version = TestDialectVersion(); + writer.writeVarInt(version.major); // major + writer.writeVarInt(version.minor); // minor + } + + std::unique_ptr + readVersion(DialectBytecodeReader &reader) const final { + uint64_t major, minor; + if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor))) + return nullptr; + auto version = std::make_unique(); + version->major = major; + version->minor = minor; + return version; + } + + LogicalResult upgradeFromVersion(Operation *topLevelOp, + const DialectVersion &version_) const final { + const auto &version = static_cast(version_); + if ((version.major == 2) && (version.minor == 0)) + return success(); + if (version.major > 2 || (version.major == 2 && version.minor > 0)) { + return topLevelOp->emitError() + << "current test dialect version is 2.0, can't parse version: " + << version.major << "." << version.minor; + } + // Prior version 2.0, the old op supported only a single attribute called + // "dimensions". We can perform the upgrade. + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + op->setAttr("modifier", BoolAttr::get(op->getContext(), false)); + }); + return success(); + } + +private: + Attribute readAttrNewEncoding(DialectBytecodeReader &reader) const { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || + encoding != test_encoding::k_attr_params) + return Attribute(); + // The new encoding has v0 first, v1 second. + uint64_t v0, v1; + if (failed(reader.readVarInt(v0)) || failed(reader.readVarInt(v1))) + return Attribute(); + return TestAttrParamsAttr::get(getContext(), static_cast(v0), + static_cast(v1)); + } + + Attribute readAttrOldEncoding(DialectBytecodeReader &reader) const { + uint64_t encoding; + if (failed(reader.readVarInt(encoding)) || + encoding != test_encoding::k_attr_params) + return Attribute(); + // The old encoding has v1 first, v0 second. + uint64_t v0, v1; + if (failed(reader.readVarInt(v1)) || failed(reader.readVarInt(v0))) + return Attribute(); + return TestAttrParamsAttr::get(getContext(), static_cast(v0), + static_cast(v1)); + } +}; + // Test support for interacting with the AsmPrinter. struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; @@ -196,35 +274,6 @@ blobManager.buildResources(provider, referencedResources.getArrayRef()); } - // Test IR upgrade with dialect version lookup. - AsmDialectVersionHandle getProducerVersion() const final { - return encodeDialectVersion(getContext(), TestDialectVersion()); - } - - LogicalResult - upgradeFromVersion(Operation *topLevelOp, - AsmDialectVersionHandle producerVersion) const final { - auto version = decodeDialectVersion(producerVersion); - if ((version.major == 2) && (version.minor == 0)) - return success(); - if (version.major > 2 || (version.major == 2 && version.minor > 0)) { - return topLevelOp->emitError() - << "current test dialect version is 2.0, can't parse version: " - << version.major << "." << version.minor; - } - // Prior version 2.0, the old op supported only a single attribute called - // "dimensions". We can perform the upgrade. - topLevelOp->walk([](TestVersionedOpA op) { - if (auto dims = op->getAttr("dimensions")) { - op->removeAttr("dimensions"); - op->setAttr("dims", dims); - } - op->setAttr("modifier", BoolAttr::get(op->getContext(), false)); - }); - - return success(); - } - private: /// The blob manager for the dialect. TestResourceBlobManagerInterface &blobManager; @@ -427,7 +476,7 @@ addInterface(blobInterface); addInterfaces(); + TestReductionPatternInterface, TestBytecodeDialectInterface>(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3163,11 +3163,29 @@ // In the current version (2.0) "dimensions" was renamed to "dims", and a new // boolean attribute "modifier" was added. The previous version of the op // corresponds to "modifier=false". We support loading old IR through - // upgrading, see `upgradeFromVersion()` in `TestOpAsmInterface`. + // upgrading, see `upgradeFromVersion()` in `TestBytecodeDialectInterface`. let arguments = (ins AnyI64Attr:$dims, BoolAttr:$modifier ); } +def TestVersionedOpB : TEST_Op<"versionedB"> { + // A previous version of the dialect (let's say 1.*) we encoded TestAttrParams + // with a custom encoding: + // + // #test.attr_params -> { varInt: Y, varInt: X } + // + // In the current version (2.0) the encoding changed and the two parameters of + // the attribute are swapped: + // + // #test.attr_params -> { varInt: X, varInt: Y } + // + // We support loading old IR through a custom readAttribute method, see + // `readAttribute()` in `TestBytecodeDialectInterface` + let arguments = (ins + TestAttrParams:$attribute + ); +} + #endif // TEST_OPS