diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -403,3 +403,24 @@ A block is encoded with an array of operations and block arguments. The first field is an encoding that combines the number of operations in the block, with a flag indicating if the block has arguments. + +### Dialect Version Section + +``` +dialect_version_section { + dialects: dialect_info[] +} + +dialect_info { + dialect: varint, + size: varint, + version: byte[] +} +``` + +The dialect version section contains details about the version for each of the +versioned dialects contained in the module. A dialect is considered versioned if +implements the OpAsmDialectInterface hooks as described in the MLIR Language +Reference. The dialect info section is optional and written only if the dialect +is registered in the context. When version information for a versioned dialect +are missing, a dialect will be parsed without performing upgrades. \ No newline at end of file diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -845,3 +845,18 @@ that are directly usable by any other dialect in MLIR. These types cover a range from primitive integer and floating-point values, attribute dictionaries, dense multi-dimensional arrays, and more. + +### IR Versionning + +A dialect can opt-in to handle versioning in a custom way. Two hooks to expose a +dialect version to the printer and parser are available through the +`OpAsmDialectInterface`. First, the `getProducerVersion()` method allows to +inject a target dialect version to the printer, which will be written to the +output mlir file. Second, the `upgradeFromVersion()` method allows to retrieve +the version information while parsing the input IR, and gives an opportunity to +each dialect for which a version is present to perform IR upgrades. + +Version information are stored on a buffer through an `AsmDialectVersionHandle`. +There is no restriction on what kind of information a dialect is allowed to +encode to model its versioning. Currently, versioning is supported only for +bytecode formats. \ No newline at end of file diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -285,6 +285,27 @@ DialectBytecodeWriter &writer) const { return failure(); } + + /// Write the version of this dialect to the given writer. + /// The first emitted VarInt should be the size of the entry. + virtual void writeVersion(DialectBytecodeWriter &writer) const {} + + virtual std::unique_ptr + parseVersion(DialectBytecodeReader &reader, + ArrayRef versionHandle) const { + reader.emitError("Dialect does not support versioning"); + return nullptr; + } + + /// Hook invoked after parsing completed, if a version directive was present + /// and included an entry for the current dialect. This hook offers the + /// opportunity to the dialect to visit the IR and upgrades constructs emitted + /// by the version of the dialect corresponding to the provided version. + virtual LogicalResult + upgradeFromVersion(Operation *topLevelOp, + const DialectVersion &version) const { + return success(); + } }; } // namespace mlir diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -463,6 +463,17 @@ return p; } +//===--------------------------------------------------------------------===// +// Dialect Asm Version Interface. +//===--------------------------------------------------------------------===// + +/// This class is used to represent the version of a dialect, for the purpose +/// of polymorphic destruction. +class DialectVersion { +public: + virtual ~DialectVersion() = default; +}; + //===----------------------------------------------------------------------===// // AsmParser //===----------------------------------------------------------------------===// @@ -1463,9 +1474,9 @@ //===--------------------------------------------------------------------===// struct Argument { - UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. - Type type; // Type. - DictionaryAttr attrs; // Attributes if present. + UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. + Type type; // Type. + DictionaryAttr attrs; // Attributes if present. std::optional sourceLoc; // Source location specifier if present. }; diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -23,8 +23,11 @@ //===----------------------------------------------------------------------===// enum { + /// The minimum supported version of the bytecode. + kMinSupportedVersion = 0, + /// The current bytecode version. - kVersion = 0, + kVersion = 1, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, @@ -61,8 +64,11 @@ /// section. kResourceOffset = 6, + /// This section contains the versions of each dialect. + kDialectVersions = 7, + /// The total number of section types. - kNumSections = 7, + kNumSections = 8, }; } // namespace Section diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -47,6 +47,8 @@ return "Resource (5)"; case bytecode::Section::kResourceOffset: return "ResourceOffset (6)"; + case bytecode::Section::kDialectVersions: + return "DialectVersions (7)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } @@ -63,6 +65,7 @@ return false; case bytecode::Section::kResource: case bytecode::Section::kResourceOffset: + case bytecode::Section::kDialectVersions: return true; default: llvm_unreachable("unknown section ID"); @@ -400,31 +403,15 @@ //===----------------------------------------------------------------------===// namespace { +class DialectReader; + /// This struct represents a dialect entry within the bytecode. struct BytecodeDialect { /// Load the dialect into the provided context if it hasn't been loaded yet. /// Returns failure if the dialect couldn't be loaded *and* the provided /// context does not allow unregistered dialects. The provided reader is used /// for error emission if necessary. - LogicalResult load(EncodingReader &reader, MLIRContext *ctx) { - if (dialect) - return success(); - Dialect *loadedDialect = ctx->getOrLoadDialect(name); - if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { - return reader.emitError( - "dialect '", name, - "' is unknown. If this is intended, please call " - "allowUnregisteredDialects() on the MLIRContext, or use " - "-allow-unregistered-dialect with the MLIR tool used."); - } - dialect = loadedDialect; - - // If the dialect was actually loaded, check to see if it has a bytecode - // interface. - if (loadedDialect) - interface = dyn_cast(loadedDialect); - return success(); - } + LogicalResult load(DialectReader &reader, MLIRContext *ctx); /// Return the loaded dialect, or nullptr if the dialect is unknown. This can /// only be called after `load`. @@ -446,6 +433,11 @@ /// The name of the dialect. StringRef name; + + /// Handle for the dialect version parsed. + ArrayRef versionHandle; + /// Lazy loaded dialect version from the handle above. + std::unique_ptr loadedVersion; }; /// This struct represents an operation name entry within the bytecode. @@ -496,7 +488,7 @@ initialize(Location fileLoc, const ParserConfig &config, MutableArrayRef dialects, StringSectionReader &stringReader, ArrayRef sectionData, - ArrayRef offsetSectionData, + ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef); /// Parse a dialect resource handle from the resource section. @@ -643,7 +635,7 @@ Location fileLoc, const ParserConfig &config, MutableArrayRef dialects, StringSectionReader &stringReader, ArrayRef sectionData, - ArrayRef offsetSectionData, + ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef) { EncodingReader resourceReader(sectionData, fileLoc); EncodingReader offsetReader(offsetSectionData, fileLoc); @@ -684,7 +676,7 @@ while (!offsetReader.empty()) { BytecodeDialect *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(resourceReader, ctx))) + failed(dialect->load(dialectReader, ctx))) return failure(); Dialect *loadedDialect = dialect->getLoadedDialect(); if (!loadedDialect) { @@ -1049,7 +1041,8 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - if (failed(entry.dialect->load(reader, fileLoc.getContext()))) + DialectReader dialectReader(*this, stringReader, resourceReader, reader); + if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1059,7 +1052,6 @@ } // Ask the dialect to parse the entry. - DialectReader dialectReader(*this, stringReader, resourceReader, reader); if constexpr (std::is_same_v) entry.entry = entry.dialect->interface->readType(dialectReader); else @@ -1120,7 +1112,8 @@ // Resource Section LogicalResult - parseResourceSection(std::optional> resourceData, + parseResourceSection(EncodingReader &reader, + std::optional> resourceData, std::optional> resourceOffsetData); //===--------------------------------------------------------------------===// @@ -1304,7 +1297,7 @@ // Process the resource section if present. if (failed(parseResourceSection( - sectionDatas[bytecode::Section::kResource], + reader, sectionDatas[bytecode::Section::kResource], sectionDatas[bytecode::Section::kResourceOffset]))) return failure(); @@ -1324,7 +1317,8 @@ // Validate the bytecode version. uint64_t currentVersion = bytecode::kVersion; - if (version < currentVersion) { + uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; + if (version < minSupportedVersion) { return reader.emitError("bytecode version ", version, " is older than the current version of ", currentVersion, ", and upgrade is not supported"); @@ -1340,6 +1334,36 @@ //===----------------------------------------------------------------------===// // Dialect Section +LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { + if (dialect) + return success(); + Dialect *loadedDialect = ctx->getOrLoadDialect(name); + if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { + return reader.emitError("dialect '") + << name + << "' is unknown. If this is intended, please call " + "allowUnregisteredDialects() on the MLIRContext, or use " + "-allow-unregistered-dialect with the MLIR tool used."; + } + dialect = loadedDialect; + + // If the dialect was actually loaded, check to see if it has a bytecode + // interface. + if (loadedDialect) + interface = dyn_cast(loadedDialect); + if (!versionHandle.empty()) { + if (!interface) + return reader.emitError("dialect '") + << name + << "' does not implement the bytecode interface, " + "but found a version entry"; + loadedVersion = interface->parseVersion(reader, versionHandle); + if (!loadedVersion) + return failure(); + } + return success(); +} + LogicalResult BytecodeReader::parseDialectSection(ArrayRef sectionData) { EncodingReader sectionReader(sectionData, fileLoc); @@ -1351,9 +1375,21 @@ dialects.resize(numDialects); // Parse each of the dialects. - for (uint64_t i = 0; i < numDialects; ++i) + for (uint64_t i = 0; i < numDialects; ++i) { if (failed(stringReader.parseString(sectionReader, dialects[i].name))) return failure(); + /// Before version 1, there wasn't any versioning available for dialects. + if (version == 0) + continue; + bytecode::Section::ID sectionID; + if (failed( + sectionReader.parseSection(sectionID, dialects[i].versionHandle))) + return failure(); + if (sectionID != bytecode::Section::kDialectVersions) { + emitError(fileLoc, "expected dialect version section"); + return failure(); + } + } // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { @@ -1376,8 +1412,10 @@ // Check to see if this operation name has already been resolved. If we // haven't, load the dialect and build the operation name. + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); if (!opName->opName) { - if (failed(opName->dialect->load(reader, getContext()))) + if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), getContext()); @@ -1389,7 +1427,7 @@ // Resource Section LogicalResult BytecodeReader::parseResourceSection( - std::optional> resourceData, + EncodingReader &reader, std::optional> resourceData, std::optional> resourceOffsetData) { // Ensure both sections are either present or not. if (resourceData.has_value() != resourceOffsetData.has_value()) { @@ -1406,9 +1444,11 @@ return success(); // Initialize the resource reader with the resource sections. + DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, + reader); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, - bufferOwnerRef); + dialectReader, bufferOwnerRef); } //===----------------------------------------------------------------------===// @@ -1440,6 +1480,18 @@ "not all forward unresolved forward operand references"); } + // Resolve dialect version. + for (const BytecodeDialect &byteCodeDialect : dialects) { + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (!byteCodeDialect.loadedVersion) + continue; + if (byteCodeDialect.interface && + failed(byteCodeDialect.interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect.loadedVersion))) + return failure(); + } + // Verify that the parsed operations are valid. if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) return failure(); diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -261,6 +261,116 @@ unsigned requiredAlignment = 1; }; +//===----------------------------------------------------------------------===// +// StringSectionBuilder +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to simplify the process of emitting the string section. +class StringSectionBuilder { +public: + /// Add the given string to the string section, and return the index of the + /// string within the section. + size_t insert(StringRef str) { + auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); + return it.first->second; + } + + /// Write the current set of strings to the given emitter. + void write(EncodingEmitter &emitter) { + emitter.emitVarInt(strings.size()); + + // Emit the sizes in reverse order, so that we don't need to backpatch an + // offset to the string data or have a separate section. + for (const auto &it : llvm::reverse(strings)) + emitter.emitVarInt(it.first.size() + 1); + // Emit the string data itself. + for (const auto &it : strings) + emitter.emitNulTerminatedString(it.first.val()); + } + +private: + /// A set of strings referenced within the bytecode. The value of the map is + /// unused. + llvm::MapVector strings; +}; +} // namespace + +class DialectWriter : public DialectBytecodeWriter { +public: + DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, + StringSectionBuilder &stringSection) + : emitter(emitter), numberingState(numberingState), + stringSection(stringSection) {} + + //===--------------------------------------------------------------------===// + // IR + //===--------------------------------------------------------------------===// + + void writeAttribute(Attribute attr) override { + emitter.emitVarInt(numberingState.getNumber(attr)); + } + void writeType(Type type) override { + emitter.emitVarInt(numberingState.getNumber(type)); + } + + void writeResourceHandle(const AsmDialectResourceHandle &resource) override { + emitter.emitVarInt(numberingState.getNumber(resource)); + } + + //===--------------------------------------------------------------------===// + // Primitives + //===--------------------------------------------------------------------===// + + void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } + + void writeSignedVarInt(int64_t value) override { + emitter.emitSignedVarInt(value); + } + + void writeAPIntWithKnownWidth(const APInt &value) override { + size_t bitWidth = value.getBitWidth(); + + // If the value is a single byte, just emit it directly without going + // through a varint. + if (bitWidth <= 8) + return emitter.emitByte(value.getLimitedValue()); + + // If the value fits within a single varint, emit it directly. + if (bitWidth <= 64) + return emitter.emitSignedVarInt(value.getLimitedValue()); + + // Otherwise, we need to encode a variable number of active words. We use + // active words instead of the number of total words under the observation + // that smaller values will be more common. + unsigned numActiveWords = value.getActiveWords(); + emitter.emitVarInt(numActiveWords); + + const uint64_t *rawValueData = value.getRawData(); + for (unsigned i = 0; i < numActiveWords; ++i) + emitter.emitSignedVarInt(rawValueData[i]); + } + + void writeAPFloatWithKnownSemantics(const APFloat &value) override { + writeAPIntWithKnownWidth(value.bitcastToAPInt()); + } + + void writeOwnedString(StringRef str) override { + emitter.emitVarInt(stringSection.insert(str)); + } + + void writeOwnedBlob(ArrayRef blob) override { + emitter.emitVarInt(blob.size()); + emitter.emitOwnedBlob(ArrayRef( + reinterpret_cast(blob.data()), blob.size())); + } + +private: + EncodingEmitter &emitter; + IRNumberingState &numberingState; + StringSectionBuilder &stringSection; +}; + /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need /// to go through an intermediate buffer when interacting with code that wants a /// raw_ostream. @@ -307,41 +417,6 @@ emitBytes({reinterpret_cast(&value), sizeof(value)}); } -//===----------------------------------------------------------------------===// -// StringSectionBuilder -//===----------------------------------------------------------------------===// - -namespace { -/// This class is used to simplify the process of emitting the string section. -class StringSectionBuilder { -public: - /// Add the given string to the string section, and return the index of the - /// string within the section. - size_t insert(StringRef str) { - auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); - return it.first->second; - } - - /// Write the current set of strings to the given emitter. - void write(EncodingEmitter &emitter) { - emitter.emitVarInt(strings.size()); - - // Emit the sizes in reverse order, so that we don't need to backpatch an - // offset to the string data or have a separate section. - for (const auto &it : llvm::reverse(strings)) - emitter.emitVarInt(it.first.size() + 1); - // Emit the string data itself. - for (const auto &it : strings) - emitter.emitNulTerminatedString(it.first.val()); - } - -private: - /// A set of strings referenced within the bytecode. The value of the map is - /// unused. - llvm::MapVector strings; -}; -} // namespace - //===----------------------------------------------------------------------===// // Bytecode Writer //===----------------------------------------------------------------------===// @@ -460,12 +535,23 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) { EncodingEmitter dialectEmitter; + // The writer used when emitting using a custom bytecode encoding. + DialectWriter dialectWriter(dialectEmitter, numberingState, stringSection); // Emit the referenced dialects. auto dialects = numberingState.getDialects(); dialectEmitter.emitVarInt(llvm::size(dialects)); - for (DialectNumbering &dialect : dialects) + for (DialectNumbering &dialect : dialects) { dialectEmitter.emitVarInt(stringSection.insert(dialect.name)); + EncodingEmitter versionEmitter; + if (dialect.interface) { + DialectWriter versionWriter(versionEmitter, numberingState, + stringSection); + dialect.interface->writeVersion(versionWriter); + } + dialectEmitter.emitSection(bytecode::Section::kDialectVersions, + std::move(versionEmitter)); + } // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { @@ -479,83 +565,6 @@ //===----------------------------------------------------------------------===// // Attributes and Types -namespace { -class DialectWriter : public DialectBytecodeWriter { -public: - DialectWriter(EncodingEmitter &emitter, IRNumberingState &numberingState, - StringSectionBuilder &stringSection) - : emitter(emitter), numberingState(numberingState), - stringSection(stringSection) {} - - //===--------------------------------------------------------------------===// - // IR - //===--------------------------------------------------------------------===// - - void writeAttribute(Attribute attr) override { - emitter.emitVarInt(numberingState.getNumber(attr)); - } - void writeType(Type type) override { - emitter.emitVarInt(numberingState.getNumber(type)); - } - - void writeResourceHandle(const AsmDialectResourceHandle &resource) override { - emitter.emitVarInt(numberingState.getNumber(resource)); - } - - //===--------------------------------------------------------------------===// - // Primitives - //===--------------------------------------------------------------------===// - - void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } - - void writeSignedVarInt(int64_t value) override { - emitter.emitSignedVarInt(value); - } - - void writeAPIntWithKnownWidth(const APInt &value) override { - size_t bitWidth = value.getBitWidth(); - - // If the value is a single byte, just emit it directly without going - // through a varint. - if (bitWidth <= 8) - return emitter.emitByte(value.getLimitedValue()); - - // If the value fits within a single varint, emit it directly. - if (bitWidth <= 64) - return emitter.emitSignedVarInt(value.getLimitedValue()); - - // Otherwise, we need to encode a variable number of active words. We use - // active words instead of the number of total words under the observation - // that smaller values will be more common. - unsigned numActiveWords = value.getActiveWords(); - emitter.emitVarInt(numActiveWords); - - const uint64_t *rawValueData = value.getRawData(); - for (unsigned i = 0; i < numActiveWords; ++i) - emitter.emitSignedVarInt(rawValueData[i]); - } - - void writeAPFloatWithKnownSemantics(const APFloat &value) override { - writeAPIntWithKnownWidth(value.bitcastToAPInt()); - } - - void writeOwnedString(StringRef str) override { - emitter.emitVarInt(stringSection.insert(str)); - } - - void writeOwnedBlob(ArrayRef blob) override { - emitter.emitVarInt(blob.size()); - emitter.emitOwnedBlob(ArrayRef( - reinterpret_cast(blob.data()), blob.size())); - } - -private: - EncodingEmitter &emitter; - IRNumberingState &numberingState; - StringSectionBuilder &stringSection; -}; -} // namespace - void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { EncodingEmitter attrTypeEmitter; EncodingEmitter offsetEmitter; diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version 0 +// VERSION: bytecode version 127 is newer than the current version //===--------------------------------------------------------------------===// // Producer diff --git a/mlir/test/Bytecode/versioning/versioned-op-1.12.mlirbc b/mlir/test/Bytecode/versioning/versioned-op-1.12.mlirbc new file mode 100644 diff --git a/mlir/test/Bytecode/versioning/versioned-op-2.0.mlirbc b/mlir/test/Bytecode/versioning/versioned-op-2.0.mlirbc new file mode 100644 diff --git a/mlir/test/Bytecode/versioning/versioned-op-2.2.mlirbc b/mlir/test/Bytecode/versioning/versioned-op-2.2.mlirbc new file mode 100644 diff --git a/mlir/test/Bytecode/versioning/versioned_op.mlir b/mlir/test/Bytecode/versioning/versioned_op.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/versioning/versioned_op.mlir @@ -0,0 +1,41 @@ +// This file contains various failure test cases related to the structure of +// the dialect section. + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: target=s390x-{{.*}} + +//===--------------------------------------------------------------------===// +// Test generic +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.0 +// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-op-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 +// CHECK1: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () + +//===--------------------------------------------------------------------===// +// Test upgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 1.12 +// COM: "test.versionedA"() {dimensions = 123 : i64} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2 +// CHECK2: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () + +//===--------------------------------------------------------------------===// +// Test forbidden downgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.2 +// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () +// COM: } +// RUN: not mlir-opt %S/versioned-op-2.2.mlirbc 2>&1 | FileCheck %s --check-prefix=ERR_NEW_VERSION +// ERR_NEW_VERSION: current test dialect version is 2.0, can't parse version: 2.2 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -10,6 +10,7 @@ #include "TestAttributes.h" #include "TestInterfaces.h" #include "TestTypes.h" +#include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,6 +22,7 @@ #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -30,11 +32,16 @@ #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" -#include +#include "llvm/Support/Endian.h" +#include <_types/_uint64_t.h> +#include +#include #include +#include // Include this before the using namespace lines below to // test that we don't have namespace dependencies. @@ -47,6 +54,15 @@ registry.insert(); } +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion : DialectVersion { + uint32_t major = 2; + uint32_t minor = 0; +}; + //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -70,6 +86,55 @@ TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase; }; +// Test support for interacting with the Bytecode reader/writer. +struct TestBytecodeDialectInterface : public BytecodeDialectInterface { + using BytecodeDialectInterface::BytecodeDialectInterface; + TestBytecodeDialectInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + // Emit a specific version of the dialect. + void writeVersion(DialectBytecodeWriter &writer) const final { + llvm::SmallVector bytes; + writer.writeVarInt(0); // major + writer.writeVarInt(2); // minor + } + + std::unique_ptr + parseVersion(DialectBytecodeReader &reader, + ArrayRef versionHandle) const final { + uint64_t major, minor; + if (failed(reader.readVarInt(major)) || failed(reader.readVarInt(minor))) + return nullptr; + auto version = std::make_unique(); + version->major = major; + version->minor = minor; + return version; + } + + LogicalResult upgradeFromVersion(Operation *topLevelOp, + const DialectVersion &version_) const final { + const auto &version = static_cast(version_); + if ((version.major == 2) && (version.minor == 0)) + return success(); + if (version.major > 2 || (version.major == 2 && version.minor > 0)) { + return topLevelOp->emitError() + << "current test dialect version is 2.0, can't parse version: " + << version.major << "." << version.minor; + } + // Prior version 2.0, the old op supported only a single attribute called + // "dimensions". We can perform the upgrade. + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + op->setAttr("modifier", BoolAttr::get(op->getContext(), false)); + }); + + return success(); + } +}; + // Test support for interacting with the AsmPrinter. struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; @@ -366,8 +431,8 @@ auto &blobInterface = addInterface(); addInterface(blobInterface); - addInterfaces(); + addInterfaces(); allowUnknownOperations(); // Instantiate our fallback op interface that we'll use on specific @@ -1103,9 +1168,7 @@ return getOperand(); } -OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { - return getValue(); -} +OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( FoldAdaptor adaptor, SmallVectorImpl &results) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3149,4 +3149,25 @@ }]; } +//===----------------------------------------------------------------------===// +// Test Ops to upgrade base on the dialect versions +//===----------------------------------------------------------------------===// + +def TestVersionedOpA : TEST_Op<"versionedA"> { + // A previous version of the dialect (let's say 1.*) supported an attribute + // named "dimensions": + // let arguments = (ins + // AnyI64Attr:$dimensions + // ); + + // In the current version (2.0) "dimensions" was renamed to "dims", and a new + // boolean attribute "modifier" was added. The previous version of the op + // corresponds to "modifier=false". We support loading old IR through + // upgrading, see `upgradeFromVersion()` in `TestOpAsmInterface`. + let arguments = (ins + AnyI64Attr:$dims, + BoolAttr:$modifier + ); +} + #endif // TEST_OPS