diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -102,6 +102,7 @@ // Attributes. NamedAttribute getNamedAttr(StringRef name, Attribute val); + VersionAttr getVersionAttr(ArrayRef data); UnitAttr getUnitAttr(); BoolAttr getBoolAttr(bool value); DictionaryAttr getDictionaryAttr(ArrayRef value); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -1227,4 +1227,48 @@ }]; } +//===----------------------------------------------------------------------===// +// VersionAttr +//===----------------------------------------------------------------------===// + +def Builtin_VersionAttrRawDataParameter : ArrayRefParameter< + "uint8_t", "storage for version attribute"> { + let allocator = [{ + if (!$_self.empty()) { + auto *alloc = static_cast( + $_allocator.allocate($_self.size(), 1)); + std::uninitialized_copy($_self.begin(), $_self.end(), alloc); + $_dst = ArrayRef(alloc, $_self.size()); + } + }]; +} + +def Builtin_VersionAttr : Builtin_Attr<"Version"> { + let summary = "An Attribute containing a blob of data representing a version"; + let description = [{ + A version attribute is an attribute that represents a blob of data managed + by each dialect independently. It is responsability of the dialect to define + a proper dialect version format and encode/decode it into/from this + attribute. + + Syntax: + + ``` + version-attribute ::= `version` `<` (attribute-value) `>` + ``` + Examples: + + ```mlir + version<"0x1A"> + ``` + }]; + + let parameters = (ins Builtin_VersionAttrRawDataParameter:$rawData); + + let extraClassDeclaration = [{ + /// Get the size of the data blob. + size_t size() const { return getRawData().size(); } + }]; +} + #endif // BUILTIN_ATTRIBUTES diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -491,6 +491,11 @@ /// the encoding process is not efficient. virtual Location getEncodedSourceLoc(SMLoc loc) = 0; + /// Return the attribute describing the version of the provided dialect name, + /// if any. The version is provided by the `dialect_versions` directive at the + /// very beginning of the parsing. + virtual VersionAttr getDialectVersion(StringRef dialect) = 0; + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// @@ -1450,9 +1455,9 @@ //===--------------------------------------------------------------------===// struct Argument { - UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. - Type type; // Type. - DictionaryAttr attrs; // Attributes if present. + UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. + Type type; // Type. + DictionaryAttr attrs; // Attributes if present. std::optional sourceLoc; // Source location specifier if present. }; @@ -1582,6 +1587,34 @@ return AliasResult::NoAlias; } + /// Hook provided by the dialect to emit a version when printing. The + /// attribute will be available when parsing back, and the dialect + /// implementation will be able to use it to load previous known version. + /// This management is entirely under the responsibility of the individual + /// dialects, and limited by possible changes in the MLIR syntax. + /// There is no restriction on what kind of attribute a dialect is using to + /// model its versioning. + virtual VersionAttr getProducerVersion() const { return {}; } + + /// Hook invoked after parsing completed, if a version directive was present + /// and included an entry for the current dialect. This hook offers the + /// opportunity to the dialect to visit the IR and upgrades constructs emitted + /// by the version of the dialect corresponding to the provided version + /// attribute. + virtual LogicalResult upgradeFromVersion(Operation *topLevelOp, + VersionAttr producerVersion) const { + return success(); + } + + virtual FailureOr parseVersionAttr(StringRef token) const { + return failure(); + } + + virtual FailureOr + printVersionAttrAsString(VersionAttr attr) const { + return failure(); + } + //===--------------------------------------------------------------------===// // Resources //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -60,6 +60,10 @@ return parser.getEncodedSourceLocation(loc); } + VersionAttr getDialectVersion(StringRef dialect) override { + return parser.getDialectVersion(dialect); + } + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -46,6 +46,7 @@ /// `:` (tensor-type | vector-type) /// | `strided` `<` `[` comma-separated-int-or-question `]` /// (`,` `offset` `:` integer-literal)? `>` +/// | `version` `<` attribute-value `>` /// | extended-attribute /// Attribute Parser::parseAttribute(Type type) { @@ -155,6 +156,10 @@ case Token::kw_strided: return parseStridedLayoutAttr(); + // Parse a version attribute. + case Token::kw_version: + return parseVersionAttr(); + // Parse a string attribute. case Token::string: { auto val = getToken().getStringValue(); @@ -324,7 +329,11 @@ return success(); } - auto attr = parseAttribute(); + Attribute attr; + if (getToken().is(Token::kw_version)) + attr = parseVersionAttr(*nameId); + else + attr = parseAttribute(); if (!attr) return failure(); attributes.push_back({*nameId, attr}); @@ -448,8 +457,8 @@ /// Parse elements values stored within a hex string. On success, the values are /// stored into 'result'. -static ParseResult parseElementAttrHexValues(Parser &parser, Token tok, - std::string &result) { +static ParseResult parseHexValues(Parser &parser, Token tok, + std::string &result) { if (std::optional value = tok.getHexStringValue()) { result = std::move(*value); return success(); @@ -719,7 +728,7 @@ } std::string data; - if (parseElementAttrHexValues(p, *hexStorage, data)) + if (parseHexValues(p, *hexStorage, data)) return nullptr; ArrayRef rawData(data.data(), data.size()); @@ -1214,3 +1223,69 @@ return StridedLayoutAttr::get(getContext(), *offset, strides); // return getChecked(loc,getContext(), *offset, strides); } + +Attribute Parser::parseVersionAttr() { + consumeToken(Token::kw_version); + if (failed(parseToken(Token::less, "expected '<' after 'version'"))) + return nullptr; + + if (!getToken().is(Token::string)) { + emitWrongTokenError("expected hex string after 'version<'"); + return nullptr; + } + + std::string result; + if (failed(parseHexValues(*this, getToken(), result))) + return nullptr; + consumeToken(Token::string); + + if (failed(parseToken(Token::greater, "expected '>'"))) + return nullptr; + + llvm::SmallVector resultCast(result.begin(), result.end()); + return VersionAttr::get(getContext(), resultCast); +} + +Attribute Parser::parseVersionAttr(StringAttr dialectName) { + consumeToken(Token::kw_version); + if (failed(parseToken(Token::less, "expected '<' after 'version'"))) + return nullptr; + + if (!getToken().is(Token::string)) { + emitWrongTokenError("expected string after 'version<'"); + return nullptr; + } + + // If the token represents a hex string, parse it as hex. + std::string result; + if (getToken().getHexStringValue().has_value()) + if (failed(parseHexValues(*this, getToken(), result))) + return nullptr; + + auto *iface = llvm::dyn_cast_or_null( + getContext()->getOrLoadDialect(dialectName)); + if (result.empty() && !iface) + return nullptr; + + // If we couldn't parse the token as hex string, then we can try using the + // custom parser exposed to the dialect. + VersionAttr attr; + if (result.empty()) { + FailureOr attrOr = + iface->parseVersionAttr(getToken().getStringValue()); + if (succeeded(attrOr)) + attr = *attrOr; + else + return nullptr; + } + consumeToken(Token::string); + + if (failed(parseToken(Token::greater, "expected '>'"))) + return nullptr; + + if (attr) + return attr; + + llvm::SmallVector resultCast(result.begin(), result.end()); + return VersionAttr::get(getContext(), resultCast); +} diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -36,6 +36,9 @@ ParserState &getState() const { return state; } MLIRContext *getContext() const { return state.config.getContext(); } const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } + VersionAttr getDialectVersion(StringRef dialect) const { + return state.dialectVersions.get(dialect).cast(); + } /// Parse a comma-separated list of elements up until the specified end token. ParseResult @@ -276,6 +279,10 @@ /// Parse a strided layout attribute. Attribute parseStridedLayoutAttr(); + /// Parse a version attribute. + Attribute parseVersionAttr(); + Attribute parseVersionAttr(StringAttr dialectName); + //===--------------------------------------------------------------------===// // Location Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -803,6 +803,23 @@ if (failed(popSSANameScope())) return failure(); + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (state.dialectVersions) { + for (NamedAttribute dialectVersion : state.dialectVersions) { + Dialect *dialect = + getContext()->getLoadedDialect(dialectVersion.getName()); + if (!dialect) + continue; + auto *asmIface = dialect->getRegisteredInterface(); + if (!asmIface) + continue; + if (failed(asmIface->upgradeFromVersion( + topLevelOp, dialectVersion.getValue().cast()))) + return failure(); + } + } + // Verify that the parsed operations are valid. if (state.config.shouldVerifyAfterParse() && failed(verify(topLevelOp))) return failure(); @@ -2599,6 +2616,23 @@ // Create a top-level operation to contain the parsed state. OwningOpRef topLevelOp(ModuleOp::create(parserLoc)); OperationParser opParser(state, topLevelOp.get()); + + // If the input starts with a `dialect_versions` keyword, we expect a + // dictionary representing the version of the dialect at the time the IR was + // produced. This will be used for possibly upgrading the IR when parsing + // completes. + if (getToken().getSpelling() == "dialect_versions") { + consumeToken(); + Attribute versions = opParser.parseAttribute(); + if (!versions) + return failure(); + state.dialectVersions = versions.dyn_cast(); + if (!state.dialectVersions) { + emitError("Expects Dictionary attribute for dialect versions"); + return failure(); + } + } + while (true) { switch (getToken().getKind()) { default: diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h --- a/mlir/lib/AsmParser/ParserState.h +++ b/mlir/lib/AsmParser/ParserState.h @@ -80,6 +80,8 @@ // popped when done. At the top-level we start with "builtin" as the // default, so that the top-level `module` operation parses as-is. SmallVector defaultDialectStack{"builtin"}; + + DictionaryAttr dialectVersions; }; } // namespace detail diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -122,6 +122,7 @@ TOK_KEYWORD(type) TOK_KEYWORD(unit) TOK_KEYWORD(vector) +TOK_KEYWORD(version) #undef TOK_MARKER #undef TOK_IDENTIFIER diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -61,8 +61,11 @@ /// section. kResourceOffset = 6, + /// This section contains the versions of each dialect. + kDialectVersions = 7, + /// The total number of section types. - kNumSections = 7, + kNumSections = 8, }; } // namespace Section diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -47,6 +47,8 @@ return "Resource (5)"; case bytecode::Section::kResourceOffset: return "ResourceOffset (6)"; + case bytecode::Section::kDialectVersions: + return "DialectVersions (7)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } @@ -63,6 +65,7 @@ return false; case bytecode::Section::kResource: case bytecode::Section::kResourceOffset: + case bytecode::Section::kDialectVersions: return true; default: llvm_unreachable("unknown section ID"); @@ -446,6 +449,8 @@ /// The name of the dialect. StringRef name; + + VersionAttr version; }; /// This struct represents an operation name entry within the bytecode. @@ -1165,6 +1170,13 @@ LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); + //===--------------------------------------------------------------------===// + // Dialect Versions Section + + /// Parse dialect versions. + LogicalResult + parseDialectVersionsSection(std::optional> sectionData); + //===--------------------------------------------------------------------===// // Value Processing @@ -1302,6 +1314,11 @@ if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); + // Process the dialect section. + if (failed(parseDialectVersionsSection( + sectionDatas[bytecode::Section::kDialectVersions]))) + return failure(); + // Process the resource section if present. if (failed(parseResourceSection( sectionDatas[bytecode::Section::kResource], @@ -1440,6 +1457,20 @@ "not all forward unresolved forward operand references"); } + // Resolve dialect version. + for (auto dialect : dialects) { + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (dialect.version) { + auto *asmIface = + (*dialect.dialect)->getRegisteredInterface(); + if (!asmIface) + continue; + if (failed(asmIface->upgradeFromVersion(*moduleOp, dialect.version))) + return failure(); + } + } + // Verify that the parsed operations are valid. if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) return failure(); @@ -1673,6 +1704,45 @@ return defineValues(reader, block->getArguments()); } +//===--------------------------------------------------------------------===// +// Dialect Versions Section + +LogicalResult BytecodeReader::parseDialectVersionsSection( + std::optional> sectionData) { + // If the dialect versions are absent, there is nothing to do. + if (!sectionData.has_value()) + return success(); + + EncodingReader sectionReader(sectionData.value(), fileLoc); + + // Parse the number of dialects in the section. + uint64_t numVersionedDialects; + if (failed(sectionReader.parseVarInt(numVersionedDialects))) + return failure(); + + // Parse each of the dialect versions. + llvm::StringMap versionedDialectMap; + for (uint64_t i = 0; i < numVersionedDialects; ++i) { + StringRef dialectName; + uint64_t dialectVersionAttrSize; + if (failed(stringReader.parseString(sectionReader, dialectName))) + return failure(); + if (failed(sectionReader.parseVarInt(dialectVersionAttrSize))) + return failure(); + ArrayRef bytes; + if (failed(sectionReader.parseBytes(dialectVersionAttrSize, bytes))) + return failure(); + versionedDialectMap.insert( + {dialectName, VersionAttr::get(getContext(), bytes)}); + } + + for (auto dialect : dialects) + if (versionedDialectMap.count(dialect.name)) + dialect.version = versionedDialectMap.lookup(dialect.name); + + return success(); +} + //===----------------------------------------------------------------------===// // Value Processing diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -385,6 +385,11 @@ void writeStringSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// + // Dialect versions + + void writeDialectVersionsSection(EncodingEmitter &emitter); + //===--------------------------------------------------------------------===// // Fields @@ -425,6 +430,9 @@ // Emit the string section. writeStringSection(emitter); + // Emit the dialect section. + writeDialectVersionsSection(emitter); + // Write the generated bytecode to the provided output stream. emitter.writeTo(os); } @@ -476,6 +484,40 @@ emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter)); } +void BytecodeWriter::writeDialectVersionsSection(EncodingEmitter &emitter) { + EncodingEmitter dialectEmitter; + + // Get dialect version. + auto getVersion = [&](const OpAsmDialectInterface *asmIface) -> VersionAttr { + // A dialect can be nullptr if not loaded. In such a case, we can't print + // the version properly. + if (asmIface) + return asmIface->getProducerVersion(); + return {}; + }; + + // Emit the referenced dialects. + auto dialects = numberingState.getDialects(); + llvm::SmallVector> dialectVersionPair; + + for (DialectNumbering &dialect : dialects) { + if (auto version = getVersion(dialect.asmInterface)) { + dialectVersionPair.push_back( + std::make_pair(stringSection.insert(dialect.name), + version)); + } + } + dialectEmitter.emitVarInt(dialectVersionPair.size()); + for (auto item : dialectVersionPair) { + dialectEmitter.emitVarInt(item.first); + dialectEmitter.emitVarInt(item.second.size()); + dialectEmitter.emitBytes(item.second.getRawData()); + } + + emitter.emitSection(bytecode::Section::kDialectVersions, + std::move(dialectEmitter)); +} + //===----------------------------------------------------------------------===// // Attributes and Types diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -44,8 +44,8 @@ #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/Threading.h" -#include #include +#include using namespace mlir; using namespace mlir::detail; @@ -376,6 +376,9 @@ /// dialect. void printResourceHandle(const AsmDialectResourceHandle &resource); + /// Print the given version attribute as hex to the stream. + void printVersionAttr(VersionAttr attr); + void printAffineMap(AffineMap map); void printAffineExpr(AffineExpr expr, @@ -1975,6 +1978,14 @@ state.getDialectResources()[resource.getDialect()].insert(resource); } +void AsmPrinter::Impl::printVersionAttr(VersionAttr attr) { + auto rawData = attr.getRawData(); + os << "version<"; + printHexString(ArrayRef(reinterpret_cast(rawData.data()), + rawData.size())); + os << ">"; +} + /// Returns true if the given dialect symbol data is simple enough to print in /// the pretty form. This is essentially when the symbol takes the form: /// identifier (`<` body `>`)? @@ -2208,6 +2219,8 @@ os << ">"; } else if (auto locAttr = attr.dyn_cast()) { printLocation(locAttr); + } else if (auto versionAttr = attr.dyn_cast()) { + printVersionAttr(versionAttr); } else { llvm::report_fatal_error("Unknown builtin attribute"); } @@ -3061,6 +3074,8 @@ void printResourceFileMetadata(function_ref checkAddMetadataDict, Operation *op); + void printDialectVersions(Operation *op); + // Contains the stack of default dialects to use when printing regions. // A new dialect is pushed to the stack before parsing regions nested under an // operation implementing `OpAsmOpInterface`, and popped when done. At the @@ -3077,6 +3092,9 @@ } // namespace void OperationPrinter::printTopLevelOperation(Operation *op) { + // Output the dialect versions directive first. + printDialectVersions(op); + // Output the aliases at the top level that can't be deferred. state.getAliasState().printNonDeferredAliases(*this, newLine); @@ -3159,6 +3177,48 @@ os << newLine << " }"; } +void OperationPrinter::printDialectVersions(Operation *op) { + // Retrieve the version to print for this dialect, if any. + auto getVersion = [&](Dialect *dialect) -> VersionAttr { + const auto interfaces = state.getDialectInterfaces(); + const OpAsmDialectInterface *asmIface = interfaces.getInterfaceFor(dialect); + if (!asmIface) + return {}; + return asmIface->getProducerVersion(); + }; + + // If any dialect has a version to print, print the `dialect_versions` + // directive. This directive has to be the very first directive in the file + // and introduce a dictionary where the keys are dialect names and the value + // are attributes representing the producer version for the given dialect. + llvm::SmallVector> dialectVersionPairs; + for (auto dialect : op->getContext()->getLoadedDialects()) { + auto version = getVersion(dialect); + if (version) + dialectVersionPairs.push_back({dialect, version}); + } + + if (dialectVersionPairs.empty()) + return; + + os << "dialect_versions { "; + interleaveComma(dialectVersionPairs, [&](auto item) { + Dialect *dialect = item.first; + VersionAttr attr = item.second; + const auto interfaces = state.getDialectInterfaces(); + const OpAsmDialectInterface *asmIface = interfaces.getInterfaceFor(dialect); + auto stringOr = asmIface->printVersionAttrAsString(attr); + os << dialect->getNamespace() << " = "; + if (succeeded(stringOr)) { + os << "version<\"" << *stringOr << "\">"; + } else { + os << attr; + } + }); + os << " }" << newLine; + return; +} + /// Print a block argument in the usual format of: /// %ssaName : type {attr1=42} loc("here") /// where location printing is controlled by the standard internal option. diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -104,6 +104,10 @@ return NamedAttribute(getStringAttr(name), val); } +VersionAttr Builder::getVersionAttr(ArrayRef data) { + return VersionAttr::get(context, data); +} + UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); } BoolAttr Builder::getBoolAttr(bool value) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/APSInt.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/Endian.h" #include diff --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir --- a/mlir/test/Bytecode/general.mlir +++ b/mlir/test/Bytecode/general.mlir @@ -3,6 +3,7 @@ // Bytecode currently does not support big-endian platforms // UNSUPPORTED: target=s390x-{{.*}} +// CHECK: dialect_versions { test = version<"1.42"> } // CHECK-LABEL: "bytecode.test1" // CHECK-NEXT: "bytecode.empty"() : () -> () // CHECK-NEXT: "bytecode.attributes"() {attra = 10 : i64, attrb = #bytecode.attr} : () -> () diff --git a/mlir/test/IR/ir_upgrade.mlir b/mlir/test/IR/ir_upgrade.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/ir_upgrade.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt -split-input-file --verify-diagnostics %s | FileCheck %s + +dialect_versions { test = version<"0x0100000029000000"> } + +// ----- + +dialect_versions { test = version<"1.39"> } + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: "test.versionedA"() {dims = 123 : i64} : () -> () +"test.versionedA"() {dims = 123 : i64} : () -> () + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: "test.versionedA"() {dims = 123 : i64} : () -> () +dialect_versions { test = version<"0x0100000029000000"> } +"test.versionedA"() {dimensions = 123 : i64} : () -> () + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: "test.versionedA"() {dims = 123 : i64} : () -> () +dialect_versions { test = version<"0x0100000029000000"> } +"test.versionedA"() {dims = 123 : i64} : () -> () + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: test.versionedB current_version +dialect_versions { test = version<"0x0120000009000000"> } +test.versionedB deprecated_syntax + +// ----- + +// CHECK: dialect_versions { test = version<"1.42"> } +// CHECK: test.versionedB current_version +dialect_versions { test = version<"0x010000002A000000"> } +test.versionedB current_version + +// ----- + +dialect_versions { test = version<"0x010000002A000000"> } +// expected-error@+1{{custom op 'test.versionedB' expected 'current_version'}} +test.versionedB deprecated_syntax + +// ----- +// expected-error@-2{{current test dialect version is 1.42, can't parse version: 1.43}} +dialect_versions { test = version<"1.43"> } +"test.versionedA"() {dims = 123 : i64} : () -> () \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -47,6 +47,30 @@ registry.insert(); } +//===----------------------------------------------------------------------===// +// TestDialect version utilities +//===----------------------------------------------------------------------===// + +struct TestDialectVersion { + int major = 1; + int minor = 42; +}; + +// Encode/decode a version attribute +VersionAttr encodeDialectVersion(MLIRContext *ctx, + const TestDialectVersion version) { + ArrayRef encode(reinterpret_cast(&version), + sizeof(TestDialectVersion)); + return Builder(ctx).getVersionAttr(encode); +} + +TestDialectVersion decodeDialectVersion(const VersionAttr attribute) { + auto version = + TestDialectVersion(*reinterpret_cast( + attribute.getRawData().data())); + return version; +} + //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -164,6 +188,49 @@ blobManager.buildResources(provider, referencedResources.getArrayRef()); } + // Test IR upgrade with dialect version lookup. + VersionAttr getProducerVersion() const final { + return encodeDialectVersion(getContext(), TestDialectVersion()); + } + + FailureOr parseVersionAttr(StringRef token) const final { + // We represent our version as a string `"major.minor"` + auto split = token.split('.'); + TestDialectVersion version{.major = std::stoi(split.first.str()), + .minor = std::stoi(split.second.str())}; + return encodeDialectVersion(getContext(), version); + } + + // We would like to represent our version as version<"major.minor"> + FailureOr + printVersionAttrAsString(VersionAttr attr) const final { + auto version = decodeDialectVersion(attr); + std::string result = std::to_string(version.major); + result.append("."); + result.append(std::to_string(version.minor)); + return result; + } + + LogicalResult upgradeFromVersion(Operation *topLevelOp, + VersionAttr producerVersion) const final { + auto version = decodeDialectVersion(producerVersion); + if (version.minor == 42) + return success(); + if (version.minor > 42) { + return topLevelOp->emitError() + << "current test dialect version is 1.42, can't parse version: " + << version.major << "." << version.minor; + } + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + }); + + return success(); + } + private: /// The blob manager for the dialect. TestResourceBlobManagerInterface &blobManager; @@ -1102,9 +1169,7 @@ return getOperand(); } -OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { - return getValue(); -} +OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( FoldAdaptor adaptor, SmallVectorImpl &results) { @@ -1634,6 +1699,23 @@ setResultRanges(getResult(), range); } +//===----------------------------------------------------------------------===// +// Test dialect_version upgrade by supporting an old syntax +//===----------------------------------------------------------------------===// + +ParseResult TestVersionedOpB::parse(mlir::OpAsmParser &parser, + mlir::OperationState &state) { + auto versionAttr = parser.getDialectVersion("test"); + auto version = decodeDialectVersion(versionAttr); + if (version.minor < 42) + return parser.parseKeyword("deprecated_syntax"); + return parser.parseKeyword("current_version"); +} + +void TestVersionedOpB::print(OpAsmPrinter &printer) { + printer << " current_version"; +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestTypeInterfaces.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3149,4 +3149,26 @@ }]; } +//===----------------------------------------------------------------------===// +// Test Ops to upgrade base on the dialect versions +//===----------------------------------------------------------------------===// + +def TestVersionedOpA : TEST_Op<"versionedA"> { + let arguments = (ins + // A previous version of this dialect used the name "dimensions" for this + // attribute, it got renamed but we support loading old IR through + // upgrading, see `upgradeFromVersion()` in `TestOpAsmInterface`. + AnyI64Attr:$dims + ); +} + +// This op will be able to parse based on an old syntax and "auto-upgrade". +def TestVersionedOpB : TEST_Op<"versionedB"> { + let arguments = (ins); + + let hasCustomAssemblyFormat = 1; +// let parser = [{ return ::parseVersionedOp(parser, result); }]; +// let printer = [{ return ::print(*this, p); }]; +} + #endif // TEST_OPS