diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -450,6 +450,64 @@ return p; } +//===--------------------------------------------------------------------===// +// Dialect Asm Version Interface. +//===--------------------------------------------------------------------===// + +namespace { +/// Simple wrapper around StorageAllocator to expose a buffer through +/// `AsmDialectVersionHandle`. +class AsmDialectVersionStorage : private StorageUniquer::StorageAllocator { +public: + AsmDialectVersionStorage(ArrayRef in) { buffer = copyInto(in); }; + auto getBuffer() { return buffer; } + +private: + ArrayRef buffer; +}; +} // namespace + +/// This class represents a handle to a dialect version storage. +class AsmDialectVersionHandle { +public: + AsmDialectVersionHandle() = default; + AsmDialectVersionHandle(ArrayRef buffer, StringRef dialectName) + : dialectName(dialectName) { + storage = std::make_shared(buffer); + } + operator bool() { return storage && !storage->getBuffer().empty(); } + + /// Return a reference to the storage buffer. + auto getBuffer() const { + assert(storage && "buffer must exist to be retrieved"); + return storage->getBuffer(); + } + + /// Return an opaque pointer to the raw data. + const void *getRawData() const { + if (storage) + return reinterpret_cast(storage->getBuffer().data()); + return nullptr; + } + + /// Return the size of the storage buffer. + auto size() const { + if (!storage) + return size_t(0); + return storage->getBuffer().size(); + } + + /// Return the dialect that owns the version. + StringRef getDialectName() const { return dialectName; } + +private: + /// The data associated with the version. + std::shared_ptr storage; + + /// The dialect owning the version. + StringRef dialectName; +}; + //===----------------------------------------------------------------------===// // AsmParser //===----------------------------------------------------------------------===// @@ -491,6 +549,11 @@ /// the encoding process is not efficient. virtual Location getEncodedSourceLoc(SMLoc loc) = 0; + /// Return the attribute describing the version of the provided dialect name, + /// if any. The version is provided by the `dialect_versions` directive at the + /// very beginning of the parsing. + virtual AsmDialectVersionHandle getDialectVersion(StringRef dialect) = 0; + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// @@ -1450,9 +1513,9 @@ //===--------------------------------------------------------------------===// struct Argument { - UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. - Type type; // Type. - DictionaryAttr attrs; // Attributes if present. + UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. + Type type; // Type. + DictionaryAttr attrs; // Attributes if present. std::optional sourceLoc; // Source location specifier if present. }; @@ -1582,6 +1645,35 @@ return AliasResult::NoAlias; } + /// Hook provided by the dialect to emit a version when printing. The + /// handle will be available when parsing back, and the dialect implementation + /// will be able to use it to load previous known version. This management is + /// entirely under the responsibility of the individual dialects. + virtual AsmDialectVersionHandle getProducerVersion() const { return {}; } + + /// Hook invoked after parsing completed, if a version directive was present + /// and included an entry for the current dialect. This hook offers the + /// opportunity to the dialect to visit the IR and upgrades constructs emitted + /// by the version of the dialect corresponding to the provided version. + virtual LogicalResult + upgradeFromVersion(Operation *topLevelOp, + AsmDialectVersionHandle versionHandle) const { + return success(); + } + + /// Hook exposed to the dialect to parse the version from the provided token + /// and return it as `AsmDialectVersionHandle`. + virtual FailureOr + parseVersionAsString(StringRef token) const { + return failure(); + } + + /// Hook exposed to the dialect to print the version as a custom string. + virtual FailureOr + printVersionAsString(AsmDialectVersionHandle versionHandle) const { + return failure(); + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -60,6 +60,10 @@ return parser.getEncodedSourceLocation(loc); } + AsmDialectVersionHandle getDialectVersion(StringRef dialect) override { + return parser.getDialectVersion(dialect); + } + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -36,6 +36,9 @@ ParserState &getState() const { return state; } MLIRContext *getContext() const { return state.config.getContext(); } const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } + AsmDialectVersionHandle getDialectVersion(StringRef dialect) const { + return state.dialectVersions.lookup(dialect); + } /// Parse a comma-separated list of elements up until the specified end token. ParseResult diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -803,6 +803,24 @@ if (failed(popSSANameScope())) return failure(); + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (!state.dialectVersions.empty()) { + for (auto &dialectVersion : state.dialectVersions) { + auto version = dialectVersion.getValue(); + Dialect *dialect = + topLevelOp->getContext()->getOrLoadDialect(version.getDialectName()); + if (!dialect) + continue; + auto *asmIface = dialect->getRegisteredInterface(); + if (!asmIface) + continue; + if (failed(asmIface->upgradeFromVersion(topLevelOp, + dialectVersion.getValue()))) + return failure(); + } + } + // Verify that the parsed operations are valid. if (state.config.shouldVerifyAfterParse() && failed(verify(topLevelOp))) return failure(); @@ -2353,6 +2371,14 @@ function_ref parseBody); ParseResult parseDialectResourceFileMetadata(); ParseResult parseExternalResourceFileMetadata(); + + /// Parse a top-level file dialect version dictionary. + /// + /// version-dict ::= 'dialect_versions {' version-entry* `}' + /// + ParseResult parseDialectVersionDictionary(); + ParseResult parseVersionEntry(AsmDialectVersionHandle &handle, + StringRef dialectName); }; /// This class represents an implementation of a resource entry for the MLIR @@ -2594,6 +2620,75 @@ }); } +ParseResult TopLevelOperationParser::parseDialectVersionDictionary() { + // If the input starts with a `dialect_versions` keyword, we expect a + // dictionary representing the version of the dialect at the time the IR was + // produced. This will be used for possibly upgrading the IR when parsing + // completes. + if (failed(parseToken(Token::kw_dialect_versions, + "expected 'dialect_versions'"))) + return failure(); + + return parseCommaSeparatedList(Delimiter::Braces, [&]() -> ParseResult { + // Parse the name of the dialect entry. + StringRef dialectName = getTokenSpelling(); + consumeToken(); + + if (failed(parseToken(Token::equal, "expected '='"))) + return failure(); + + AsmDialectVersionHandle handle; + if (failed(parseVersionEntry(handle, dialectName))) + return failure(); + + state.dialectVersions.insert({dialectName, handle}); + return success(); + }); +} + +ParseResult +TopLevelOperationParser::parseVersionEntry(AsmDialectVersionHandle &handle, + StringRef dialectName) { + if (failed(parseToken(Token::kw_version, "expected 'version'"))) + return failure(); + if (failed(parseToken(Token::less, "expected '<' after 'version'"))) + return failure(); + + if (!getToken().is(Token::string)) { + emitWrongTokenError("expected string after 'version<'"); + return failure(); + } + + // If the token represents a hex string, parse it as hex. + std::optional result = getToken().getHexStringValue(); + auto *iface = llvm::dyn_cast_or_null( + getContext()->getOrLoadDialect(dialectName)); + if (!result.has_value() && !iface) + return failure(); + + // If we couldn't parse the token as hex string, then we can try using the + // custom parser exposed to the dialect. + if (!result.has_value() || (*result).empty()) { + FailureOr handleOr = + iface->parseVersionAsString(getToken().getStringValue()); + if (succeeded(handleOr)) + handle = *handleOr; + else + return failure(); + } + consumeToken(Token::string); + + if (failed(parseToken(Token::greater, "expected '>'"))) + return failure(); + + if (handle) + return success(); + + llvm::SmallVector cast((*result).begin(), (*result).end()); + handle = AsmDialectVersionHandle(cast, dialectName); + return success(); +} + ParseResult TopLevelOperationParser::parse(Block *topLevelBlock, Location parserLoc) { // Create a top-level operation to contain the parsed state. @@ -2639,11 +2734,17 @@ return failure(); break; - // Parse a file-level metadata dictionary. + // Parse a file-level metadata dictionary. case Token::file_metadata_begin: if (parseFileMetadataDictionary()) return failure(); break; + + // Parse a file-level version dictionary. + case Token::kw_dialect_versions: + if (parseDialectVersionDictionary()) + return failure(); + break; } } } diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h --- a/mlir/lib/AsmParser/ParserState.h +++ b/mlir/lib/AsmParser/ParserState.h @@ -80,6 +80,9 @@ // popped when done. At the top-level we start with "builtin" as the // default, so that the top-level `module` operation parses as-is. SmallVector defaultDialectStack{"builtin"}; + + /// A map between a dialect name and its version. + llvm::StringMap dialectVersions; }; } // namespace detail diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -122,6 +122,8 @@ TOK_KEYWORD(type) TOK_KEYWORD(unit) TOK_KEYWORD(vector) +TOK_KEYWORD(version) +TOK_KEYWORD(dialect_versions) #undef TOK_MARKER #undef TOK_IDENTIFIER diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -61,8 +61,11 @@ /// section. kResourceOffset = 6, + /// This section contains the versions of each dialect. + kDialectVersions = 7, + /// The total number of section types. - kNumSections = 7, + kNumSections = 8, }; } // namespace Section diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -47,6 +47,8 @@ return "Resource (5)"; case bytecode::Section::kResourceOffset: return "ResourceOffset (6)"; + case bytecode::Section::kDialectVersions: + return "DialectVersions (7)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } @@ -63,6 +65,7 @@ return false; case bytecode::Section::kResource: case bytecode::Section::kResourceOffset: + case bytecode::Section::kDialectVersions: return true; default: llvm_unreachable("unknown section ID"); @@ -446,6 +449,9 @@ /// The name of the dialect. StringRef name; + + /// Handle for the dialect version we are parsing. + AsmDialectVersionHandle version; }; /// This struct represents an operation name entry within the bytecode. @@ -1165,6 +1171,13 @@ LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); + //===--------------------------------------------------------------------===// + // Dialect Versions Section + + /// Parse dialect versions. + LogicalResult + parseDialectVersionsSection(std::optional> sectionData); + //===--------------------------------------------------------------------===// // Value Processing @@ -1302,6 +1315,11 @@ if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); + // Process the dialect section. + if (failed(parseDialectVersionsSection( + sectionDatas[bytecode::Section::kDialectVersions]))) + return failure(); + // Process the resource section if present. if (failed(parseResourceSection( sectionDatas[bytecode::Section::kResource], @@ -1440,6 +1458,25 @@ "not all forward unresolved forward operand references"); } + // Resolve dialect version. + for (auto byteCodeDialect : dialects) { + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (byteCodeDialect.version) { + Dialect *dialect = moduleOp->getContext()->getOrLoadDialect( + byteCodeDialect.version.getDialectName()); + if (!dialect) + continue; + auto *asmIface = + (dialect)->getRegisteredInterface(); + if (!asmIface) + continue; + if (failed( + asmIface->upgradeFromVersion(*moduleOp, byteCodeDialect.version))) + return failure(); + } + } + // Verify that the parsed operations are valid. if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) return failure(); @@ -1673,6 +1710,45 @@ return defineValues(reader, block->getArguments()); } +//===--------------------------------------------------------------------===// +// Dialect Versions Section + +LogicalResult BytecodeReader::parseDialectVersionsSection( + std::optional> sectionData) { + // If the dialect versions are absent, there is nothing to do. + if (!sectionData.has_value()) + return success(); + + EncodingReader sectionReader(sectionData.value(), fileLoc); + + // Parse the number of dialects in the section. + uint64_t numVersionedDialects; + if (failed(sectionReader.parseVarInt(numVersionedDialects))) + return failure(); + + // Parse each of the dialect versions. + llvm::StringMap versionedDialectMap; + for (uint64_t i = 0; i < numVersionedDialects; ++i) { + StringRef dialectName; + uint64_t dialectVersionSize; + if (failed(stringReader.parseString(sectionReader, dialectName))) + return failure(); + if (failed(sectionReader.parseVarInt(dialectVersionSize))) + return failure(); + ArrayRef bytes; + if (failed(sectionReader.parseBytes(dialectVersionSize, bytes))) + return failure(); + versionedDialectMap.insert( + {dialectName, AsmDialectVersionHandle(bytes, dialectName)}); + } + + for (auto dialect : dialects) + if (versionedDialectMap.count(dialect.name)) + dialect.version = versionedDialectMap.lookup(dialect.name); + + return success(); +} + //===----------------------------------------------------------------------===// // Value Processing diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -385,6 +385,11 @@ void writeStringSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// + // Dialect versions + + void writeDialectVersionsSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// // Fields @@ -425,6 +430,9 @@ // Emit the string section. writeStringSection(emitter); + // Emit the dialect section. + writeDialectVersionsSection(emitter); + // Write the generated bytecode to the provided output stream. emitter.writeTo(os); } @@ -476,6 +484,41 @@ emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter)); } +void BytecodeWriter::writeDialectVersionsSection(EncodingEmitter &emitter) { + EncodingEmitter dialectEmitter; + + // Get dialect version. + auto getVersion = + [&](const OpAsmDialectInterface *asmIface) -> AsmDialectVersionHandle { + // A dialect can be nullptr if not loaded. In such a case, we can't print + // the version properly. + if (asmIface) + return asmIface->getProducerVersion(); + return {}; + }; + + // Emit the referenced dialects. + auto dialects = numberingState.getDialects(); + llvm::SmallVector> + dialectVersionPair; + + for (DialectNumbering &dialect : dialects) { + if (auto version = getVersion(dialect.asmInterface)) { + dialectVersionPair.push_back( + std::make_pair(stringSection.insert(dialect.name), version)); + } + } + dialectEmitter.emitVarInt(dialectVersionPair.size()); + for (auto item : dialectVersionPair) { + dialectEmitter.emitVarInt(item.first); + dialectEmitter.emitVarInt(item.second.size()); + dialectEmitter.emitBytes(item.second.getBuffer()); + } + + emitter.emitSection(bytecode::Section::kDialectVersions, + std::move(dialectEmitter)); +} + //===----------------------------------------------------------------------===// // Attributes and Types diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -44,8 +44,8 @@ #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/Threading.h" -#include #include +#include using namespace mlir; using namespace mlir::detail; @@ -3061,6 +3061,10 @@ void printResourceFileMetadata(function_ref checkAddMetadataDict, Operation *op); + /// Print a dictionary containing dialect names and corresponding dialect + /// versions. + void printDialectVersions(Operation *op); + // Contains the stack of default dialects to use when printing regions. // A new dialect is pushed to the stack before parsing regions nested under an // operation implementing `OpAsmOpInterface`, and popped when done. At the @@ -3077,6 +3081,12 @@ } // namespace void OperationPrinter::printTopLevelOperation(Operation *op) { + // If any dialect has a version to print, print the `dialect_versions` + // directive. This directive has to be the very first directive in the file + // and introduce a dictionary where the keys are dialect names and the value + // are the buffers representing the producer version for the given dialect. + printDialectVersions(op); + // Output the aliases at the top level that can't be deferred. state.getAliasState().printNonDeferredAliases(*this, newLine); @@ -3159,6 +3169,46 @@ os << newLine << " }"; } +void OperationPrinter::printDialectVersions(Operation *op) { + // Retrieve the version to print for this dialect, if any. + auto getVersion = [&](Dialect *dialect) -> AsmDialectVersionHandle { + const auto interfaces = state.getDialectInterfaces(); + const OpAsmDialectInterface *asmIface = interfaces.getInterfaceFor(dialect); + if (!asmIface) + return {}; + return asmIface->getProducerVersion(); + }; + + llvm::SmallVector versionHandles; + for (auto dialect : op->getContext()->getLoadedDialects()) { + auto version = getVersion(dialect); + if (version) + versionHandles.push_back(version); + } + + if (versionHandles.empty()) + return; + + os << "dialect_versions { "; + interleaveComma(versionHandles, [&](auto item) { + auto dialectName = item.getDialectName(); + const auto interfaces = state.getDialectInterfaces(); + auto *dialect = op->getContext()->getLoadedDialect(dialectName); + const OpAsmDialectInterface *asmIface = interfaces.getInterfaceFor(dialect); + ::printKeywordOrString(dialect->getNamespace(), os); + os << " = version<"; + auto stringOr = asmIface->printVersionAsString(item); + if (succeeded(stringOr)) { + os << "\"" << *stringOr << "\""; + } else { + os << "\"0x" << llvm::toHex(item.getBuffer()) << "\""; + } + os << ">"; + }); + os << " }" << newLine; + return; +} + /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. diff --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir --- a/mlir/test/Bytecode/general.mlir +++ b/mlir/test/Bytecode/general.mlir @@ -3,6 +3,7 @@ // Bytecode currently does not support big-endian platforms // UNSUPPORTED: target=s390x-{{.*}} +// CHECK: dialect_versions { test = version<"1.42"> } // CHECK-LABEL: "bytecode.test1" // CHECK-NEXT: "bytecode.empty"() : () -> () // CHECK-NEXT: "bytecode.attributes"() {attra = 10 : i64, attrb = #bytecode.attr} : () -> () diff --git a/mlir/test/IR/ir_upgrade.mlir b/mlir/test/IR/ir_upgrade.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/ir_upgrade.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt -split-input-file --verify-diagnostics %s | FileCheck %s + +dialect_versions { test = version<"0x0100000029000000"> } + +// ----- + +dialect_versions { test = version<"1.39"> } + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: "test.versionedA"() {dims = 123 : i64} : () -> () +"test.versionedA"() {dims = 123 : i64} : () -> () + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: "test.versionedA"() {dims = 123 : i64} : () -> () +dialect_versions { test = version<"0x0100000029000000"> } +"test.versionedA"() {dimensions = 123 : i64} : () -> () + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: "test.versionedA"() {dims = 123 : i64} : () -> () +dialect_versions { test = version<"0x0100000029000000"> } +"test.versionedA"() {dims = 123 : i64} : () -> () + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: test.versionedB current_version +dialect_versions { test = version<"0x0120000009000000"> } +test.versionedB deprecated_syntax + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: test.versionedB current_version +dialect_versions { test = version<"0x010000002A000000"> } +test.versionedB current_version + +// ----- + +dialect_versions { test = version<"0x010000002A000000"> } +// expected-error@+1{{custom op 'test.versionedB' expected 'current_version'}} +test.versionedB deprecated_syntax + +// ----- +// expected-error@-2{{current test dialect version is 1.42, can't parse version: 1.43}} +dialect_versions { test = version<"1.43"> } +"test.versionedA"() {dims = 123 : i64} : () -> () \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -32,9 +32,9 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" -#include #include +#include // Include this before the using namespace lines below to // test that we don't have namespace dependencies. @@ -47,6 +47,30 @@ registry.insert(); } +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion { + int major = 1; + int minor = 42; +}; + +// Encode/decode a version attribute +AsmDialectVersionHandle encodeDialectVersion(MLIRContext *ctx, + const TestDialectVersion version) { + ArrayRef encode(reinterpret_cast(&version), + sizeof(TestDialectVersion)); + return AsmDialectVersionHandle( + encode, ctx->getOrLoadDialect()->getNamespace()); +} + +TestDialectVersion decodeDialectVersion(AsmDialectVersionHandle handle) { + auto version = TestDialectVersion( + *reinterpret_cast(handle.getRawData())); + return version; +} + //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -164,6 +188,52 @@ blobManager.buildResources(provider, referencedResources.getArrayRef()); } + // Test IR upgrade with dialect version lookup. + AsmDialectVersionHandle getProducerVersion() const final { + return encodeDialectVersion(getContext(), TestDialectVersion()); + } + + FailureOr + parseVersionAsString(StringRef token) const final { + // We represent our version as a string `"major.minor"` + auto split = token.split('.'); + TestDialectVersion version; + version.major = std::stoi(split.first.str()); + version.minor = std::stoi(split.second.str()); + return encodeDialectVersion(getContext(), version); + } + + // We would like to represent our version as version<"major.minor"> + FailureOr + printVersionAsString(AsmDialectVersionHandle attr) const final { + auto version = decodeDialectVersion(attr); + std::string result = std::to_string(version.major); + result.append("."); + result.append(std::to_string(version.minor)); + return result; + } + + LogicalResult + upgradeFromVersion(Operation *topLevelOp, + AsmDialectVersionHandle producerVersion) const final { + auto version = decodeDialectVersion(producerVersion); + if (version.minor == 42) + return success(); + if (version.minor > 42) { + return topLevelOp->emitError() + << "current test dialect version is 1.42, can't parse version: " + << version.major << "." << version.minor; + } + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + }); + + return success(); + } + private: /// The blob manager for the dialect. TestResourceBlobManagerInterface &blobManager; @@ -1102,9 +1172,7 @@ return getOperand(); } -OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { - return getValue(); -} +OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( FoldAdaptor adaptor, SmallVectorImpl &results) { @@ -1634,6 +1702,23 @@ setResultRanges(getResult(), range); } +//===----------------------------------------------------------------------===// +// Test dialect_version upgrade by supporting an old syntax +//===----------------------------------------------------------------------===// + +ParseResult TestVersionedOpB::parse(mlir::OpAsmParser &parser, + mlir::OperationState &state) { + auto handle = parser.getDialectVersion("test"); + auto version = decodeDialectVersion(handle); + if (version.minor < 42) + return parser.parseKeyword("deprecated_syntax"); + return parser.parseKeyword("current_version"); +} + +void TestVersionedOpB::print(OpAsmPrinter &printer) { + printer << " current_version"; +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestTypeInterfaces.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3149,4 +3149,26 @@ }]; } +//===----------------------------------------------------------------------===// +// Test Ops to upgrade base on the dialect versions +//===----------------------------------------------------------------------===// + +def TestVersionedOpA : TEST_Op<"versionedA"> { + let arguments = (ins + // A previous version of this dialect used the name "dimensions" for this + // attribute, it got renamed but we support loading old IR through + // upgrading, see `upgradeFromVersion()` in `TestOpAsmInterface`. + AnyI64Attr:$dims + ); +} + +// This op will be able to parse based on an old syntax and "auto-upgrade". +def TestVersionedOpB : TEST_Op<"versionedB"> { + let arguments = (ins); + + let hasCustomAssemblyFormat = 1; +// let parser = [{ return ::parseVersionedOp(parser, result); }]; +// let printer = [{ return ::print(*this, p); }]; +} + #endif // TEST_OPS