diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -830,3 +830,15 @@ that are directly usable by any other dialect in MLIR. These types cover a range from primitive integer and floating-point values, attribute dictionaries, dense multi-dimensional arrays, and more. + +### IR Versionning + +Dialect can opt-in to handle versioning in a custom way. The input IR may start +with the keyword `dialect_versions`, followed by a dictionary attribute where +the keys are dialect names and the value are attributes representing the +producer version for the given dialect. There is no restriction on what kind of +attribute a dialect is using to model its versioning. + +On parsing the version is made available through the parser, and a dialect for +which a version is present has the ability to use the stored version to upgrade +the IR. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -394,6 +394,11 @@ /// the encoding process is not efficient. virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; + /// Return the attribute describing the version of the provided dialect name, + /// if any. The version is provided by the `dialect_versions` directive at the + /// very beginning of the parsing. + virtual Attribute getDialectVersion(StringRef dialect) = 0; + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// @@ -1350,6 +1355,24 @@ return AliasResult::NoAlias; } + /// Hook provided by the dialect to emit a version when printing. The + /// attribute will be available when parsing back, and the dialect + /// implementation will be able to use it to load previous known version. + /// This management is entirely under the responsibility of the individual + /// dialects, and limited by possible changes in the MLIR syntax. + /// There is no restriction on what kind of attribute a dialect is using to + /// model its versioning. + virtual Attribute getProducerVersion() const { return {}; } + + /// Hook invoked after parsing completed, if a version directive was present + /// and included an entry for the current dialect. This hook offers the + /// opportunity to the dialect to visit the IR and upgrades constructs emitted + /// by the version of the dialect corresponding to the provided version + /// attribute. + virtual LogicalResult upgradeFromVersion(Operation *topLevelOp, + Attribute producerVersion) const { + return success(); + } }; } // namespace mlir diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2487,6 +2487,33 @@ } // namespace void OperationPrinter::printTopLevelOperation(Operation *op) { + + // Retrieve the version to print for this dialect, if any. + auto getVersion = [&](Dialect *dialect) -> Attribute { + const OpAsmDialectInterface *asmIface = state->getOpAsmInterface(dialect); + if (!asmIface) + return {}; + return asmIface->getProducerVersion(); + }; + + // If any dialect has a version to print, print the `dialect_versions` + // directive. This directive has to be the very first directive in the file + // and introduce a dictionary where the keys are dialect names and the value + // are attributes representing the producer version for the given dialect. + // There is no restriction on what kind of attribute a dialect is using to + // model its versioning. + std::vector loadedDialects = op->getContext()->getLoadedDialects(); + llvm::erase_if(loadedDialects, + [&](Dialect *dialect) { return !getVersion(dialect); }); + if (llvm::any_of(loadedDialects, getVersion)) { + os << "dialect_versions { "; + interleaveComma(loadedDialects, [&](Dialect *dialect) { + if (Attribute version = getVersion(dialect)) + os << dialect->getNamespace() << " = " << version; + }); + os << " }" << newLine; + } + // Output the aliases at the top level that can't be deferred. state->getAliasState().printNonDeferredAliases(os, newLine); diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -60,6 +60,10 @@ return parser.getEncodedSourceLocation(loc); } + Attribute getDialectVersion(StringRef dialect) override { + return parser.getDialectVersion(dialect); + } + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -34,6 +34,9 @@ ParserState &getState() const { return state; } MLIRContext *getContext() const { return state.context; } const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } + Attribute getDialectVersion(StringRef dialect) const { + return state.dialectVersions.get(dialect); + } /// Parse a comma-separated list of elements up until the specified end token. ParseResult diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -583,6 +583,23 @@ if (failed(popSSANameScope())) return failure(); + // Parsing is complete, give an opportunity to each dialect to visit the + // IR and perform upgrades. + if (state.dialectVersions) { + for (NamedAttribute dialectVersion : state.dialectVersions) { + Dialect *dialect = + getContext()->getLoadedDialect(dialectVersion.getName()); + if (!dialect) + continue; + auto *asmIface = dialect->getRegisteredInterface(); + if (!asmIface) + continue; + if (failed(asmIface->upgradeFromVersion(topLevelOp, + dialectVersion.getValue()))) + return failure(); + } + } + // Verify that the parsed operations are valid. if (failed(verify(topLevelOp))) return failure(); @@ -2168,6 +2185,23 @@ // Create a top-level operation to contain the parsed state. OwningOpRef topLevelOp(ModuleOp::create(parserLoc)); OperationParser opParser(state, topLevelOp.get()); + + // If the input starts with a `dialect_versions` keyword, we expect a + // dictionnary representing the version of the dialect at the time the IR was + // produced. This will be used for possibly upgrading the IR when parsing + // completes. + if (getToken().getSpelling() == "dialect_versions") { + consumeToken(); + Attribute versions = opParser.parseAttribute(); + if (!versions) + return failure(); + state.dialectVersions = versions.dyn_cast(); + if (!state.dialectVersions) { + emitError("Expects Dictionary attribute for dialect versions"); + return failure(); + } + } + while (true) { switch (getToken().getKind()) { default: @@ -2187,6 +2221,7 @@ auto &destOps = topLevelBlock->getOperations(); destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), parsedOps, parsedOps.begin(), parsedOps.end()); + return success(); } diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h --- a/mlir/lib/Parser/ParserState.h +++ b/mlir/lib/Parser/ParserState.h @@ -89,6 +89,8 @@ // popped when done. At the top-level we start with "builtin" as the // default, so that the top-level `module` operation parses as-is. SmallVector defaultDialectStack{"builtin"}; + + DictionaryAttr dialectVersions; }; } // namespace detail diff --git a/mlir/test/IR/ir_upgrade.mlir b/mlir/test/IR/ir_upgrade.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/ir_upgrade.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt -split-input-file --verify-diagnostics %s | FileCheck %s + +dialect_versions { test = 41 } + +// ----- + +// CHECK: dialect_versions { test = 42 : i32 } +// CHECK: "test.versionedA"() {dims = 123 : i64} : () -> () +"test.versionedA"() {dims = 123 : i64} : () -> () + +// ----- + +// CHECK: dialect_versions { test = 42 : i32 } +// CHECK: "test.versionedA"() {dims = 123 : i64} : () -> () +dialect_versions { test = 41 } +"test.versionedA"() {dimensions = 123 : i64} : () -> () + +// ----- + +// CHECK: dialect_versions { test = 42 : i32 } +// CHECK: "test.versionedA"() {dims = 123 : i64} : () -> () +dialect_versions { test = 41 } +"test.versionedA"() {dims = 123 : i64} : () -> () + +// ----- + +// CHECK: dialect_versions { test = 42 : i32 } +// CHECK: test.versionedB current_version +dialect_versions { test = 41 } +test.versionedB deprecated_syntax + +// ----- + +// CHECK: dialect_versions { test = 42 : i32 } +// CHECK: test.versionedB current_version +dialect_versions { test = 42 } +test.versionedB current_version + +// ----- + +dialect_versions { test = 42 } +// expected-error@+1{{custom op 'test.versionedB' expected 'current_version'}} +test.versionedB deprecated_syntax + +// ----- +// expected-error@-2{{Current test dialect version is 42, can't parse version: 43 : i64}} +dialect_versions { test = 43 } +"test.versionedA"() {dims = 123 : i64} : () -> () + + diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -99,6 +99,34 @@ } return AliasResult::NoAlias; } + + Attribute getProducerVersion() const final { + return Builder(getDialect()->getContext()).getI32IntegerAttr(42); + } + LogicalResult upgradeFromVersion(Operation *topLevelOp, + Attribute producerVersion) const final { + auto version = producerVersion.dyn_cast(); + if (!version) { + return topLevelOp->emitError() + << "Expected an IntegerAttr for test dialect version, got: " + << producerVersion; + } + if (version.getInt() == 42) + return success(); + if (version.getInt() > 42) { + return topLevelOp->emitError() + << "Current test dialect version is 42, can't parse version: " + << version; + } + topLevelOp->walk([](TestVersionedOpA op) { + if (auto dims = op->getAttr("dimensions")) { + op->removeAttr("dimensions"); + op->setAttr("dims", dims); + } + }); + + return success(); + } }; struct TestDialectFoldInterface : public DialectFoldInterface { @@ -1209,6 +1237,24 @@ /*printBlockTerminators=*/false); } +//===----------------------------------------------------------------------===// +// Test dialect_version upgrade by supporting an old syntax +//===----------------------------------------------------------------------===// + +static ParseResult parseVersionedOp(OpAsmParser &parser, + OperationState &state) { + if (auto version = + parser.getDialectVersion("test").dyn_cast_or_null()) { + if (version.getInt() < 42) + return parser.parseKeyword("deprecated_syntax"); + } + return parser.parseKeyword("current_version"); +} + +static void print(TestVersionedOpB op, OpAsmPrinter &printer) { + printer << " current_version"; +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestOpStructs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2600,4 +2600,25 @@ def TestEffectsOpB : TEST_Op<"op_with_effects_b", [MemoryEffects<[MemWrite]>]>; +//===----------------------------------------------------------------------===// +// Test Ops to upgrade base on the dialect versions +//===----------------------------------------------------------------------===// + +def TestVersionedOpA : TEST_Op<"versionedA"> { + let arguments = (ins + // A previous version of this dialect used the name "dimensions" for this + // attribute, it got renamed but we support loading old IR through + // upgrading, see `upgradeFromVersion()` in `TestOpAsmInterface`. + AnyI64Attr:$dims + ); +} + +// This op will be able to parse based on an old syntax and "auto-upgrade". +def TestVersionedOpB : TEST_Op<"versionedB"> { + let arguments = (ins); + let parser = [{ return ::parseVersionedOp(parser, result); }]; + let printer = [{ return ::print(*this, p); }]; +} + + #endif // TEST_OPS