diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -403,3 +403,24 @@ A block is encoded with an array of operations and block arguments. The first field is an encoding that combines the number of operations in the block, with a flag indicating if the block has arguments. + +### Dialect Version Section + +``` +dialect_version_section { + dialects: dialect_info[] +} + +dialect_info { + dialect: varint, + size: varint, + version: byte[] +} +``` + +The dialect version section contains details about the version for each of the +versioned dialects contained in the module. A dialect is considered versioned if +implements the OpAsmDialectInterface hooks as described in the MLIR Language +Reference. The dialect info section is optional and written only if the dialect +is registered in the context. When version information for a versioned dialect +are missing, a dialect will be parsed without performing upgrades. \ No newline at end of file diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -845,3 +845,18 @@ that are directly usable by any other dialect in MLIR. These types cover a range from primitive integer and floating-point values, attribute dictionaries, dense multi-dimensional arrays, and more. + +### IR Versionning + +A dialect can opt-in to handle versioning in a custom way. Two hooks to expose a +dialect version to the printer and parser are available through the +`OpAsmDialectInterface`. First, the `getProducerVersion()` method allows to +inject a target dialect version to the printer, which will be written to the +output mlir file. Second, the `upgradeFromVersion()` method allows to retrieve +the version information while parsing the input IR, and gives an opportunity to +each dialect for which a version is present to perform IR upgrades. + +Version information are stored on a buffer through an `AsmDialectVersionHandle`. +There is no restriction on what kind of information a dialect is allowed to +encode to model its versioning. Currently, versioning is supported only for +bytecode formats. \ No newline at end of file diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -450,6 +450,68 @@ return p; } +//===--------------------------------------------------------------------===// +// Dialect Asm Version Interface. +//===--------------------------------------------------------------------===// + +namespace { +/// Simple wrapper around StorageAllocator to expose a buffer through +/// `AsmDialectVersionHandle`. +class AsmDialectVersionStorage : private StorageUniquer::StorageAllocator { +public: + AsmDialectVersionStorage(ArrayRef in) { buffer = copyInto(in); }; + auto getBuffer() const { return buffer; } + +private: + ArrayRef buffer; +}; +} // namespace + +/// This class represents a handle to a dialect version storage. It is a +/// light-weight class that holds a pointer to a buffer and can be easily +/// iterated on and passed around by value. Each dialect will be able to encode +/// and decode the a version onto the storage exposed through the handle using +/// the `getProducerVersion` and `upgradeFromVersion` methods of the +/// `OpAsmDialectInterface`. +class AsmDialectVersionHandle { +public: + AsmDialectVersionHandle() = default; + AsmDialectVersionHandle(ArrayRef buffer, StringRef name) + : storage(std::make_shared(buffer)), + dialectName(name) {} + operator bool() const { return storage && !storage->getBuffer().empty(); } + + /// Return a reference to the storage buffer. + auto getBuffer() const { + assert(storage && "buffer must exist to be retrieved"); + return storage->getBuffer(); + } + + /// Return an opaque pointer to the raw data. + const void *getRawData() const { + if (storage) + return reinterpret_cast(storage->getBuffer().data()); + return nullptr; + } + + /// Return the size of the storage buffer. + auto size() const { + if (!storage) + return size_t(0); + return storage->getBuffer().size(); + } + + /// Return the dialect that owns the version. + StringRef getDialectName() const { return dialectName; } + +private: + /// The data associated with the version. + std::shared_ptr storage; + + /// The dialect owning the version. + StringRef dialectName; +}; + //===----------------------------------------------------------------------===// // AsmParser //===----------------------------------------------------------------------===// @@ -1450,9 +1512,9 @@ //===--------------------------------------------------------------------===// struct Argument { - UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. - Type type; // Type. - DictionaryAttr attrs; // Attributes if present. + UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. + Type type; // Type. + DictionaryAttr attrs; // Attributes if present. std::optional sourceLoc; // Source location specifier if present. }; @@ -1582,6 +1644,22 @@ return AliasResult::NoAlias; } + /// Hook provided by the dialect to emit a version when printing. The + /// handle will be available when parsing back, and the dialect implementation + /// will be able to use it to load previous known version. This management is + /// entirely under the responsibility of the individual dialects. + virtual AsmDialectVersionHandle getProducerVersion() const { return {}; } + + /// Hook invoked after parsing completed, if a version directive was present + /// and included an entry for the current dialect. This hook offers the + /// opportunity to the dialect to visit the IR and upgrades constructs emitted + /// by the version of the dialect corresponding to the provided version. + virtual LogicalResult + upgradeFromVersion(Operation *topLevelOp, + AsmDialectVersionHandle versionHandle) const { + return success(); + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -61,8 +61,11 @@ /// section. kResourceOffset = 6, + /// This section contains the versions of each dialect. + kDialectVersions = 7, + /// The total number of section types. - kNumSections = 7, + kNumSections = 8, }; } // namespace Section diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -47,6 +47,8 @@ return "Resource (5)"; case bytecode::Section::kResourceOffset: return "ResourceOffset (6)"; + case bytecode::Section::kDialectVersions: + return "DialectVersions (7)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } @@ -63,6 +65,7 @@ return false; case bytecode::Section::kResource: case bytecode::Section::kResourceOffset: + case bytecode::Section::kDialectVersions: return true; default: llvm_unreachable("unknown section ID"); @@ -446,6 +449,9 @@ /// The name of the dialect. StringRef name; + + /// Handle for the dialect version parsed. + AsmDialectVersionHandle versionHandle; }; /// This struct represents an operation name entry within the bytecode. @@ -1165,6 +1171,13 @@ LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); + //===--------------------------------------------------------------------===// + // Dialect Versions Section + + /// Parse dialect versions. + LogicalResult + parseDialectVersionsSection(std::optional> sectionData); + //===--------------------------------------------------------------------===// // Value Processing @@ -1302,6 +1315,11 @@ if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); + // Process the dialect version section. + if (failed(parseDialectVersionsSection( + sectionDatas[bytecode::Section::kDialectVersions]))) + return failure(); + // Process the resource section if present. if (failed(parseResourceSection( sectionDatas[bytecode::Section::kResource], @@ -1440,6 +1458,23 @@ "not all forward unresolved forward operand references"); } + // Resolve dialect version. + for (const BytecodeDialect &byteCodeDialect : dialects) { + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (byteCodeDialect.versionHandle) { + Dialect *dialect = moduleOp->getContext()->getOrLoadDialect( + byteCodeDialect.versionHandle.getDialectName()); + auto *asmIface = + llvm::dyn_cast_or_null(dialect); + if (!asmIface) + continue; + if (failed(asmIface->upgradeFromVersion(*moduleOp, + byteCodeDialect.versionHandle))) + return failure(); + } + } + // Verify that the parsed operations are valid. if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) return failure(); @@ -1673,6 +1708,45 @@ return defineValues(reader, block->getArguments()); } +//===--------------------------------------------------------------------===// +// Dialect Versions Section + +LogicalResult BytecodeReader::parseDialectVersionsSection( + std::optional> sectionData) { + // If the dialect versions are absent, there is nothing to do. + if (!sectionData.has_value()) + return success(); + + EncodingReader sectionReader(sectionData.value(), fileLoc); + + // Parse the number of dialects in the section. + uint64_t numVersionedDialects; + if (failed(sectionReader.parseVarInt(numVersionedDialects))) + return failure(); + + // Parse each of the dialect versions. + llvm::StringMap versionedDialectMap; + for (uint64_t i = 0; i < numVersionedDialects; ++i) { + StringRef dialectName; + if (failed(stringReader.parseString(sectionReader, dialectName))) + return failure(); + uint64_t dialectVersionSize; + if (failed(sectionReader.parseVarInt(dialectVersionSize))) + return failure(); + ArrayRef bytes; + if (failed(sectionReader.parseBytes(dialectVersionSize, bytes))) + return failure(); + versionedDialectMap.insert( + {dialectName, AsmDialectVersionHandle(bytes, dialectName)}); + } + + for (auto &dialect : dialects) + if (versionedDialectMap.count(dialect.name)) + dialect.versionHandle = versionedDialectMap.lookup(dialect.name); + + return success(); +} + //===----------------------------------------------------------------------===// // Value Processing diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -385,6 +385,11 @@ void writeStringSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// + // Dialect versions + + void writeDialectVersionsSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// // Fields @@ -425,6 +430,9 @@ // Emit the string section. writeStringSection(emitter); + // Emit the dialect section. + writeDialectVersionsSection(emitter); + // Write the generated bytecode to the provided output stream. emitter.writeTo(os); } @@ -476,6 +484,41 @@ emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter)); } +void BytecodeWriter::writeDialectVersionsSection(EncodingEmitter &emitter) { + EncodingEmitter dialectEmitter; + + // Get dialect version. + auto getVersion = + [&](const OpAsmDialectInterface *asmIface) -> AsmDialectVersionHandle { + // A dialect can be nullptr if not loaded. In such a case, we can't print + // the version properly. + if (asmIface) + return asmIface->getProducerVersion(); + return {}; + }; + + // Emit the referenced dialects. + auto dialects = numberingState.getDialects(); + llvm::SmallVector> + dialectVersionPair; + + for (DialectNumbering &dialect : dialects) { + if (auto version = getVersion(dialect.asmInterface)) { + dialectVersionPair.push_back( + std::make_pair(stringSection.insert(dialect.name), version)); + } + } + dialectEmitter.emitVarInt(dialectVersionPair.size()); + for (auto item : dialectVersionPair) { + dialectEmitter.emitVarInt(item.first); + dialectEmitter.emitVarInt(item.second.size()); + dialectEmitter.emitBytes(item.second.getBuffer()); + } + + emitter.emitSection(bytecode::Section::kDialectVersions, + std::move(dialectEmitter)); +} + //===----------------------------------------------------------------------===// // Attributes and Types diff --git a/mlir/test/Bytecode/versioning/versioned-op-1.12.mlirbc b/mlir/test/Bytecode/versioning/versioned-op-1.12.mlirbc new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@ () +// COM: } +// RUN: mlir-opt %S/versioned-op-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 +// CHECK1: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () + +//===--------------------------------------------------------------------===// +// Test upgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 1.12 +// COM: "test.versionedA"() {dimensions = 123 : i64} : () -> () +// COM: } +// RUN: mlir-opt %S/versioned-op-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2 +// CHECK2: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () + +//===--------------------------------------------------------------------===// +// Test forbidden downgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.2 +// COM: "test.versionedA"() {dims = 123 : i64, modifier = false} : () -> () +// COM: } +// RUN: not mlir-opt %S/versioned-op-2.2.mlirbc 2>&1 | FileCheck %s --check-prefix=ERR_NEW_VERSION +// ERR_NEW_VERSION: current test dialect version is 2.0, can't parse version: 2.2 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -32,9 +32,10 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" -#include +#include "llvm/Support/Endian.h" #include +#include // Include this before the using namespace lines below to // test that we don't have namespace dependencies. @@ -47,6 +48,37 @@ registry.insert(); } +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion { + uint32_t major = 2; + uint32_t minor = 0; +}; + +// Encode/decode a version attribute +AsmDialectVersionHandle encodeDialectVersion(MLIRContext *ctx, + const TestDialectVersion version) { + uint64_t encoding; + uint32_t *ptr = (uint32_t *)&encoding; + llvm::support::endian::write32le((void *)ptr++, version.major); + llvm::support::endian::write32le((void *)ptr, version.minor); + ArrayRef encode(reinterpret_cast(&encoding), + sizeof(encoding)); + return AsmDialectVersionHandle( + encode, ctx->getOrLoadDialect()->getNamespace()); +} + +TestDialectVersion decodeDialectVersion(AsmDialectVersionHandle handle) { + uint64_t encoding = *reinterpret_cast(handle.getRawData()); + uint32_t *ptr = (uint32_t *)&encoding; + TestDialectVersion version; + version.major = llvm::support::endian::read32le((void *)ptr++); + version.minor = llvm::support::endian::read32le((void *)ptr); + return version; +} + //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -164,6 +196,35 @@ blobManager.buildResources(provider, referencedResources.getArrayRef()); } + // Test IR upgrade with dialect version lookup. + AsmDialectVersionHandle getProducerVersion() const final { + return encodeDialectVersion(getContext(), TestDialectVersion()); + } + + LogicalResult + upgradeFromVersion(Operation *topLevelOp, + AsmDialectVersionHandle producerVersion) const final { + auto version = decodeDialectVersion(producerVersion); + if ((version.major == 2) && (version.minor == 0)) + return success(); + if (version.major > 2 || (version.major == 2 && version.minor > 0)) { + return topLevelOp->emitError() + << "current test dialect version is 2.0, can't parse version: " + << version.major << "." << version.minor; + } + // Prior version 2.0, the old op supported only a single attribute called + // "dimensions". We can perform the upgrade. + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + op->setAttr("modifier", BoolAttr::get(op->getContext(), false)); + }); + + return success(); + } + private: /// The blob manager for the dialect. TestResourceBlobManagerInterface &blobManager; @@ -1102,9 +1163,7 @@ return getOperand(); } -OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { - return getValue(); -} +OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( FoldAdaptor adaptor, SmallVectorImpl &results) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3149,4 +3149,25 @@ }]; } +//===----------------------------------------------------------------------===// +// Test Ops to upgrade base on the dialect versions +//===----------------------------------------------------------------------===// + +def TestVersionedOpA : TEST_Op<"versionedA"> { + // A previous version of the dialect (let's say 1.*) supported an attribute + // named "dimensions": + // let arguments = (ins + // AnyI64Attr:$dimensions + // ); + + // In the current version (2.0) "dimensions" was renamed to "dims", and a new + // boolean attribute "modifier" was added. The previous version of the op + // corresponds to "modifier=false". We support loading old IR through + // upgrading, see `upgradeFromVersion()` in `TestOpAsmInterface`. + let arguments = (ins + AnyI64Attr:$dims, + BoolAttr:$modifier + ); +} + #endif // TEST_OPS