diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -394,6 +394,8 @@ /// the encoding process is not efficient. virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; + virtual Attribute getDialectVersion(StringRef dialect) = 0; + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// @@ -1347,6 +1349,11 @@ return AliasResult::NoAlias; } + virtual Attribute getProducerVersion() const { return {}; } + virtual LogicalResult upgradeFromVersion(Operation *topLevelOp, + Attribute attr) { + return success(); + } }; } // namespace mlir diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2491,6 +2491,25 @@ } // namespace void OperationPrinter::printTopLevelOperation(Operation *op) { + auto getVersion = [&](Dialect *dialect) -> Attribute { + const OpAsmDialectInterface *asmIface = state->getOpAsmInterface(dialect); + if (!asmIface) + return {}; + return asmIface->getProducerVersion(); + }; + + std::vector loadedDialects = op->getContext()->getLoadedDialects(); + llvm::erase_if(loadedDialects, + [&](Dialect *dialect) { return !getVersion(dialect); }); + if (llvm::any_of(loadedDialects, getVersion)) { + os << "dialect_versions { "; + interleaveComma(loadedDialects, [&](Dialect *dialect) { + if (Attribute version = getVersion(dialect)) + os << dialect->getNamespace() << " = " << version; + }); + os << " }" << newLine; + } + // Output the aliases at the top level that can't be deferred. state->getAliasState().printNonDeferredAliases(os, newLine); diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -60,6 +60,10 @@ return parser.getEncodedSourceLocation(loc); } + Attribute getDialectVersion(StringRef dialect) override { + return parser.getDialectVersion(dialect); + } + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -34,6 +34,9 @@ ParserState &getState() const { return state; } MLIRContext *getContext() const { return state.context; } const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } + Attribute getDialectVersion(StringRef dialect) const { + return state.dialectVersions.get(dialect); + } /// Parse a comma-separated list of elements up until the specified end token. ParseResult diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2154,6 +2154,19 @@ // Create a top-level operation to contain the parsed state. OwningOpRef topLevelOp(ModuleOp::create(parserLoc)); OperationParser opParser(state, topLevelOp.get()); + + if (getToken().getSpelling() == "dialect_versions") { + consumeToken(); + Attribute versions = opParser.parseAttribute(); + if (!versions) + return failure(); + state.dialectVersions = versions.dyn_cast(); + if (!state.dialectVersions) { + emitError("Expects Dictionary attribute for dialect versions"); + return failure(); + } + } + while (true) { switch (getToken().getKind()) { default: @@ -2173,6 +2186,19 @@ auto &destOps = topLevelBlock->getOperations(); destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), parsedOps, parsedOps.begin(), parsedOps.end()); + + for (NamedAttribute dialectVersion : state.dialectVersions) { + Dialect *dialect = + getContext()->getLoadedDialect(dialectVersion.getName()); + if (!dialect) + continue; + auto *asmIface = dyn_cast(dialect); + if (!asmIface) + continue; + if (failed(asmIface->upgradeFromVersion(topLevelOp.get(), + dialectVersion.getValue()))) + return failure(); + } return success(); } diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h --- a/mlir/lib/Parser/ParserState.h +++ b/mlir/lib/Parser/ParserState.h @@ -89,6 +89,8 @@ // popped when done. At the top-level we start with "builtin" as the // default, so that the top-level `module` operation parses as-is. SmallVector defaultDialectStack{"builtin"}; + + DictionaryAttr dialectVersions; }; } // namespace detail diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -99,6 +99,10 @@ } return AliasResult::NoAlias; } + + Attribute getProducerVersion() const final { + return Builder(getDialect()->getContext()).getI32IntegerAttr(42); + } }; struct TestDialectFoldInterface : public DialectFoldInterface {