Index: mlir/include/mlir/IR/OpImplementation.h =================================================================== --- mlir/include/mlir/IR/OpImplementation.h +++ mlir/include/mlir/IR/OpImplementation.h @@ -491,6 +491,11 @@ /// the encoding process is not efficient. virtual Location getEncodedSourceLoc(SMLoc loc) = 0; + /// Return the attribute describing the version of the provided dialect name, + /// if any. The version is provided by the `dialect_versions` directive at the + /// very beginning of the parsing. + virtual Attribute getDialectVersion(StringRef dialect) = 0; + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// @@ -1582,6 +1587,25 @@ return AliasResult::NoAlias; } + /// Hook provided by the dialect to emit a version when printing. The + /// attribute will be available when parsing back, and the dialect + /// implementation will be able to use it to load previous known version. + /// This management is entirely under the responsibility of the individual + /// dialects, and limited by possible changes in the MLIR syntax. + /// There is no restriction on what kind of attribute a dialect is using to + /// model its versioning. + virtual Attribute getProducerVersion() const { return {}; } + + /// Hook invoked after parsing completed, if a version directive was present + /// and included an entry for the current dialect. This hook offers the + /// opportunity to the dialect to visit the IR and upgrades constructs emitted + /// by the version of the dialect corresponding to the provided version + /// attribute. + virtual LogicalResult upgradeFromVersion(Operation *topLevelOp, + Attribute producerVersion) const { + return success(); + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// Index: mlir/lib/AsmParser/AsmParserImpl.h =================================================================== --- mlir/lib/AsmParser/AsmParserImpl.h +++ mlir/lib/AsmParser/AsmParserImpl.h @@ -60,6 +60,10 @@ return parser.getEncodedSourceLocation(loc); } + Attribute getDialectVersion(StringRef dialect) override { + return parser.getDialectVersion(dialect); + } + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// Index: mlir/lib/AsmParser/Parser.h =================================================================== --- mlir/lib/AsmParser/Parser.h +++ mlir/lib/AsmParser/Parser.h @@ -36,6 +36,9 @@ ParserState &getState() const { return state; } MLIRContext *getContext() const { return state.config.getContext(); } const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } + Attribute getDialectVersion(StringRef dialect) const { + return state.dialectVersions.get(dialect); + } /// Parse a comma-separated list of elements up until the specified end token. ParseResult Index: mlir/lib/AsmParser/Parser.cpp =================================================================== --- mlir/lib/AsmParser/Parser.cpp +++ mlir/lib/AsmParser/Parser.cpp @@ -803,6 +803,23 @@ if (failed(popSSANameScope())) return failure(); + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (state.dialectVersions) { + for (NamedAttribute dialectVersion : state.dialectVersions) { + Dialect *dialect = + getContext()->getLoadedDialect(dialectVersion.getName()); + if (!dialect) + continue; + auto *asmIface = dialect->getRegisteredInterface(); + if (!asmIface) + continue; + if (failed(asmIface->upgradeFromVersion(topLevelOp, + dialectVersion.getValue()))) + return failure(); + } + } + // Verify that the parsed operations are valid. if (state.config.shouldVerifyAfterParse() && failed(verify(topLevelOp))) return failure(); @@ -2599,6 +2616,23 @@ // Create a top-level operation to contain the parsed state. OwningOpRef topLevelOp(ModuleOp::create(parserLoc)); OperationParser opParser(state, topLevelOp.get()); + + // If the input starts with a `dialect_versions` keyword, we expect a + // dictionary representing the version of the dialect at the time the IR was + // produced. This will be used for possibly upgrading the IR when parsing + // completes. + if (getToken().getSpelling() == "dialect_versions") { + consumeToken(); + Attribute versions = opParser.parseAttribute(); + if (!versions) + return failure(); + state.dialectVersions = versions.dyn_cast(); + if (!state.dialectVersions) { + emitError("Expects Dictionary attribute for dialect versions"); + return failure(); + } + } + while (true) { switch (getToken().getKind()) { default: Index: mlir/lib/AsmParser/ParserState.h =================================================================== --- mlir/lib/AsmParser/ParserState.h +++ mlir/lib/AsmParser/ParserState.h @@ -80,6 +80,8 @@ // popped when done. At the top-level we start with "builtin" as the // default, so that the top-level `module` operation parses as-is. SmallVector defaultDialectStack{"builtin"}; + + DictionaryAttr dialectVersions; }; } // namespace detail Index: mlir/lib/Bytecode/Encoding.h =================================================================== --- mlir/lib/Bytecode/Encoding.h +++ mlir/lib/Bytecode/Encoding.h @@ -61,8 +61,12 @@ /// section. kResourceOffset = 6, + /// This section contains the offsets of resources within the Resource + /// section. + kDialectVersions = 7, + /// The total number of section types. - kNumSections = 7, + kNumSections = 8, }; } // namespace Section Index: mlir/lib/Bytecode/Reader/BytecodeReader.cpp =================================================================== --- mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -47,6 +47,8 @@ return "Resource (5)"; case bytecode::Section::kResourceOffset: return "ResourceOffset (6)"; + case bytecode::Section::kDialectVersions: + return "DialectVersions (7)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } @@ -63,6 +65,7 @@ return false; case bytecode::Section::kResource: case bytecode::Section::kResourceOffset: + case bytecode::Section::kDialectVersions: return true; default: llvm_unreachable("unknown section ID"); @@ -446,6 +449,12 @@ /// The name of the dialect. StringRef name; + + /// Flag to identify if this dialect is versioned. + bool isDialectVersioned = false; + + /// The identifier of the version attribute. + unsigned dialectVersionAttributeID; }; /// This struct represents an operation name entry within the bytecode. @@ -1165,6 +1174,13 @@ LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); + //===--------------------------------------------------------------------===// + // Dialect Versions Section + + /// Parse dialect versions. + LogicalResult + parseDialectVersionsSection(std::optional> sectionData); + //===--------------------------------------------------------------------===// // Value Processing @@ -1302,6 +1318,11 @@ if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); + // Process the dialect section. + if (failed(parseDialectVersionsSection( + sectionDatas[bytecode::Section::kDialectVersions]))) + return failure(); + // Process the resource section if present. if (failed(parseResourceSection( sectionDatas[bytecode::Section::kResource], @@ -1440,6 +1461,25 @@ "not all forward unresolved forward operand references"); } + // Resolve dialect version. + for (auto dialect : dialects) { + if (dialect.isDialectVersioned) { + auto dialectVersion = + attrTypeReader.resolveAttribute(dialect.dialectVersionAttributeID); + + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (!dialect.dialect.has_value()) + continue; + auto *asmIface = + (*dialect.dialect)->getRegisteredInterface(); + if (!asmIface) + continue; + if (failed(asmIface->upgradeFromVersion(*moduleOp, dialectVersion))) + return failure(); + } + } + // Verify that the parsed operations are valid. if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) return failure(); @@ -1673,6 +1713,44 @@ return defineValues(reader, block->getArguments()); } +//===--------------------------------------------------------------------===// +// Dialect Versions Section + +LogicalResult BytecodeReader::parseDialectVersionsSection( + std::optional> sectionData) { + // If the dialect versions are absent, there is nothing to do. + if (!sectionData.has_value()) + return success(); + + EncodingReader sectionReader(sectionData.value(), fileLoc); + + // Parse the number of dialects in the section. + uint64_t numVersionedDialects; + if (failed(sectionReader.parseVarInt(numVersionedDialects))) + return failure(); + + // Parse each of the dialect versions. + llvm::StringMap versionedDialectMap; + for (uint64_t i = 0; i < numVersionedDialects; ++i) { + StringRef dialectName; + uint64_t dialectVersion; + if (failed(stringReader.parseString(sectionReader, dialectName))) + return failure(); + if (failed(sectionReader.parseVarInt(dialectVersion))) + return failure(); + versionedDialectMap.insert({dialectName, dialectVersion}); + } + + for (auto dialect : dialects) { + if (versionedDialectMap.count(dialect.name)) { + dialect.isDialectVersioned = true; + dialect.dialectVersionAttributeID = + versionedDialectMap.lookup(dialect.name); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // Value Processing Index: mlir/lib/Bytecode/Writer/BytecodeWriter.cpp =================================================================== --- mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -385,6 +385,11 @@ void writeStringSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// + // Dialect versions + + void writeDialectVersionsSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// // Fields @@ -425,6 +430,9 @@ // Emit the string section. writeStringSection(emitter); + // Emit the dialect section. + writeDialectVersionsSection(emitter); + // Write the generated bytecode to the provided output stream. emitter.writeTo(os); } @@ -476,6 +484,38 @@ emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter)); } +void BytecodeWriter::writeDialectVersionsSection(EncodingEmitter &emitter) { + EncodingEmitter dialectEmitter; + + // Get dialect version. + auto getVersion = [&](const OpAsmDialectInterface *asmIface) -> Attribute { + // A dialect can be nullptr if not loaded. In such a case, we can't print + // the version properly. + if (asmIface) + return asmIface->getProducerVersion(); + return {}; + }; + + // Emit the referenced dialects. + auto dialects = numberingState.getDialects(); + llvm::SmallVector> dialectVersionPair; + + for (DialectNumbering &dialect : dialects) + if (auto version = getVersion(dialect.asmInterface)) + dialectVersionPair.push_back( + std::make_pair(stringSection.insert(dialect.name), + numberingState.getNumber(version))); + + dialectEmitter.emitVarInt(dialectVersionPair.size()); + for (auto item : dialectVersionPair) { + dialectEmitter.emitVarInt(item.first); + dialectEmitter.emitVarInt(item.second); + } + + emitter.emitSection(bytecode::Section::kDialectVersions, + std::move(dialectEmitter)); +} + //===----------------------------------------------------------------------===// // Attributes and Types Index: mlir/lib/Bytecode/Writer/IRNumbering.cpp =================================================================== --- mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" using namespace mlir; using namespace mlir::bytecode::detail; @@ -223,6 +224,11 @@ numbering = &numberDialect(dialect->getNamespace()); numbering->interface = dyn_cast(dialect); numbering->asmInterface = dyn_cast(dialect); + if (auto *asmIface = numbering->asmInterface) { + auto version = asmIface->getProducerVersion(); + if (version) + number(version); + } } return *numbering; } Index: mlir/lib/IR/AsmPrinter.cpp =================================================================== --- mlir/lib/IR/AsmPrinter.cpp +++ mlir/lib/IR/AsmPrinter.cpp @@ -3075,6 +3075,33 @@ } // namespace void OperationPrinter::printTopLevelOperation(Operation *op) { + // Retrieve the version to print for this dialect, if any. + auto getVersion = [&](Dialect *dialect) -> Attribute { + const auto interfaces = state.getDialectInterfaces(); + const OpAsmDialectInterface *asmIface = interfaces.getInterfaceFor(dialect); + if (!asmIface) + return {}; + return asmIface->getProducerVersion(); + }; + + // If any dialect has a version to print, print the `dialect_versions` + // directive. This directive has to be the very first directive in the file + // and introduce a dictionary where the keys are dialect names and the value + // are attributes representing the producer version for the given dialect. + // There is no restriction on what kind of attribute a dialect is using to + // model its versioning. + std::vector loadedDialects = op->getContext()->getLoadedDialects(); + llvm::erase_if(loadedDialects, + [&](Dialect *dialect) { return !getVersion(dialect); }); + if (llvm::any_of(loadedDialects, getVersion)) { + os << "dialect_versions { "; + interleaveComma(loadedDialects, [&](Dialect *dialect) { + if (Attribute version = getVersion(dialect)) + os << dialect->getNamespace() << " = " << version; + }); + os << " }" << newLine; + } + // Output the aliases at the top level that can't be deferred. state.getAliasState().printNonDeferredAliases(*this, newLine); Index: mlir/test/lib/Dialect/Test/TestDialect.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.cpp +++ mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -164,6 +164,35 @@ blobManager.buildResources(provider, referencedResources.getArrayRef()); } + // Test IR upgrade with dialect version lookup. + Attribute getProducerVersion() const final { + return Builder(getDialect()->getContext()).getI32IntegerAttr(42); + } + LogicalResult upgradeFromVersion(Operation *topLevelOp, + Attribute producerVersion) const final { + auto version = producerVersion.dyn_cast(); + if (!version) { + return topLevelOp->emitError() + << "Expected an IntegerAttr for test dialect version, got: " + << producerVersion; + } + if (version.getInt() == 42) + return success(); + if (version.getInt() > 42) { + return topLevelOp->emitError() + << "Current test dialect version is 42, can't parse version: " + << version; + } + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + }); + + return success(); + } + private: /// The blob manager for the dialect. TestResourceBlobManagerInterface &blobManager; @@ -1634,6 +1663,24 @@ setResultRanges(getResult(), range); } +//===----------------------------------------------------------------------===// +// Test dialect_version upgrade by supporting an old syntax +//===----------------------------------------------------------------------===// + +ParseResult TestVersionedOpB::parse(mlir::OpAsmParser &parser, + mlir::OperationState &state) { + if (auto version = + parser.getDialectVersion("test").dyn_cast_or_null()) { + if (version.getInt() < 42) + return parser.parseKeyword("deprecated_syntax"); + } + return parser.parseKeyword("current_version"); +} + +void TestVersionedOpB::print(OpAsmPrinter &printer) { + printer << " current_version"; +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestTypeInterfaces.cpp.inc" Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -3149,4 +3149,26 @@ }]; } +//===----------------------------------------------------------------------===// +// Test Ops to upgrade base on the dialect versions +//===----------------------------------------------------------------------===// + +def TestVersionedOpA : TEST_Op<"versionedA"> { + let arguments = (ins + // A previous version of this dialect used the name "dimensions" for this + // attribute, it got renamed but we support loading old IR through + // upgrading, see `upgradeFromVersion()` in `TestOpAsmInterface`. + AnyI64Attr:$dims + ); +} + +// This op will be able to parse based on an old syntax and "auto-upgrade". +def TestVersionedOpB : TEST_Op<"versionedB"> { + let arguments = (ins); + + let hasCustomAssemblyFormat = 1; +// let parser = [{ return ::parseVersionedOp(parser, result); }]; +// let printer = [{ return ::print(*this, p); }]; +} + #endif // TEST_OPS