diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -1,4 +1,4 @@ -//===- AsmState.h - State class for AsmPrinter ------------------*- C++ -*-===// +//===- AsmState.h - Assembly State Utilities --------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file defines the AsmState class. +// This file defines various classes and utilites for interacting with the MLIR +// assembly formats. // //===----------------------------------------------------------------------===// @@ -19,12 +20,222 @@ #include namespace mlir { +class AsmConfigPrinter; class Operation; namespace detail { class AsmStateImpl; } // namespace detail +//===----------------------------------------------------------------------===// +// Config +//===----------------------------------------------------------------------===// + +/// The following classes enable support for parsing and printing configurations +/// within MLIR assembly formats. Configurations are a mechanism by which +/// dialects, and external clients, may attach additional information when +/// parsing/printing IR without that information being encoded in the IR itself. +/// Configurations are not uniqued within the MLIR context, are not attached +/// directly to any operation, and are solely intended to live and be processed +/// outside of the immediate IR. +/// +/// External configurations are encoded using a key-value pair nested within +/// dictionaries anchored either on a dialect, or an externally registered +/// entity. Dictionaries anchored on dialects use the dialect namespace +/// directly, and dictionaries anchored on external entities anchor on an +/// identifier wrapped within `<>`. The configuration key is an identifier used +/// to disambiguate the data. The configuration value may be stored in various +/// limited forms, but general encodings use a string (human readable) or blob +/// format (binary). Within the textual format, an example may be of the form: +/// +/// {-# +/// // The `config` section within the file-level metadata dictionary is used +/// // to contain any configuration entries. +/// config: { +/// // Here is a dictionary anchored on "mlir_reproducer", which is an +/// // external entity representing MLIR's crash reproducer functionality. +/// // External entity anchors are wrapped in `<>` to differentiate them +/// // from dialect names. +/// : { +/// // `pipeline` is an entry that holds a crash reproducer pipeline +/// // configuration. +/// pipeline: "func.func(canonicalize,cse)" +/// }, +/// // Here is a dictionary anchored on "foo_dialect", which is a dialect +/// // namespace. +/// foo_dialect: { +/// // `some_dialect_config` is a key to be interpreted by the dialect, +/// // and used to initialize/configure/etc. +/// some_dialect_config: "Some important config value" +/// } +/// } +/// #-} +/// + +//===----------------------------------------------------------------------===// +// Config Entry + +/// This class is used to build config entries for the printer. Each config +/// entry is represented using a key/value pair. The provided key must be unique +/// within the current context, which allows for a client to provide +/// configuration entries without worrying about overlap with other clients. +class AsmConfigBuilder { +public: + virtual ~AsmConfigBuilder(); + + /// Build an config entry represented by the given bool. + virtual void buildBool(StringRef key, bool data) = 0; + + /// Build an config entry represented by the given human-readable string + /// value. + virtual void buildString(StringRef key, StringRef data) = 0; + + /// Build an config entry represented by the given binary blob data. + virtual void buildBlob(StringRef key, ArrayRef data, + unsigned dataAlignent) = 0; + /// A useful overload if the data type is known. Note that this does not + /// support `char` element types to avoid accidentally not providing the + /// expected alignment of data in situations that treat blobs generically. + template + std::enable_if_t::value> buildBlob(StringRef key, + ArrayRef data) { + buildBlob( + key, ArrayRef((const char *)data.data(), data.size() * sizeof(T)), + alignof(T)); + } +}; + +/// This class represents a single parsed config entry. +class AsmConfigEntry { +public: + virtual ~AsmConfigEntry(); + + /// Return the key of the config entry. + virtual StringRef getKey() const = 0; + + /// Emit an error at the location of this entry. + virtual InFlightDiagnostic emitError() const = 0; + + /// Parse the config entry represented by a boolean. Returns failure if the + /// entry does not correspond to a bool. + virtual FailureOr parseAsBool() const = 0; + + /// Parse the config entry represented by a human-readable string. Returns + /// failure if the entry does not correspond to a string. + virtual FailureOr parseAsString() const = 0; + + /// The type of an allocator function used to allocate memory for a blob when + /// required. The function is provided a size and alignment, and should return + /// an aligned allocation buffer. + using BlobAllocatorFn = + function_ref(unsigned size, unsigned align)>; + + /// This class represents a processed blob of data. + struct Blob { + /// The raw, properly aligned, blob data. + MutableArrayRef data; + + /// A flag indicating if data was allocated using the provided allocation + /// function, false if the data did not need to be allocated (e.g. such as + /// when mmap'd). + bool dataWasAllocated = false; + }; + + /// Parse the config entry represented by a binary blob. Returns failure if + /// the entry does not correspond to a blob. If the blob needed to be + /// allocated, the given allocator function is invoked. + virtual FailureOr parseAsBlob(BlobAllocatorFn allocator) const = 0; +}; + +//===----------------------------------------------------------------------===// +// Config + +/// This class represents an instance of a configuration parser. This class +/// should be implemented by non-dialect clients that want to inject +/// additional configurations into MLIR assembly formats. +class AsmConfigParser { +public: + /// Create a new parser with the given identifying name. This name uniquely + /// identifies the entries of this parser, and differentiates them from other + /// contexts. + AsmConfigParser(StringRef name) : name(name.str()) {} + virtual ~AsmConfigParser(); + + /// Return the name of this parser. + StringRef getName() const { return name; } + + /// Parse the given configuration entry. Returns failure if the key/data were + /// not valid, or could otherwise not be processed correctly. Any necessary + /// errors should be emitted with the provided entry. + virtual LogicalResult parseConfig(AsmConfigEntry &entry) = 0; + + /// Return an configuration parser implemented via the given callable, whose + /// form should match that of `parseConfig` above. + template + static std::unique_ptr fromCallable(StringRef name, + CallableT &&parseFn) { + struct Processor : public AsmConfigParser { + Processor(StringRef name, CallableT &&parseFn) + : AsmConfigParser(name), parseFn(std::move(parseFn)) {} + LogicalResult parseConfig(AsmConfigEntry &entry) override { + return parseFn(entry); + } + + std::decay_t parseFn; + }; + return std::make_unique(name, std::forward(parseFn)); + } + +private: + std::string name; +}; + +/// This class represents an instance of a configuration printer. This class +/// should be implemented by non-dialect clients that want to inject additional +/// configurations into MLIR assembly formats. +class AsmConfigPrinter { +public: + /// Create a new printer with the given identifying name. This name uniquely + /// identifies the entries of this printer, and differentiates them from + /// other contexts. + AsmConfigPrinter(StringRef name) : name(name.str()) {} + virtual ~AsmConfigPrinter(); + + /// Return the name of this printer. + StringRef getName() const { return name; } + + /// Build any configurations to include during printing, utilizing the given + /// top-level root operation to help determine what information to include. + /// Provided data should be registered in the form of a key/data pair, to the + /// given builder. + virtual void buildConfigs(Operation *op, AsmConfigBuilder &builder) const = 0; + + /// Return an configuration printer implemented via the given callable, whose + /// form should match that of `buildConfigs` above. + template + static std::unique_ptr fromCallable(StringRef name, + CallableT &&printFn) { + struct Printer : public AsmConfigPrinter { + Printer(StringRef name, CallableT &&printFn) + : AsmConfigPrinter(name), printFn(std::move(printFn)) {} + void buildConfigs(Operation *op, + AsmConfigBuilder &builder) const override { + printFn(op, builder); + } + + std::decay_t printFn; + }; + return std::make_unique(name, std::forward(printFn)); + } + +private: + std::string name; +}; + +//===----------------------------------------------------------------------===// +// AsmState +//===----------------------------------------------------------------------===// + /// This class provides management for the lifetime of the state used when /// printing the IR. It allows for alleviating the cost of recomputing the /// internal state of the asm printer. @@ -54,6 +265,23 @@ /// state has not been initialized. detail::AsmStateImpl &getImpl() { return *impl; } + //===--------------------------------------------------------------------===// + // Config + //===--------------------------------------------------------------------===// + + /// Attach the given configuration printer to the AsmState. + void attachConfigPrinter(std::unique_ptr printer); + + /// Attach an configuration printer, in the form of a callable, to the + /// AsmState. + template + std::enable_if_t>::value> + attachConfigPrinter(StringRef name, CallableT &&printFn) { + attachConfigPrinter( + AsmConfigPrinter::fromCallable(name, std::forward(printFn))); + } + private: AsmState() = delete; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -20,7 +20,8 @@ #include "llvm/Support/SMLoc.h" namespace mlir { - +class AsmConfigEntry; +class AsmConfigBuilder; class Builder; //===----------------------------------------------------------------------===// @@ -1324,6 +1325,12 @@ class OpAsmDialectInterface : public DialectInterface::Base { public: + OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} + + //===------------------------------------------------------------------===// + // Aliases + //===------------------------------------------------------------------===// + /// Holds the result of `getAlias` hook call. enum class AliasResult { /// The object (type or attribute) is not supported by the hook @@ -1336,8 +1343,6 @@ FinalAlias }; - OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} - /// Hooks for getting an alias identifier alias for a given symbol, that is /// not necessarily a part of this dialect. The identifier is used in place of /// the symbol when printing textual IR. These aliases must not contain `.` or @@ -1348,6 +1353,20 @@ virtual AliasResult getAlias(Type type, raw_ostream &os) const { return AliasResult::NoAlias; } + + //===------------------------------------------------------------------===// + // Config + //===------------------------------------------------------------------===// + + /// Hook for parsing configurations entries. Returns failure if the entry was + /// not valid, or could otherwise not be processed correctly. Any necessary + /// errors can be emitted via the provided entry. + virtual LogicalResult parseConfig(AsmConfigEntry &entry) const; + + /// Hook for building configurations to use during printing. The given + /// top-level root operation may be inspected to help determine what + /// information to include. + virtual void buildConfigs(Operation *op, AsmConfigBuilder &builder) const {} }; } // namespace mlir diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -13,6 +13,7 @@ #ifndef MLIR_PARSER_PARSER_H #define MLIR_PARSER_PARSER_H +#include "mlir/IR/AsmState.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include @@ -88,9 +89,33 @@ /// Return asm parser state to be used when parsing. AsmParserState *getAsmParserState() const { return asmState; } + /// Return the configuration parser registered to the given name, or nullptr + /// if no parser with `name` is registered. + AsmConfigParser *getConfigParser(StringRef name) const { + auto it = configParsers.find(name); + return it == configParsers.end() ? nullptr : it->second.get(); + } + + /// Attach the given configuration parser. + void attachConfigParser(std::unique_ptr parser) { + StringRef name = parser->getName(); + auto it = configParsers.try_emplace(name, std::move(parser)); + assert(it.second && "config parser already registered with the given name"); + } + + /// Attach the given callable configuration parser with the given name. + template + std::enable_if_t>::value> + attachConfigParser(StringRef name, CallableT &&parserFn) { + attachConfigParser( + AsmConfigParser::fromCallable(name, std::forward(parserFn))); + } + private: MLIRContext *context; AsmParserState *asmState; + DenseMap> configParsers; }; /// This parses the file specified by the indicated SourceMgr and appends parsed diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -112,6 +112,12 @@ /// The OpAsmOpInterface, see OpAsmInterface.td for more details. #include "mlir/IR/OpAsmInterface.cpp.inc" +LogicalResult OpAsmDialectInterface::parseConfig(AsmConfigEntry &entry) const { + return entry.emitError() << "unknown 'config' key '" << entry.getKey() + << "' for dialect '" << getDialect()->getNamespace() + << "'"; +} + //===----------------------------------------------------------------------===// // OpPrintingFlags //===----------------------------------------------------------------------===// @@ -1254,6 +1260,15 @@ return name; } +//===----------------------------------------------------------------------===// +// External Config +//===----------------------------------------------------------------------===// + +AsmConfigEntry::~AsmConfigEntry() = default; +AsmConfigBuilder::~AsmConfigBuilder() = default; +AsmConfigParser::~AsmConfigParser() = default; +AsmConfigPrinter::~AsmConfigPrinter() = default; + //===----------------------------------------------------------------------===// // AsmState //===----------------------------------------------------------------------===// @@ -1278,6 +1293,17 @@ /// Get the state used for SSA names. SSANameState &getSSANameState() { return nameState; } + /// Return the dialects within the context that implement + /// OpAsmDialectInterface. + DialectInterfaceCollection &getDialectInterfaces() { + return interfaces; + } + + /// Return the non-dialect config printers. + auto getConfigPrinters() { + return llvm::make_pointee_range(externalConfigPrinters); + } + /// Get the printer flags. const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } @@ -1292,6 +1318,9 @@ /// Collection of OpAsm interfaces implemented in the context. DialectInterfaceCollection interfaces; + /// A collection of non-dialect configuration printers. + SmallVector> externalConfigPrinters; + /// The state used for attribute and type aliases. AliasState aliasState; @@ -1303,6 +1332,9 @@ /// An optional location map to be populated. AsmState::LocationMap *locationMap; + + // Allow direct access to the impl fields. + friend AsmState; }; } // namespace detail } // namespace mlir @@ -1352,6 +1384,10 @@ return impl->getPrinterFlags(); } +void AsmState::attachConfigPrinter(std::unique_ptr printer) { + impl->externalConfigPrinters.emplace_back(std::move(printer)); +} + //===----------------------------------------------------------------------===// // AsmPrinter::Impl //===----------------------------------------------------------------------===// @@ -2629,6 +2665,51 @@ void printUserIDs(Operation *user, bool prefixComma = false); private: + /// This class represents a config builder implementation for the MLIR textual + /// assembly format. + class ConfigBuilder : public AsmConfigBuilder { + public: + using ValueFn = function_ref; + using PrintFn = function_ref; + + ConfigBuilder(OperationPrinter &p, PrintFn printFn) + : p(p), printFn(printFn) {} + ~ConfigBuilder() override = default; + + void buildBool(StringRef key, bool data) final { + printFn(key, [&](raw_ostream &os) { p.os << (data ? "true" : "false"); }); + } + + void buildString(StringRef key, StringRef data) final { + printFn(key, [&](raw_ostream &os) { p.printEscapedString(data); }); + } + + void buildBlob(StringRef key, ArrayRef data, + unsigned dataAlignment) final { + printFn(key, [&](raw_ostream &os) { + // Store the blob in a hex string containing the alignment and the data. + os << "\"0x" + << llvm::toHex(StringRef(reinterpret_cast(&dataAlignment), + sizeof(dataAlignment))) + << llvm::toHex(StringRef(data.data(), data.size())) << "\""; + }); + } + + private: + OperationPrinter &p; + PrintFn printFn; + }; + + /// Print the metadata dictionary for the file, eliding it if it is empty. + void printFileMetadataDictionary(Operation *op); + + /// Print the config section for the file metadata dictionary. + /// `checkAddMetadataDict` is used to indicate that metadata is going to be + /// added, and the file metadata dictionary should be started if it hasn't + /// yet. + void printConfigFileMetadata(function_ref checkAddMetadataDict, + Operation *op); + // Contains the stack of default dialects to use when printing regions. // A new dialect is pushed to the stack before parsing regions nested under an // operation implementing `OpAsmOpInterface`, and popped when done. At the @@ -2654,6 +2735,63 @@ // Output the aliases at the top level that can be deferred. state->getAliasState().printDeferredAliases(os, newLine); + + // Output any file level metadata. + printFileMetadataDictionary(op); +} + +void OperationPrinter::printFileMetadataDictionary(Operation *op) { + bool sawMetadataEntry = false; + auto checkAddMetadataDict = [&] { + if (!std::exchange(sawMetadataEntry, true)) + os << newLine << "{-#" << newLine; + }; + + // Add the various types of metadata. + printConfigFileMetadata(checkAddMetadataDict, op); + + // If the file dictionary exists, close it. + if (sawMetadataEntry) + os << newLine << "#-}" << newLine; +} + +void OperationPrinter::printConfigFileMetadata( + function_ref checkAddMetadataDict, Operation *op) { + // Functor used to add data entries to the file metadata dictionary. + bool hadConfig = false; + auto processProvider = [&](const Twine &name, auto &provider) { + bool hadEntry = false; + auto printFn = [&](StringRef key, ConfigBuilder::ValueFn valueFn) { + checkAddMetadataDict(); + + // Emit the top-level config entry if we haven't yet. + if (!std::exchange(hadConfig, true)) + os << " config: {" << newLine; + // Emit the parent config entry if we haven't yet. + if (!std::exchange(hadEntry, true)) + os << " " << name << ": {" << newLine; + else + os << "," << newLine; + + os << " " << key << ": "; + valueFn(os); + }; + ConfigBuilder entryBuilder(*this, printFn); + provider.buildConfigs(op, entryBuilder); + + if (hadEntry) + os << newLine << " }"; + }; + + // Check each of the dialects for config data. + for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) + processProvider(interface.getDialect()->getNamespace(), interface); + // Check for any non-dialect providers. + for (const auto &printer : state->getConfigPrinters()) + processProvider("<" + printer.getName() + ">", printer); + + if (hadConfig) + os << newLine << " }"; } /// Print a block argument in the usual format of: diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -242,16 +242,11 @@ return success(); } - /// Returns true if the current token corresponds to a keyword. - bool isCurrentTokenAKeyword() const { - return parser.getToken().isAny(Token::bare_identifier, Token::inttype) || - parser.getToken().isKeyword(); - } - /// Parse the given keyword if present. ParseResult parseOptionalKeyword(StringRef keyword) override { // Check that the current token has the same spelling. - if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) + if (!parser.isCurrentTokenAKeyword() || + parser.getTokenSpelling() != keyword) return failure(); parser.consumeToken(); return success(); @@ -260,7 +255,7 @@ /// Parse a keyword, if present, into 'keyword'. ParseResult parseOptionalKeyword(StringRef *keyword) override { // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) + if (!parser.isCurrentTokenAKeyword()) return failure(); *keyword = parser.getTokenSpelling(); @@ -273,7 +268,7 @@ parseOptionalKeyword(StringRef *keyword, ArrayRef allowedKeywords) override { // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) + if (!parser.isCurrentTokenAKeyword()) return failure(); StringRef currentKeyword = parser.getTokenSpelling(); diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -99,6 +99,10 @@ case ')': return formToken(Token::r_paren, tokStart); case '{': + if (*curPtr == '-' && *(curPtr + 1) == '#') { + curPtr += 2; + return formToken(Token::file_metadata_begin, tokStart); + } return formToken(Token::l_brace, tokStart); case '}': return formToken(Token::r_brace, tokStart); @@ -140,12 +144,14 @@ case '@': return lexAtIdentifier(tokStart); - case '!': - LLVM_FALLTHROUGH; - case '^': - LLVM_FALLTHROUGH; case '#': + if (*curPtr == '-' && *(curPtr + 1) == '}') { + curPtr += 2; + return formToken(Token::file_metadata_end, tokStart); + } LLVM_FALLTHROUGH; + case '!': + case '^': case '%': return lexPrefixedIdentifier(tokStart); case '"': diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -154,6 +154,15 @@ const llvm::fltSemantics &semantics, size_t typeSizeInBits); + /// Returns true if the current token corresponds to a keyword. + bool isCurrentTokenAKeyword() const { + return getToken().isAny(Token::bare_identifier, Token::inttype) || + getToken().isKeyword(); + } + + /// Parse a keyword, if present, into 'keyword'. + ParseResult parseOptionalKeyword(StringRef *keyword); + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -13,6 +13,7 @@ #include "Parser.h" #include "AsmParserImpl.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" @@ -291,6 +292,16 @@ return success(); } +ParseResult Parser::parseOptionalKeyword(StringRef *keyword) { + // Check that the current token is a keyword. + if (!isCurrentTokenAKeyword()) + return failure(); + + *keyword = getTokenSpelling(); + consumeToken(); + return success(); +} + //===----------------------------------------------------------------------===// // OperationParser //===----------------------------------------------------------------------===// @@ -2065,17 +2076,98 @@ private: /// Parse an attribute alias declaration. + /// + /// attribute-alias-def ::= '#' alias-name `=` attribute-value + /// ParseResult parseAttributeAliasDef(); - /// Parse an attribute alias declaration. + /// Parse a type alias declaration. + /// + /// type-alias-def ::= '!' alias-name `=` type + /// ParseResult parseTypeAliasDef(); + + /// Parse a top-level file metadata dictionary. + /// + /// file-metadata-dict ::= '{-#' file-metadata-entry* `#-}' + /// + ParseResult parseFileMetadataDictionary(); + + /// Parse an instance of the 'config' file metadata. + ParseResult parseConfigFileMetadata(); +}; + +/// This class represents an implementation of a config entry for the MLIR +/// textual format. +class ParsedConfigEntry : public AsmConfigEntry { +public: + ParsedConfigEntry(StringRef key, SMLoc keyLoc, Token value, Parser &p) + : key(key), keyLoc(keyLoc), value(value), p(p) {} + ~ParsedConfigEntry() override = default; + + StringRef getKey() const final { return key; } + + InFlightDiagnostic emitError() const final { return p.emitError(keyLoc); } + + FailureOr parseAsBool() const final { + if (value.is(Token::kw_true)) + return true; + if (value.is(Token::kw_false)) + return false; + return p.emitError(value.getLoc(), + "expected 'true' or 'false' value for key '" + key + + "'"); + } + + FailureOr parseAsString() const final { + if (value.isNot(Token::string)) + return p.emitError(value.getLoc(), + "expected string value for key '" + key + "'"); + return value.getStringValue(); + } + + FailureOr parseAsBlob(BlobAllocatorFn allocator) const final { + // Blob data within then textual format is represented as a hex string. + // TODO: We could avoid an additional alloc+copy here if we pre-allocated + // the buffer to use during hex processing. + Optional blobData = + value.is(Token::string) ? value.getHexStringValue() : llvm::None; + if (!blobData) + return p.emitError(value.getLoc(), + "expected hex string blob for key '" + key + "'"); + + // Extract the alignment of the blob data, which gets stored at the + // beginning of the string. + if (blobData->size() < sizeof(unsigned)) { + return p.emitError(value.getLoc(), + "expected hex string blob for key '" + key + + "' to encode alignment in first 4 bytes"); + } + unsigned align = 0; + memcpy(&align, blobData->data(), sizeof(unsigned)); + + // Get the data portion of the blob. + StringRef data = StringRef(*blobData).drop_front(sizeof(unsigned)); + if (data.empty()) + return Blob{llvm::None, /*dataWasAllocated=*/false}; + + // Allocate memory for the blob using the provided allocator and copy the + // data into it. + Blob blob = {allocator(data.size(), align), /*dataWasAllocated=*/true}; + assert(llvm::isAddrAligned(llvm::Align(align), blob.data.data()) && + "blob allocator did not return a properly aligned address"); + memcpy(blob.data.data(), data.data(), data.size()); + return blob; + } + +private: + StringRef key; + SMLoc keyLoc; + Token value; + Parser &p; }; } // namespace -/// Parses an attribute alias declaration. -/// -/// attribute-alias-def ::= '#' alias-name `=` attribute-value -/// ParseResult TopLevelOperationParser::parseAttributeAliasDef() { assert(getToken().is(Token::hash_identifier)); StringRef aliasName = getTokenSpelling().drop_front(); @@ -2104,10 +2196,6 @@ return success(); } -/// Parse a type alias declaration. -/// -/// type-alias-def ::= '!' alias-name `=` type -/// ParseResult TopLevelOperationParser::parseTypeAliasDef() { assert(getToken().is(Token::exclamation_identifier)); StringRef aliasName = getTokenSpelling().drop_front(); @@ -2136,6 +2224,86 @@ return success(); } +ParseResult TopLevelOperationParser::parseFileMetadataDictionary() { + consumeToken(Token::file_metadata_begin); + return parseCommaSeparatedListUntil( + Token::file_metadata_end, [&]() -> ParseResult { + // Parse the key of the metadata dictionary. + SMLoc keyLoc = getToken().getLoc(); + StringRef key; + if (failed(parseOptionalKeyword(&key))) + return emitError("expected identifier key in file " + "metadata dictionary"); + if (parseToken(Token::colon, "expected ':'")) + return failure(); + + // Process the metadata entry. + if (key == "config") + return parseConfigFileMetadata(); + return emitError(keyLoc, "unknown key '" + key + + "' in file metadata dictionary"); + }); +} + +ParseResult TopLevelOperationParser::parseConfigFileMetadata() { + if (parseToken(Token::l_brace, "expected '{'")) + return failure(); + + return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult { + // Parse the name, which is either 'dialect-name' or ''. + bool isExternalKey = consumeIf(Token::less); + SMLoc nameLoc = getToken().getLoc(); + StringRef name; + if (failed(parseOptionalKeyword(&name))) + return emitError("expected identifier name for 'config' entry"); + if (isExternalKey && parseToken(Token::greater, "expected '>'")) + return failure(); + if (parseToken(Token::colon, "expected ':'") || + parseToken(Token::l_brace, "expected '{'")) + return failure(); + + PointerUnion handler; + if (isExternalKey) { + handler = state.config.getConfigParser(name); + } else { + // Lookup the dialect and check that it can handle the configuration + // entry. + Dialect *dialect = getContext()->getOrLoadDialect(name); + if (!dialect) + return emitError(nameLoc, "dialect '" + name + "' is unknown"); + + handler = dyn_cast(dialect); + if (!handler) { + return emitError() << "unexpected 'config' section for dialect '" + << dialect->getNamespace() << "'"; + } + } + + return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult { + SMLoc keyLoc = getToken().getLoc(); + StringRef key; + if (failed(parseOptionalKeyword(&key))) + return emitError("expected identifier key for 'config' entry"); + if (parseToken(Token::colon, "expected ':'")) + return failure(); + Token valueTok = getToken(); + consumeToken(); + + // FIXME: What is our policy on unknown configurations? The current code + // asserts that dialect configs can be handler, but external configs don't + // need to be. We should define a set policy here. + if (!handler) + return success(); + + ParsedConfigEntry entry(key, keyLoc, valueTok, *this); + if (const auto *iface = handler.dyn_cast()) + return iface->parseConfig(entry); + auto *processor = handler.get(); + return processor->parseConfig(entry); + }); + }); +} + ParseResult TopLevelOperationParser::parse(Block *topLevelBlock, Location parserLoc) { // Create a top-level operation to contain the parsed state. @@ -2180,6 +2348,12 @@ if (parseTypeAliasDef()) return failure(); break; + + // Parse a file-level metadata dictionary. + case Token::file_metadata_begin: + if (parseFileMetadataDictionary()) + return failure(); + break; } } } diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -129,9 +129,12 @@ // Get the internal string data, without the quotes. StringRef bytes = getSpelling().drop_front().drop_back(); - // Try to extract the binary data from the hex string. + // Try to extract the binary data from the hex string. We expect the hex + // string to start with `0x` and have an even number of hex nibbles (nibbles + // should come in pairs). std::string hex; - if (!bytes.consume_front("0x") || !llvm::tryGetFromHex(bytes, hex)) + if (!bytes.consume_front("0x") || (bytes.size() & 1) || + !llvm::tryGetFromHex(bytes, hex)) return llvm::None; return hex; } diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -72,6 +72,9 @@ TOK_PUNCTUATION(star, "*") TOK_PUNCTUATION(vertical_bar, "|") +TOK_PUNCTUATION(file_metadata_begin, "{-#") +TOK_PUNCTUATION(file_metadata_end, "#-}") + // Keywords. These turn "foo" into Token::kw_foo enums. // NOTE: Please key these alphabetized to make it easier to find something in diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -25,3 +25,17 @@ // expected-error@below {{Test iterating `APInt`: }} // expected-error@below {{Test iterating `IntegerAttr`: }} arith.constant dense<> : tensor<0xi64> + +// Check that we handle an external constant parsed from the config. +// expected-error@below {{Test iterating `uint64_t`: 1, 2, 3}} +// expected-error@below {{Test iterating `APInt`: unable to iterate type}} +// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} +arith.constant #test.e1di64_elements<"blob1"> : tensor<3xi64> + +{-# + config: { + test: { + blob1: "0x08000000010000000000000002000000000000000300000000000000" + } + } +#-} diff --git a/mlir/test/IR/file-metadata-config.mlir b/mlir/test/IR/file-metadata-config.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/file-metadata-config.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// Check that we only preserve the blob that got referenced. +// CHECK: test: { +// CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000" +// CHECK-NEXT: } + +module attributes { test.blob_ref = #test.e1di64_elements<"blob1"> } {} + +{-# + config: { + test: { + blob1: "0x08000000010000000000000002000000000000000300000000000000", + blob2: "0x08000000040000000000000005000000000000000600000000000000" + } + } +#-} diff --git a/mlir/test/IR/invalid-file-metadata.mlir b/mlir/test/IR/invalid-file-metadata.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/invalid-file-metadata.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+2 {{expected identifier key in file metadata dictionary}} +{-# + +// ----- + +// expected-error@+2 {{expected ':'}} +{-# + key +#-} + +// ----- + +// expected-error@+2 {{unknown key 'some_key' in file metadata dictionary}} +{-# + some_key: {} +#-} + +// ----- + +//===----------------------------------------------------------------------===// +// `config` +//===----------------------------------------------------------------------===// + +// expected-error@+2 {{expected '{'}} +{-# + config: "value" +#-} + +// ----- + +// expected-error@+3 {{expected identifier name for 'config' entry}} +{-# + config: { + 10 + } +#-} + +// ----- + +// expected-error@+3 {{expected '>'}} +{-# + config: { + { + let mnemonic = "e1di64_elements"; + let parameters = (ins + AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type, + + // A string key to the data, which is actually stored within the + // dialect. + StringRefParameter<>:$key + ); + let extraClassDeclaration = [{ + /// Return the elements referenced by this attribute. + llvm::ArrayRef getElements() const; + + /// The set of data types that can be iterated by this attribute. + using ContiguousIterableTypesT = std::tuple; + + /// Provide begin iterators for the various iterable types. + // * uint64_t + auto value_begin_impl(OverloadToken) const { + return getElements().begin(); + } + }]; + let assemblyFormat = "`<` $key `>`"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -194,6 +194,20 @@ return get(getContext(), first, second, third); } +//===----------------------------------------------------------------------===// +// TestExtern1DI64ElementsAttr +//===----------------------------------------------------------------------===// + +ArrayRef TestExtern1DI64ElementsAttr::getElements() const { + auto &mgr = static_cast(getDialect()).getExternalDataManager(); + const TestExternalElementsData *data = mgr.getData(getKey()); + + // TODO: If ElementsAttr had failable accessors, we could gracefully handle + // this case. + assert(data && "expected external data to be registered"); + return data->data; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -43,6 +43,57 @@ class RewritePatternSet; } // namespace mlir +//===----------------------------------------------------------------------===// +// External Elements Data +//===----------------------------------------------------------------------===// + +namespace test { +/// This class represents a single external elements instance. It keeps track of +/// the data, and handles deallocation when necessary. +struct TestExternalElementsData { + TestExternalElementsData(llvm::MutableArrayRef data, + bool dataIsOwned) + : data(data), dataIsOwned(dataIsOwned) {} + ~TestExternalElementsData() { + if (dataIsOwned) + free(data.data()); + } + + /// The raw underlying data for the attribute. + llvm::MutableArrayRef data; + /// A boolean indicating if the data is owned and should be freed. + bool dataIsOwned; +}; + +/// This class acts as a manager for external elements data. It provides API +/// for creating and accessing registered elements data. +class TestExternalElementsDataManager { +public: + /// Return the data registered for the given name, or nullptr if no data is + /// registered. + const TestExternalElementsData *getData(llvm::StringRef name) const { + auto it = dataMap.find(name); + return it != dataMap.end() ? &*it->second : nullptr; + } + + /// Register the provided data with the given name. Asserts that no data is + /// already registered with the name. + void insertData(llvm::StringRef name, + std::unique_ptr data) { + auto it = dataMap.try_emplace(name, std::move(data)); + (void)it; + assert(it.second && "data already registered"); + } + +private: + llvm::StringMap> dataMap; +}; +} // namespace test + +//===----------------------------------------------------------------------===// +// TestDialect +//===----------------------------------------------------------------------===// + #include "TestOpInterfaces.h.inc" #include "TestOpStructs.h.inc" #include "TestOpsDialect.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" @@ -57,6 +58,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; + //===------------------------------------------------------------------===// + // Aliases + //===------------------------------------------------------------------===// + AliasResult getAlias(Attribute attr, raw_ostream &os) const final { StringAttr strAttr = attr.dyn_cast(); if (!strAttr) @@ -102,6 +107,41 @@ } return AliasResult::NoAlias; } + + //===------------------------------------------------------------------===// + // Config + //===------------------------------------------------------------------===// + + LogicalResult parseConfig(AsmConfigEntry &entry) const final { + TestDialect *dialect = cast(getDialect()); + TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); + + // The config entries are external constant data. + auto blobAllocFn = [](unsigned size, unsigned align) { + assert(align == alignof(uint64_t) && "unexpected data alignment"); + return MutableArrayRef((char *)malloc(size), size); + }; + FailureOr blob = entry.parseAsBlob(blobAllocFn); + if (failed(blob)) + return failure(); + + MutableArrayRef data((uint64_t *)blob->data.data(), + blob->data.size() / sizeof(uint64_t)); + mgr.insertData(entry.getKey(), std::make_unique( + data, blob->dataWasAllocated)); + return success(); + } + + void buildConfigs(Operation *op, AsmConfigBuilder &provider) const final { + SetVector usedExternalData; + op->walk([&](Operation *op) { + for (NamedAttribute attr : op->getAttrs()) + if (auto data = attr.getValue().dyn_cast()) + usedExternalData.insert(data); + }); + for (TestExtern1DI64ElementsAttr data : usedExternalData) + provider.buildBlob(data.getKey(), data.getElements()); + } }; struct TestDialectFoldInterface : public DialectFoldInterface { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -41,6 +41,12 @@ ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const override; + + /// Returns the external elements data manager for this dialect. + TestExternalElementsDataManager &getExternalDataManager() { + return externalDataManager; + } + private: // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces; @@ -49,6 +55,9 @@ ::llvm::SetVector<::mlir::Type> &stack) const; void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer, ::llvm::SetVector<::mlir::Type> &stack) const; + + /// An external data manager used to test external elements data. + TestExternalElementsDataManager externalDataManager; }]; }