diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -1,4 +1,4 @@ -//===- AsmState.h - State class for AsmPrinter ------------------*- C++ -*-===// +//===- AsmState.h - Assembly State Utilities --------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file defines the AsmState class. +// This file defines various classes and utilites for interacting with the MLIR +// assembly formats. // //===----------------------------------------------------------------------===// @@ -19,12 +20,190 @@ #include namespace mlir { +class AsmExternalConfigProvider; class Operation; namespace detail { class AsmStateImpl; } // namespace detail +//===----------------------------------------------------------------------===// +// External Config +//===----------------------------------------------------------------------===// + +/// The following classes enable support for providing and processing external +/// configurations within MLIR assembly formats. This is a mechanism with which +/// dialects, and external clients, may attach additional information when +/// printing IR without that information being encoded in the IR itself. +/// External configurations are not uniqued within the MLIR context, are not +/// attached directly to any operation, and are solely intended to live and be +/// processed outside of the immediate IR. There are many potential uses of this +/// functionality, for example MLIR's pass crash reproducer utilizes this to +/// attach the pass configuration executing when a crash occurs. Other types of +/// uses may be embedding large amounts of binary data, such as weights in ML +/// applications, that shouldn't be copied directly into the MLIR context, but +/// need to be kept adjacent to the IR. +/// +/// External configurations are encoded using a key-value pair nested within a +/// dictionary anchored by name either on a dialect, or an externally registered +/// entity. The key is an identifier used to disambiguate the data. The value +/// may be stored in various limited forms, but general encodings use a string +/// (human readable) or blob format (binary). Within the textual format, an +/// example may be of the form: +/// +/// {-# +/// // The `config` section within the file-level metadata dictionary is used +/// // to contain any external configuration entries. +/// config: { +/// // Here is a dictionary anchored on "mlir_reproducer", which is an +/// // external entity representing MLIR's crash reproducer functionality. +/// // External entity anchors are wrapped in `<>` to differentiate them +/// // from dialect names. +/// : { +/// // `pipeline` is an entry that holds a crash reproducer pipeline +/// // configuration. +/// pipeline: "func.func(canonicalize,cse)" +/// }, +/// // Here is a dictionary anchored on "foo_dialect", which is a dialect +/// // namespace. +/// foo_dialect: { +/// // `some_dialect_config` is a key to be interpreted by the dialect, +/// // and used to initialize/configure/etc. +/// some_dialect_config: "Some important config value" +/// } +/// } +/// #-} +/// + +//===----------------------------------------------------------------------===// +// External Config Entry + +/// This class is used to build external config entries to the printer. Each +/// external config entry has a corresponding key/value pair. The provided key +/// must be unique within the current context, this allows a dialect "foo" +/// providing external configs to not worry about overlap with data provided by +/// other dialects. +class AsmExternalConfigBuilder { +public: + virtual ~AsmExternalConfigBuilder(); + + /// Provide an external config entry represented by the given bool. + virtual void provideBool(StringRef key, bool data) = 0; + + /// Provide an external config entry represented by the given human-readable + /// string value. + virtual void provideString(StringRef key, StringRef data) = 0; + + /// Provide an external config entry represented by the given binary blob + /// data. + virtual void provideBlob(StringRef key, ArrayRef data, + unsigned dataAlignent) = 0; + /// A useful overload if the data type is known. Note that this does not + /// support `char` element types to avoid accidentally not providing the + /// expected alignment of data in situations that treat blobs generically. + template + std::enable_if_t::value> + provideBlob(StringRef key, ArrayRef data) { + provideBlob( + key, ArrayRef((const char *)data.data(), data.size() * sizeof(T)), + alignof(T)); + } +}; + +/// This class represents a single parsed external config entry. +class AsmExternalConfigEntry { +public: + virtual ~AsmExternalConfigEntry(); + + /// Return the key of the external config entry. + virtual StringRef getKey() const = 0; + + /// Emit an error at the location of this entry. + virtual InFlightDiagnostic emitError() const = 0; + + /// Process the external config entry represented by a boolean. + /// Returns failure if the entry does not correspond to a bool. + virtual FailureOr processAsBool() const = 0; + + /// Process the external config entry represented by a human-readable string. + /// Returns failure if the entry does not correspond to a string. + virtual FailureOr processAsString() const = 0; + + /// The type of an allocator function used to allocate memory for a blob when + /// required. The function is provided a size and alignment, and should return + /// an aligned allocation buffer. + using BlobAllocatorFn = + function_ref(unsigned size, unsigned align)>; + + /// This class represents a processed blob of data. + struct Blob { + /// The raw, properly aligned, blob data. + MutableArrayRef data; + + /// A flag indicating if data was allocated using the provided allocation + /// function, false if the data did not need to be allocated (e.g. such as + /// when mmap'd). + bool dataWasAllocated = false; + }; + + /// Process the external config entry represented by a binary blob. Returns + /// failure if the entry does not correspond to a blob. If the blob needed to + /// be allocated, the given allocator function is invoked. + virtual FailureOr processAsBlob(BlobAllocatorFn allocator) const = 0; +}; + +//===----------------------------------------------------------------------===// +// External Config + +class AsmExternalConfigProcessor { +public: + /// Create a new processor with the given identifying name. This name uniquely + /// identifies the entries of this processor, and differentiates them from + /// other contexts. Any necessary errors should be emitted with the provided + /// function. + AsmExternalConfigProcessor(StringRef name) : name(name.str()) {} + virtual ~AsmExternalConfigProcessor(); + + /// Return the name of this processor. + StringRef getName() const { return name; } + + /// Process the given external config found during parsing. Returns failure if + /// the key/data were not valid, or could otherwise not be processed + /// correctly. + virtual LogicalResult + processExternalConfig(AsmExternalConfigEntry &entry) = 0; + +private: + std::string name; +}; + +class AsmExternalConfigProvider { +public: + /// Create a new provider with the given identifying name. This name uniquely + /// identifies the entries of this provider, and differentiates them from + /// other contexts. + AsmExternalConfigProvider(StringRef name) : name(name.str()) {} + virtual ~AsmExternalConfigProvider(); + + /// Return the name of this provider. + StringRef getName() const { return name; } + + /// Provides external config to include during printing, utilizing the given + /// top-level root operation to help determine what information to include. + /// Provided data should be registered in the form of a key/data pair, to the + /// given provider. + virtual void + provideExternalConfig(Operation *op, + AsmExternalConfigBuilder &provider) const = 0; + +private: + std::string name; +}; + +//===----------------------------------------------------------------------===// +// AsmState +//===----------------------------------------------------------------------===// + /// This class provides management for the lifetime of the state used when /// printing the IR. It allows for alleviating the cost of recomputing the /// internal state of the asm printer. @@ -54,6 +233,39 @@ /// state has not been initialized. detail::AsmStateImpl &getImpl() { return *impl; } + //===--------------------------------------------------------------------===// + // External Config + //===--------------------------------------------------------------------===// + + /// Attach an external config provider to the AsmState to provide external + /// config during printing. + void attachExternalConfigProvider( + std::unique_ptr provider); + + /// Attach an external config provider in the form of a callable to the + /// AsmState to provide external config during printing. + template + std::enable_if_t>::value> + attachExternalConfigProvider(StringRef name, CallableT &&providerFn) { + struct Provider : public AsmExternalConfigProvider { + Provider(StringRef name, CallableT &&providerFn) + : AsmExternalConfigProvider(name), providerFn(std::move(providerFn)) { + } + + void + provideExternalConfig(Operation *op, + AsmExternalConfigBuilder &provider) const override { + providerFn(op, provider); + } + + std::decay_t providerFn; + }; + attachExternalConfigProvider( + std::make_unique(name, std::forward(providerFn))); + } + private: AsmState() = delete; diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -20,7 +20,8 @@ #include "llvm/Support/SMLoc.h" namespace mlir { - +class AsmExternalConfigEntry; +class AsmExternalConfigBuilder; class Builder; //===----------------------------------------------------------------------===// @@ -1324,6 +1325,12 @@ class OpAsmDialectInterface : public DialectInterface::Base { public: + OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} + + //===------------------------------------------------------------------===// + // Aliases + //===------------------------------------------------------------------===// + /// Holds the result of `getAlias` hook call. enum class AliasResult { /// The object (type or attribute) is not supported by the hook @@ -1336,8 +1343,6 @@ FinalAlias }; - OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} - /// Hooks for getting an alias identifier alias for a given symbol, that is /// not necessarily a part of this dialect. The identifier is used in place of /// the symbol when printing textual IR. These aliases must not contain `.` or @@ -1348,6 +1353,23 @@ virtual AliasResult getAlias(Type type, raw_ostream &os) const { return AliasResult::NoAlias; } + + //===------------------------------------------------------------------===// + // External Config + //===------------------------------------------------------------------===// + + /// Hook for processing external configurations found during parsing. Returns + /// failure if the entry was not valid, or could otherwise not be processed + /// correctly. Any necessary errors can be emitted via the provided entry. + virtual LogicalResult + processExternalConfig(AsmExternalConfigEntry &entry) const; + + /// Hook for providing external configurations to include during printing, + /// the given top-level root operation may be inspected to help determine what + /// information to include. + virtual void provideExternalConfig(Operation *op, + AsmExternalConfigBuilder &provider) const { + } }; } // namespace mlir diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -13,6 +13,7 @@ #ifndef MLIR_PARSER_PARSER_H #define MLIR_PARSER_PARSER_H +#include "mlir/IR/AsmState.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include @@ -88,9 +89,48 @@ /// Return asm parser state to be used when parsing. AsmParserState *getAsmParserState() const { return asmState; } + /// Return the configuration processor registered to the given name, or + /// nullptr if no processor with `name` is registered. + AsmExternalConfigProcessor *getExternalConfigProcessor(StringRef name) const { + auto it = configProcessors.find(name); + return it == configProcessors.end() ? nullptr : it->second.get(); + } + + /// Attach a new external configuration processor. + void attachExternalConfigProcessor( + std::unique_ptr processor) { + StringRef name = processor->getName(); + auto it = configProcessors.try_emplace(name, std::move(processor)); + assert(it.second && "configuration processor already registered"); + } + + /// Attach an external config processor in the form of a callable with the + /// given name. + template + std::enable_if_t>::value> + attachExternalConfigProcessor(StringRef name, CallableT &&processorFn) { + struct Processor : public AsmExternalConfigProcessor { + Processor(StringRef name, CallableT &&processorFn) + : AsmExternalConfigProcessor(name), + processorFn(std::move(processorFn)) {} + + LogicalResult + processExternalConfig(AsmExternalConfigEntry &entry) override { + return processorFn(entry); + } + + std::decay_t processorFn; + }; + attachExternalConfigProcessor(std::make_unique( + name, std::forward(processorFn))); + } + private: MLIRContext *context; AsmParserState *asmState; + DenseMap> + configProcessors; }; /// This parses the file specified by the indicated SourceMgr and appends parsed diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -112,6 +112,13 @@ /// The OpAsmOpInterface, see OpAsmInterface.td for more details. #include "mlir/IR/OpAsmInterface.cpp.inc" +LogicalResult OpAsmDialectInterface::processExternalConfig( + AsmExternalConfigEntry &entry) const { + return entry.emitError() << "unknown 'config' key '" << entry.getKey() + << "' for dialect '" << getDialect()->getNamespace() + << "'"; +} + //===----------------------------------------------------------------------===// // OpPrintingFlags //===----------------------------------------------------------------------===// @@ -1254,6 +1261,15 @@ return name; } +//===----------------------------------------------------------------------===// +// External Config +//===----------------------------------------------------------------------===// + +AsmExternalConfigEntry::~AsmExternalConfigEntry() = default; +AsmExternalConfigBuilder::~AsmExternalConfigBuilder() = default; +AsmExternalConfigProcessor::~AsmExternalConfigProcessor() = default; +AsmExternalConfigProvider::~AsmExternalConfigProvider() = default; + //===----------------------------------------------------------------------===// // AsmState //===----------------------------------------------------------------------===// @@ -1278,6 +1294,17 @@ /// Get the state used for SSA names. SSANameState &getSSANameState() { return nameState; } + /// Return the dialects within the context that implement + /// OpAsmDialectInterface. + DialectInterfaceCollection &getDialectInterfaces() { + return interfaces; + } + + /// Return the non-dialect external config providers. + auto getExternalConfigProviders() { + return llvm::make_pointee_range(externalConfigProviders); + } + /// Get the printer flags. const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } @@ -1292,6 +1319,10 @@ /// Collection of OpAsm interfaces implemented in the context. DialectInterfaceCollection interfaces; + /// A collection of non-dialect external data providers. + SmallVector> + externalConfigProviders; + /// The state used for attribute and type aliases. AliasState aliasState; @@ -1303,6 +1334,9 @@ /// An optional location map to be populated. AsmState::LocationMap *locationMap; + + // Allow direct access to the impl fields. + friend AsmState; }; } // namespace detail } // namespace mlir @@ -1352,6 +1386,11 @@ return impl->getPrinterFlags(); } +void AsmState::attachExternalConfigProvider( + std::unique_ptr provider) { + impl->externalConfigProviders.emplace_back(std::move(provider)); +} + //===----------------------------------------------------------------------===// // AsmPrinter::Impl //===----------------------------------------------------------------------===// @@ -2629,6 +2668,55 @@ void printUserIDs(Operation *user, bool prefixComma = false); private: + /// This class represents an external config builder implementation for the + /// MLIR textual assembly format. + class ExternalConfigBuilder : public AsmExternalConfigBuilder { + public: + using ValueFn = function_ref; + using ProvideFn = function_ref; + + ExternalConfigBuilder(OperationPrinter &p, ProvideFn provideFn) + : p(p), provideFn(provideFn) {} + ~ExternalConfigBuilder() override = default; + + void provideBool(StringRef key, bool data) final { + provideFn(key, + [&](raw_ostream &os) { p.os << (data ? "true" : "false"); }); + } + + void provideString(StringRef key, StringRef data) final { + provideFn(key, [&](raw_ostream &os) { p.printEscapedString(data); }); + } + + /// Provide an external data entry represented by the given binary blob + /// data. + void provideBlob(StringRef key, ArrayRef data, + unsigned dataAlignment) final { + provideFn(key, [&](raw_ostream &os) { + // Store the blob in a hex string containing the alignment and the data. + os << "\"0x" + << llvm::toHex(StringRef(reinterpret_cast(&dataAlignment), + sizeof(dataAlignment))) + << llvm::toHex(StringRef(data.data(), data.size())) << "\""; + }); + } + + private: + OperationPrinter &p; + ProvideFn provideFn; + }; + + /// Print the metadata dictionary for the file, eliding it if it is empty. + void printFileMetadataDictionary(Operation *op); + + /// Print the config section for the file metadata dictionary. + /// `checkAddMetadataDict` is used to indicate that metadata is going to be + /// added, and the file metadata dictionary should be started if it hasn't + /// yet. + void + printFileExternalConfigMetadata(function_ref checkAddMetadataDict, + Operation *op); + // Contains the stack of default dialects to use when printing regions. // A new dialect is pushed to the stack before parsing regions nested under an // operation implementing `OpAsmOpInterface`, and popped when done. At the @@ -2654,6 +2742,64 @@ // Output the aliases at the top level that can be deferred. state->getAliasState().printDeferredAliases(os, newLine); + + // Output any file level metadata. + printFileMetadataDictionary(op); +} + +void OperationPrinter::printFileMetadataDictionary(Operation *op) { + bool sawMetadataEntry = false; + auto checkAddMetadataDict = [&] { + if (!std::exchange(sawMetadataEntry, true)) + os << newLine << "{-#" << newLine; + }; + + // Add the various types of metadata. + printFileExternalConfigMetadata(checkAddMetadataDict, op); + + // If the file dictionary exists, close it. + if (sawMetadataEntry) + os << newLine << "#-}" << newLine; +} + +void OperationPrinter::printFileExternalConfigMetadata( + function_ref checkAddMetadataDict, Operation *op) { + // Functor used to add external data entries to the file metadata dictionary. + bool hadExternalConfig = false; + auto processProvider = [&](const Twine &name, auto &provider) { + bool hadEntry = false; + auto provideFn = [&](StringRef key, + ExternalConfigBuilder::ValueFn valueFn) { + checkAddMetadataDict(); + + // Emit the external config entry if we haven't yet. + if (!std::exchange(hadExternalConfig, true)) + os << " config: {" << newLine; + // Emit the parent config entry if we haven't yet. + if (!std::exchange(hadEntry, true)) + os << " " << name << ": {" << newLine; + else + os << "," << newLine; + + os << " " << key << ": "; + valueFn(os); + }; + ExternalConfigBuilder entryProvider(*this, provideFn); + provider.provideExternalConfig(op, entryProvider); + + if (hadEntry) + os << newLine << " }"; + }; + + // Check each of the dialects for external data. + for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) + processProvider(interface.getDialect()->getNamespace(), interface); + // Check for any non-dialect providers. + for (const auto &provider : state->getExternalConfigProviders()) + processProvider("<" + provider.getName() + ">", provider); + + if (hadExternalConfig) + os << newLine << " }"; } /// Print a block argument in the usual format of: diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -242,16 +242,11 @@ return success(); } - /// Returns true if the current token corresponds to a keyword. - bool isCurrentTokenAKeyword() const { - return parser.getToken().isAny(Token::bare_identifier, Token::inttype) || - parser.getToken().isKeyword(); - } - /// Parse the given keyword if present. ParseResult parseOptionalKeyword(StringRef keyword) override { // Check that the current token has the same spelling. - if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) + if (!parser.isCurrentTokenAKeyword() || + parser.getTokenSpelling() != keyword) return failure(); parser.consumeToken(); return success(); @@ -260,7 +255,7 @@ /// Parse a keyword, if present, into 'keyword'. ParseResult parseOptionalKeyword(StringRef *keyword) override { // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) + if (!parser.isCurrentTokenAKeyword()) return failure(); *keyword = parser.getTokenSpelling(); @@ -273,7 +268,7 @@ parseOptionalKeyword(StringRef *keyword, ArrayRef allowedKeywords) override { // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) + if (!parser.isCurrentTokenAKeyword()) return failure(); StringRef currentKeyword = parser.getTokenSpelling(); diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -99,6 +99,10 @@ case ')': return formToken(Token::r_paren, tokStart); case '{': + if (*curPtr == '-' && *(curPtr + 1) == '#') { + curPtr += 2; + return formToken(Token::file_metadata_begin, tokStart); + } return formToken(Token::l_brace, tokStart); case '}': return formToken(Token::r_brace, tokStart); @@ -140,12 +144,14 @@ case '@': return lexAtIdentifier(tokStart); - case '!': - LLVM_FALLTHROUGH; - case '^': - LLVM_FALLTHROUGH; case '#': + if (*curPtr == '-' && *(curPtr + 1) == '}') { + curPtr += 2; + return formToken(Token::file_metadata_end, tokStart); + } LLVM_FALLTHROUGH; + case '!': + case '^': case '%': return lexPrefixedIdentifier(tokStart); case '"': diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -154,6 +154,15 @@ const llvm::fltSemantics &semantics, size_t typeSizeInBits); + /// Returns true if the current token corresponds to a keyword. + bool isCurrentTokenAKeyword() const { + return getToken().isAny(Token::bare_identifier, Token::inttype) || + getToken().isKeyword(); + } + + /// Parse a keyword, if present, into 'keyword'. + ParseResult parseOptionalKeyword(StringRef *keyword); + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -13,6 +13,7 @@ #include "Parser.h" #include "AsmParserImpl.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" @@ -22,6 +23,7 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/bit.h" +#include "llvm/Support/Endian.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/SourceMgr.h" #include @@ -291,6 +293,16 @@ return success(); } +ParseResult Parser::parseOptionalKeyword(StringRef *keyword) { + // Check that the current token is a keyword. + if (!isCurrentTokenAKeyword()) + return failure(); + + *keyword = getTokenSpelling(); + consumeToken(); + return success(); +} + //===----------------------------------------------------------------------===// // OperationParser //===----------------------------------------------------------------------===// @@ -2065,17 +2077,98 @@ private: /// Parse an attribute alias declaration. + /// + /// attribute-alias-def ::= '#' alias-name `=` attribute-value + /// ParseResult parseAttributeAliasDef(); - /// Parse an attribute alias declaration. + /// Parse a type alias declaration. + /// + /// type-alias-def ::= '!' alias-name `=` type + /// ParseResult parseTypeAliasDef(); + + /// Parse a top-level file metadata dictionary. + /// + /// file-metadata-dict ::= '{-#' file-metadata-entry* `#-}' + /// + ParseResult parseFileMetadataDictionary(); + + /// Parse an instance of the 'config' file metadata. + ParseResult parseExternalConfigFileMetadata(); +}; + +/// This class represents an implementation of an external config entry for the +/// MLIR textual format. +class ParsedExternalConfigEntry : public AsmExternalConfigEntry { +public: + ParsedExternalConfigEntry(StringRef key, SMLoc keyLoc, Token value, Parser &p) + : key(key), keyLoc(keyLoc), value(value), p(p) {} + ~ParsedExternalConfigEntry() override = default; + + StringRef getKey() const final { return key; } + + InFlightDiagnostic emitError() const final { return p.emitError(keyLoc); } + + FailureOr processAsBool() const final { + if (value.is(Token::kw_true)) + return true; + if (value.is(Token::kw_false)) + return false; + return p.emitError(value.getLoc(), + "expected 'true' or 'false' value for key '" + key + + "'"); + } + + FailureOr processAsString() const final { + if (value.isNot(Token::string)) + return p.emitError(value.getLoc(), + "expected string value for key '" + key + "'"); + return value.getStringValue(); + } + + FailureOr processAsBlob(BlobAllocatorFn allocator) const final { + // Blob data within then textual format is represented as a hex string. + // TODO: We could avoid an additional alloc+copy here if we pre-allocated + // the buffer to use during hex parsing. + Optional blobData = + value.is(Token::string) ? value.getHexStringValue() : llvm::None; + if (!blobData) + return p.emitError(value.getLoc(), + "expected hex string blob for key '" + key + "'"); + + // Extract the alignment of the blob data, which gets stored at the + // beginning of the string. + if (blobData->size() < sizeof(unsigned)) { + return p.emitError(value.getLoc(), + "expected hex string blob for key '" + key + + "' to encode alignment in first 4 bytes"); + } + unsigned align = 0; + memcpy(&align, blobData->data(), sizeof(unsigned)); + + // Get the data portion of the blob. + StringRef data = StringRef(*blobData).drop_front(sizeof(unsigned)); + if (data.empty()) + return Blob{llvm::None, /*dataWasAllocated=*/false}; + + // Allocate memory for the blob using the provided allocator and copy the + // data into it. + Blob blob = {allocator(data.size(), align), /*dataWasAllocated=*/true}; + assert(llvm::isAddrAligned(llvm::Align(align), blob.data.data()) && + "blob allocator did not return an aligned address"); + memcpy(blob.data.data(), data.data(), data.size()); + return blob; + } + +private: + StringRef key; + SMLoc keyLoc; + Token value; + Parser &p; }; } // namespace -/// Parses an attribute alias declaration. -/// -/// attribute-alias-def ::= '#' alias-name `=` attribute-value -/// ParseResult TopLevelOperationParser::parseAttributeAliasDef() { assert(getToken().is(Token::hash_identifier)); StringRef aliasName = getTokenSpelling().drop_front(); @@ -2104,10 +2197,6 @@ return success(); } -/// Parse a type alias declaration. -/// -/// type-alias-def ::= '!' alias-name `=` type -/// ParseResult TopLevelOperationParser::parseTypeAliasDef() { assert(getToken().is(Token::exclamation_identifier)); StringRef aliasName = getTokenSpelling().drop_front(); @@ -2136,6 +2225,84 @@ return success(); } +ParseResult TopLevelOperationParser::parseFileMetadataDictionary() { + consumeToken(Token::file_metadata_begin); + return parseCommaSeparatedListUntil( + Token::file_metadata_end, [&]() -> ParseResult { + // Parse the key of the metadata dictionary. + SMLoc keyLoc = getToken().getLoc(); + StringRef key; + if (failed(parseOptionalKeyword(&key))) + return emitError("expected identifier key in file " + "metadata dictionary"); + if (parseToken(Token::colon, "expected ':'")) + return failure(); + + // Parse the value of the metadata instance and process it. + if (key == "config") + return parseExternalConfigFileMetadata(); + return emitError(keyLoc, "unknown key '" + key + + "' in file metadata dictionary"); + }); +} + +ParseResult TopLevelOperationParser::parseExternalConfigFileMetadata() { + if (parseToken(Token::l_brace, "expected '{'")) + return failure(); + + return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult { + // Parse the name, which is either 'dialect-name' or ''. + bool isExternalKey = consumeIf(Token::less); + SMLoc nameLoc = getToken().getLoc(); + StringRef name; + if (failed(parseOptionalKeyword(&name))) + return emitError("expected identifier name for 'config' entry"); + if (isExternalKey && parseToken(Token::greater, "expected '>'")) + return failure(); + if (parseToken(Token::colon, "expected ':'") || + parseToken(Token::l_brace, "expected '{'")) + return failure(); + + PointerUnion + handler; + if (isExternalKey) { + handler = state.config.getExternalConfigProcessor(name); + } else { + // Lookup the dialect and ask it to process the data. + Dialect *dialect = getContext()->getOrLoadDialect(name); + if (!dialect) + return emitError(nameLoc, "dialect '" + name + "' is unknown"); + + handler = dyn_cast(dialect); + if (!handler) { + return emitError() << "unexpected 'config' section for dialect '" + << dialect->getNamespace() << "'"; + } + } + + return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult { + SMLoc keyLoc = getToken().getLoc(); + StringRef key; + if (failed(parseOptionalKeyword(&key))) + return emitError("expected identifier key for 'config' entry"); + if (parseToken(Token::colon, "expected ':'")) + return failure(); + Token valueTok = getToken(); + consumeToken(); + + // FIXME: What should we do here? Ignore the configurations? + if (!handler) + return success(); + + ParsedExternalConfigEntry entry(key, keyLoc, valueTok, *this); + if (const auto *iface = handler.dyn_cast()) + return iface->processExternalConfig(entry); + auto *processor = handler.get(); + return processor->processExternalConfig(entry); + }); + }); +} + ParseResult TopLevelOperationParser::parse(Block *topLevelBlock, Location parserLoc) { // Create a top-level operation to contain the parsed state. @@ -2180,6 +2347,12 @@ if (parseTypeAliasDef()) return failure(); break; + + // Parse a file-level metadata dictionary. + case Token::file_metadata_begin: + if (parseFileMetadataDictionary()) + return failure(); + break; } } } diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -129,9 +129,12 @@ // Get the internal string data, without the quotes. StringRef bytes = getSpelling().drop_front().drop_back(); - // Try to extract the binary data from the hex string. + // Try to extract the binary data from the hex string. We expect the hex + // string to start with `0x` and have an even number of hex nibbles (nibbles + // should come in pairs). std::string hex; - if (!bytes.consume_front("0x") || !llvm::tryGetFromHex(bytes, hex)) + if (!bytes.consume_front("0x") || (bytes.size() & 1) || + !llvm::tryGetFromHex(bytes, hex)) return llvm::None; return hex; } diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -72,6 +72,9 @@ TOK_PUNCTUATION(star, "*") TOK_PUNCTUATION(vertical_bar, "|") +TOK_PUNCTUATION(file_metadata_begin, "{-#") +TOK_PUNCTUATION(file_metadata_end, "#-}") + // Keywords. These turn "foo" into Token::kw_foo enums. // NOTE: Please key these alphabetized to make it easier to find something in diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -25,3 +25,17 @@ // expected-error@below {{Test iterating `APInt`: }} // expected-error@below {{Test iterating `IntegerAttr`: }} arith.constant dense<> : tensor<0xi64> + +// Check that we handle an external constant parsed from the config. +// expected-error@below {{Test iterating `uint64_t`: 1, 2, 3}} +// expected-error@below {{Test iterating `APInt`: unable to iterate type}} +// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} +arith.constant #test.e1di64_elements<"blob1"> : tensor<3xi64> + +{-# + config: { + test: { + blob1: "0x08000000010000000000000002000000000000000300000000000000" + } + } +#-} diff --git a/mlir/test/IR/file-metadata-config.mlir b/mlir/test/IR/file-metadata-config.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/file-metadata-config.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// Check that we only preserve the blob that got referenced. +// CHECK: test: { +// CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000" +// CHECK-NEXT: } + +module attributes { test.blob_ref = #test.e1di64_elements<"blob1"> } {} + +{-# + config: { + test: { + blob1: "0x08000000010000000000000002000000000000000300000000000000", + blob2: "0x08000000040000000000000005000000000000000600000000000000" + } + } +#-} diff --git a/mlir/test/IR/invalid-file-metadata.mlir b/mlir/test/IR/invalid-file-metadata.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/invalid-file-metadata.mlir @@ -0,0 +1,111 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+2 {{expected identifier key in file metadata dictionary}} +{-# + +// ----- + +// expected-error@+2 {{expected ':'}} +{-# + key +#-} + +// ----- + +// expected-error@+2 {{unknown key 'some_key' in file metadata dictionary}} +{-# + some_key: {} +#-} + +// ----- + +//===----------------------------------------------------------------------===// +// `config` +//===----------------------------------------------------------------------===// + +// expected-error@+2 {{expected '{'}} +{-# + config: "value" +#-} + +// ----- + +// expected-error@+3 {{expected identifier name for 'config' entry}} +{-# + config: { + 10 + } +#-} + +// ----- + +// expected-error@+3 {{expected '>'}} +{-# + config: { + { + let mnemonic = "e1di64_elements"; + let parameters = (ins + AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type, + + // A string key to the data, which is actually stored within the + // dialect. + StringRefParameter<>:$key + ); + let extraClassDeclaration = [{ + /// Return the elements referenced by this attribute. + llvm::ArrayRef getElements() const; + + /// The set of data types that can be iterated by this attribute. + using ContiguousIterableTypesT = std::tuple; + + /// Provide begin iterators for the various iterable types. + // * uint64_t + auto value_begin_impl(OverloadToken) const { + return getElements().begin(); + } + }]; + let assemblyFormat = "`<` $key `>`"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -194,6 +194,17 @@ return get(getContext(), first, second, third); } +//===----------------------------------------------------------------------===// +// TestExtern1DI64ElementsAttr +//===----------------------------------------------------------------------===// + +ArrayRef TestExtern1DI64ElementsAttr::getElements() const { + auto &mgr = static_cast(getDialect()).getExternalDataManager(); + const TestExternalElementsData *data = mgr.getData(getKey()); + assert(data && "expected external data to be registered"); + return data->data; +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -43,6 +43,57 @@ class RewritePatternSet; } // namespace mlir +//===----------------------------------------------------------------------===// +// External Elements Data +//===----------------------------------------------------------------------===// + +namespace test { +/// This class represents a single external elements instance. It keeps track of +/// data, and handles deallocation when necessary. +struct TestExternalElementsData { + TestExternalElementsData(llvm::MutableArrayRef data, + bool dataIsOwned) + : data(data), dataIsOwned(dataIsOwned) {} + ~TestExternalElementsData() { + if (dataIsOwned) + free(data.data()); + } + + /// The raw underlying data for the attribute. + llvm::MutableArrayRef data; + /// A boolean indicating if the data is owned and should be freed. + bool dataIsOwned; +}; + +/// This class acts as a manager for external elements data. It provides API +/// for creating and accessing registered elements data. +class TestExternalElementsDataManager { +public: + /// Return the data registered for the given name, or nullptr if no data is + /// registered. + const TestExternalElementsData *getData(llvm::StringRef name) const { + auto it = dataMap.find(name); + return it != dataMap.end() ? &*it->second : nullptr; + } + + /// Register the provided data with the given name. Asserts that no data is + /// already registered with the name. + void insertData(llvm::StringRef name, + std::unique_ptr data) { + auto it = dataMap.try_emplace(name, std::move(data)); + (void)it; + assert(it.second && "data already registered"); + } + +private: + llvm::StringMap> dataMap; +}; +} // namespace test + +//===----------------------------------------------------------------------===// +// TestDialect +//===----------------------------------------------------------------------===// + #include "TestOpInterfaces.h.inc" #include "TestOpStructs.h.inc" #include "TestOpsDialect.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" @@ -57,6 +58,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; + //===------------------------------------------------------------------===// + // Aliases + //===------------------------------------------------------------------===// + AliasResult getAlias(Attribute attr, raw_ostream &os) const final { StringAttr strAttr = attr.dyn_cast(); if (!strAttr) @@ -102,6 +107,44 @@ } return AliasResult::NoAlias; } + + //===------------------------------------------------------------------===// + // External Config + //===------------------------------------------------------------------===// + + LogicalResult + processExternalConfig(AsmExternalConfigEntry &entry) const final { + TestDialect *dialect = cast(getDialect()); + TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); + + // The config entries are external constant data. + auto blobAllocFn = [](unsigned size, unsigned align) { + assert(align == alignof(uint64_t) && "unexpected data alignment"); + return MutableArrayRef((char *)malloc(size), size); + }; + FailureOr blob = + entry.processAsBlob(blobAllocFn); + if (failed(blob)) + return failure(); + + MutableArrayRef data((uint64_t *)blob->data.data(), + blob->data.size() / sizeof(uint64_t)); + mgr.insertData(entry.getKey(), std::make_unique( + data, blob->dataWasAllocated)); + return success(); + } + + void provideExternalConfig(Operation *op, + AsmExternalConfigBuilder &provider) const final { + SetVector usedExternalData; + op->walk([&](Operation *op) { + for (NamedAttribute attr : op->getAttrs()) + if (auto data = attr.getValue().dyn_cast()) + usedExternalData.insert(data); + }); + for (TestExtern1DI64ElementsAttr data : usedExternalData) + provider.provideBlob(data.getKey(), data.getElements()); + } }; struct TestDialectFoldInterface : public DialectFoldInterface { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -41,6 +41,12 @@ ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const override; + + /// Returns the external elements data manager for this dialect. + TestExternalElementsDataManager &getExternalDataManager() { + return externalDataManager; + } + private: // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces; @@ -49,6 +55,9 @@ ::llvm::SetVector<::mlir::Type> &stack) const; void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer, ::llvm::SetVector<::mlir::Type> &stack) const; + + /// An external data manager used to test external elements data. + TestExternalElementsDataManager externalDataManager; }]; }