diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -1,4 +1,4 @@ -//===- AsmState.h - State class for AsmPrinter ------------------*- C++ -*-===// +//===- AsmState.h - Assembly State Utilities --------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,7 +6,8 @@ // //===----------------------------------------------------------------------===// // -// This file defines the AsmState class. +// This file defines various classes and utilites for interacting with the MLIR +// assembly formats. // //===----------------------------------------------------------------------===// @@ -19,12 +20,319 @@ #include namespace mlir { +class AsmResourcePrinter; class Operation; namespace detail { class AsmStateImpl; } // namespace detail +//===----------------------------------------------------------------------===// +// Resources +//===----------------------------------------------------------------------===// + +/// The following classes enable support for parsing and printing resources +/// within MLIR assembly formats. Resources are a mechanism by which dialects, +/// and external clients, may attach additional information when parsing or +/// printing IR without that information being encoded in the IR itself. +/// Resources are not uniqued within the MLIR context, are not attached directly +/// to any operation, and are solely intended to live and be processed outside +/// of the immediate IR. +/// +/// Resources are encoded using a key-value pair nested within dictionaries +/// anchored either on a dialect, or an externally registered entity. +/// Dictionaries anchored on dialects use the dialect namespace directly, and +/// dictionaries anchored on external entities used a provided unique +/// identifier. The resource key is an identifier used to disambiguate the data. +/// The resource value may be stored in various limited forms, but general +/// encodings use a string (human readable) or blob format (binary). Within the +/// textual format, an example may be of the form: +/// +/// {-# +/// // The `dialect_resources` section within the file-level metadata +/// // dictionary is used to contain any dialect resource entries. +/// dialect_resources: { +/// // Here is a dictionary anchored on "foo_dialect", which is a dialect +/// // namespace. +/// foo_dialect: { +/// // `some_dialect_resource` is a key to be interpreted by the dialect, +/// // and used to initialize/configure/etc. +/// some_dialect_resource: "Some important resource value" +/// } +/// }, +/// // The `external_resources` section within the file-level metadata +/// // dictionary is used to contain any non-dialect resource entries. +/// external_resources: { +/// // Here is a dictionary anchored on "mlir_reproducer", which is an +/// // external entity representing MLIR's crash reproducer functionality. +/// mlir_reproducer: { +/// // `pipeline` is an entry that holds a crash reproducer pipeline +/// // resource. +/// pipeline: "func.func(canonicalize,cse)" +/// } +/// } +/// #-} +/// + +//===----------------------------------------------------------------------===// +// Resource Entry + +/// This class is used to build resource entries for use by the printer. Each +/// resource entry is represented using a key/value pair. The provided key must +/// be unique within the current context, which allows for a client to provide +/// resource entries without worrying about overlap with other clients. +class AsmResourceBuilder { +public: + virtual ~AsmResourceBuilder(); + + /// Build a resource entry represented by the given bool. + virtual void buildBool(StringRef key, bool data) = 0; + + /// Build a resource entry represented by the given human-readable string + /// value. + virtual void buildString(StringRef key, StringRef data) = 0; + + /// Build an resource entry represented by the given binary blob data. + virtual void buildBlob(StringRef key, ArrayRef data, + unsigned dataAlignent) = 0; + /// Build an resource entry represented by the given binary blob data. This is + /// a useful overload if the data type is known. Note that this does not + /// support `char` element types to avoid accidentally not providing the + /// expected alignment of data in situations that treat blobs generically. + template + std::enable_if_t::value> buildBlob(StringRef key, + ArrayRef data) { + buildBlob( + key, ArrayRef((const char *)data.data(), data.size() * sizeof(T)), + alignof(T)); + } +}; + +/// This class represents a processed binary blob of data. A resource blob is +/// essentially a collection of data, potentially mutable, with an associated +/// deleter function (used if the data needs to be destroyed). +class AsmResourceBlob { +public: + /// A deleter function that frees a blob given the data and allocation size. + using DeleterFn = llvm::unique_function; + + AsmResourceBlob() = default; + AsmResourceBlob(ArrayRef data, DeleterFn deleter, bool dataIsMutable) + : data(data), deleter(std::move(deleter)), dataIsMutable(dataIsMutable) {} + /// Utility constructor that initializes a blob with a non-char type T. + template + AsmResourceBlob(ArrayRef data, DelT &&deleteFn, bool dataIsMutable) + : data((const char *)data.data(), data.size() * sizeof(T)), + deleter([deleteFn = std::forward(deleteFn)](const void *data, + size_t size) { + return deleteFn((const T *)data, size); + }), + dataIsMutable(dataIsMutable) {} + AsmResourceBlob(AsmResourceBlob &&) = default; + AsmResourceBlob &operator=(AsmResourceBlob &&) = default; + AsmResourceBlob(const AsmResourceBlob &) = delete; + AsmResourceBlob &operator=(const AsmResourceBlob &) = delete; + ~AsmResourceBlob() { + if (deleter) + deleter(data.data(), data.size()); + } + + /// Return the raw underlying data of this blob. + ArrayRef getData() const { return data; } + + /// Return a mutable reference to the raw underlying data of this blob. + /// Asserts that the blob `isMutable`. + MutableArrayRef getMutableData() { + assert(isMutable() && + "cannot access mutable reference to non-mutable data"); + return MutableArrayRef(const_cast(data.data()), data.size()); + } + + /// Return if the data of this blob is mutable. + bool isMutable() const { return dataIsMutable; } + + /// Return the deleter function of this blob. + DeleterFn &getDeleter() { return deleter; } + const DeleterFn &getDeleter() const { return deleter; } + +private: + /// The raw, properly aligned, blob data. + ArrayRef data; + + /// An optional deleter function used to deallocate the underlying data when + /// necessary. + DeleterFn deleter; + + /// Whether the data is mutable. + bool dataIsMutable; +}; + +/// This class represents a single parsed resource entry. +class AsmParsedResourceEntry { +public: + virtual ~AsmParsedResourceEntry(); + + /// Return the key of the resource entry. + virtual StringRef getKey() const = 0; + + /// Emit an error at the location of this entry. + virtual InFlightDiagnostic emitError() const = 0; + + /// Parse the resource entry represented by a boolean. Returns failure if the + /// entry does not correspond to a bool. + virtual FailureOr parseAsBool() const = 0; + + /// Parse the resource entry represented by a human-readable string. Returns + /// failure if the entry does not correspond to a string. + virtual FailureOr parseAsString() const = 0; + + /// The type of an allocator function used to allocate memory for a blob when + /// required. The function is provided a size and alignment, and should return + /// an aligned allocation buffer. + using BlobAllocatorFn = + function_ref; + + /// Parse the resource entry represented by a binary blob. Returns failure if + /// the entry does not correspond to a blob. If the blob needed to be + /// allocated, the given allocator function is invoked. + virtual FailureOr + parseAsBlob(BlobAllocatorFn allocator) const = 0; +}; + +//===----------------------------------------------------------------------===// +// Resource Parser/Printer + +/// This class represents an instance of a resource parser. This class should be +/// implemented by non-dialect clients that want to inject additional resources +/// into MLIR assembly formats. +class AsmResourceParser { +public: + /// Create a new parser with the given identifying name. This name uniquely + /// identifies the entries of this parser, and differentiates them from other + /// contexts. + AsmResourceParser(StringRef name) : name(name.str()) {} + virtual ~AsmResourceParser(); + + /// Return the name of this parser. + StringRef getName() const { return name; } + + /// Parse the given resource entry. Returns failure if the key/data were not + /// valid, or could otherwise not be processed correctly. Any necessary errors + /// should be emitted with the provided entry. + virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) = 0; + + /// Return a resource parser implemented via the given callable, whose form + /// should match that of `parseResource` above. + template + static std::unique_ptr fromCallable(StringRef name, + CallableT &&parseFn) { + struct Processor : public AsmResourceParser { + Processor(StringRef name, CallableT &&parseFn) + : AsmResourceParser(name), parseFn(std::move(parseFn)) {} + LogicalResult parseResource(AsmParsedResourceEntry &entry) override { + return parseFn(entry); + } + + std::decay_t parseFn; + }; + return std::make_unique(name, std::forward(parseFn)); + } + +private: + std::string name; +}; + +/// This class represents an instance of a resource printer. This class should +/// be implemented by non-dialect clients that want to inject additional +/// resources into MLIR assembly formats. +class AsmResourcePrinter { +public: + /// Create a new printer with the given identifying name. This name uniquely + /// identifies the entries of this printer, and differentiates them from + /// other contexts. + AsmResourcePrinter(StringRef name) : name(name.str()) {} + virtual ~AsmResourcePrinter(); + + /// Return the name of this printer. + StringRef getName() const { return name; } + + /// Build any resources to include during printing, utilizing the given + /// top-level root operation to help determine what information to include. + /// Provided data should be registered in the form of a key/data pair, to the + /// given builder. + virtual void buildResources(Operation *op, + AsmResourceBuilder &builder) const = 0; + + /// Return a resource printer implemented via the given callable, whose form + /// should match that of `buildResources` above. + template + static std::unique_ptr fromCallable(StringRef name, + CallableT &&printFn) { + struct Printer : public AsmResourcePrinter { + Printer(StringRef name, CallableT &&printFn) + : AsmResourcePrinter(name), printFn(std::move(printFn)) {} + void buildResources(Operation *op, + AsmResourceBuilder &builder) const override { + printFn(op, builder); + } + + std::decay_t printFn; + }; + return std::make_unique(name, std::forward(printFn)); + } + +private: + std::string name; +}; + +//===----------------------------------------------------------------------===// +// ParserConfig +//===----------------------------------------------------------------------===// + +/// This class represents a configuration for the MLIR assembly parser. It +/// contains all of the necessary state to parse a textual MLIR source file. +class ParserConfig { +public: + ParserConfig(MLIRContext *context) : context(context) { + assert(context && "expected valid MLIR context"); + } + + /// Return the MLIRContext to be used when parsing. + MLIRContext *getContext() const { return context; } + + /// Return the resource parser registered to the given name, or nullptr if no + /// parser with `name` is registered. + AsmResourceParser *getResourceParser(StringRef name) const { + auto it = resourceParsers.find(name); + return it == resourceParsers.end() ? nullptr : it->second.get(); + } + + /// Attach the given resource parser. + void attachResourceParser(std::unique_ptr parser) { + StringRef name = parser->getName(); + auto it = resourceParsers.try_emplace(name, std::move(parser)); + assert(it.second && + "resource parser already registered with the given name"); + } + + /// Attach the given callable resource parser with the given name. + template + std::enable_if_t>::value> + attachResourceParser(StringRef name, CallableT &&parserFn) { + attachResourceParser(AsmResourceParser::fromCallable( + name, std::forward(parserFn))); + } + +private: + MLIRContext *context; + DenseMap> resourceParsers; +}; + +//===----------------------------------------------------------------------===// +// AsmState +//===----------------------------------------------------------------------===// + /// This class provides management for the lifetime of the state used when /// printing the IR. It allows for alleviating the cost of recomputing the /// internal state of the asm printer. @@ -54,6 +362,22 @@ /// state has not been initialized. detail::AsmStateImpl &getImpl() { return *impl; } + //===--------------------------------------------------------------------===// + // Resources + //===--------------------------------------------------------------------===// + + /// Attach the given resource printer to the AsmState. + void attachResourcePrinter(std::unique_ptr printer); + + /// Attach an resource printer, in the form of a callable, to the AsmState. + template + std::enable_if_t>::value> + attachResourcePrinter(StringRef name, CallableT &&printFn) { + attachResourcePrinter(AsmResourcePrinter::fromCallable( + name, std::forward(printFn))); + } + private: AsmState() = delete; diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -6,17 +6,21 @@ // //===----------------------------------------------------------------------===// // -// This file contains Interfaces for interacting with the AsmParser and -// AsmPrinter. +// This file contains interfaces and other utilities for interacting with the +// AsmParser and AsmPrinter. // //===----------------------------------------------------------------------===// -#ifndef MLIR_OPASMINTERFACE -#define MLIR_OPASMINTERFACE +#ifndef MLIR_IR_OPASMINTERFACE_TD +#define MLIR_IR_OPASMINTERFACE_TD +include "mlir/IR/AttrTypeBase.td" include "mlir/IR/OpBase.td" -/// Interface for hooking into the OpAsmPrinter and OpAsmParser. +//===----------------------------------------------------------------------===// +// OpAsmOpInterface +//===----------------------------------------------------------------------===// + def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> { let description = [{ This interface provides hooks to interact with the AsmPrinter and AsmParser @@ -105,4 +109,20 @@ ]; } -#endif // MLIR_OPASMINTERFACE +//===----------------------------------------------------------------------===// +// ResourceHandleParameter +//===----------------------------------------------------------------------===// + +/// This parameter represents a handle to a resource of the parent dialect that +/// is encoded into the "dialect_resources" section of the assembly format. The +/// parent dialect is expected to implement the various resource methods +/// defined in `OpAsmDialectInterface`. This parameter expects a C++ +/// `handleType` that derives from `AsmDialectResourceHandle::Base` and +/// implements a derived handle to the desired resource type. +class ResourceHandleParameter + : AttrOrTypeParameter { + let parser = "$_parser.parseResourceHandle<" # handleType # ", DialectT>()"; + let printer = "$_printer.printResourceHandle(&getDialect(), $_self)"; +} + +#endif // MLIR_IR_OPASMINTERFACE_TD diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -20,9 +20,71 @@ #include "llvm/Support/SMLoc.h" namespace mlir { - +class AsmParsedResourceEntry; +class AsmResourceBuilder; class Builder; +//===----------------------------------------------------------------------===// +// AsmDialectResourceHandle +//===----------------------------------------------------------------------===// + +/// This class represents an opaque handle to a dialect resource entry. +class AsmDialectResourceHandle { +public: + AsmDialectResourceHandle() = default; + AsmDialectResourceHandle(void *resource, TypeID resourceID) + : resource(resource), opaqueID(resourceID) {} + bool operator==(const AsmDialectResourceHandle &other) const { + return resource == other.resource; + } + + /// A base class used to implement derived dialect resource handles. + template + class Base; + + /// Return an opaque pointer to the referenced resource. + void *getResource() const { return resource; } + + /// Return the type ID of the resource. + TypeID getTypeID() const { return opaqueID; } + +private: + /// The opaque handle to the dialect resource. + void *resource = nullptr; + /// The type of the resource referenced. + TypeID opaqueID; +}; + +/// This class represents a CRTP base class for dialect resource handles. It +/// abstracts away various utilities necessary for defined derived resource +/// handles. +template +class AsmDialectResourceHandle::Base : public AsmDialectResourceHandle { +public: + /// Construct a handle from a pointer to the resource. The given pointer + /// should be guaranteed to live beyond the life of this handle. + Base(ResourceT *resource) + : AsmDialectResourceHandle(resource, TypeID::get()) {} + Base(AsmDialectResourceHandle handle) : AsmDialectResourceHandle(handle) { + assert(handle.getTypeID() == TypeID::get()); + } + + /// Return the resource referenced by this handle. + ResourceT *getResource() { return static_cast(resource); } + const ResourceT *getResource() const { + return static_cast(resource); + } + + /// Support llvm style casting. + static bool classof(const AsmDialectResourceHandle *handle) { + return handle->getTypeID() == TypeID::get(); + } +}; + +inline llvm::hash_code hash_value(const AsmDialectResourceHandle ¶m) { + return llvm::hash_value(param.getResource()); +} + //===----------------------------------------------------------------------===// // AsmPrinter //===----------------------------------------------------------------------===// @@ -94,6 +156,10 @@ /// special or non-printable characters in it. virtual void printSymbolName(StringRef symbolRef); + /// Print a handle to the given resource which belongs the `dialect`. + void printResourceHandle(Dialect *dialect, + const AsmDialectResourceHandle &resource); + /// Print an optional arrow followed by a type list. template void printOptionalArrowTypeList(TypeRange &&types) { @@ -856,6 +922,25 @@ StringRef attrName, NamedAttrList &attrs) = 0; + //===--------------------------------------------------------------------===// + // Resource Parsing + //===--------------------------------------------------------------------===// + + /// Parse a handle to a resource within the assembly format for the given + /// dialect. + template + FailureOr parseResourceHandle() { + SMLoc handleLoc = getCurrentLocation(); + FailureOr handle = + parseResourceHandle(getContext()->getOrLoadDialect()); + if (failed(handle)) + return failure(); + if (auto *result = dyn_cast(&*handle)) + return std::move(*result); + return emitError(handleLoc) << "provided resource handle differs from the " + "expected resource type"; + } + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// @@ -1012,6 +1097,12 @@ /// next token. virtual ParseResult parseXInDimensionList() = 0; +protected: + /// Parse a handle to a resource within the assembly format for the given + /// dialect. + virtual FailureOr + parseResourceHandle(Dialect *dialect) = 0; + private: AsmParser(const AsmParser &) = delete; void operator=(const AsmParser &) = delete; @@ -1324,6 +1415,12 @@ class OpAsmDialectInterface : public DialectInterface::Base { public: + OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} + + //===------------------------------------------------------------------===// + // Aliases + //===------------------------------------------------------------------===// + /// Holds the result of `getAlias` hook call. enum class AliasResult { /// The object (type or attribute) is not supported by the hook @@ -1336,8 +1433,6 @@ FinalAlias }; - OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} - /// Hooks for getting an alias identifier alias for a given symbol, that is /// not necessarily a part of this dialect. The identifier is used in place of /// the symbol when printing textual IR. These aliases must not contain `.` or @@ -1348,6 +1443,41 @@ virtual AliasResult getAlias(Type type, raw_ostream &os) const { return AliasResult::NoAlias; } + + //===--------------------------------------------------------------------===// + // Resources + //===--------------------------------------------------------------------===// + + /// Declare a resource with the given key, returning a handle to use for any + /// references of this resource key within the IR during parsing. The result + /// of `getResourceKey` on the returned handle is permitted to be different + /// than `key`. + virtual FailureOr + declareResource(StringRef key) const { + return failure(); + } + + /// Return a key to use for the given resource. This key should uniquely + /// identify this resource within the dialect. + virtual std::string + getResourceKey(const AsmDialectResourceHandle &handle) const { + llvm_unreachable( + "Dialect must implement `getResourceKey` when defining resources"); + } + + /// Hook for parsing resource entries. Returns failure if the entry was not + /// valid, or could otherwise not be processed correctly. Any necessary errors + /// can be emitted via the provided entry. + virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const; + + /// Hook for building resources to use during printing. The given `op` may be + /// inspected to help determine what information to include. + /// `referencedResources` contains all of the resources detected when printing + /// 'op'. + virtual void + buildResources(Operation *op, + const SetVector &referencedResources, + AsmResourceBuilder &builder) const {} }; } // namespace mlir @@ -1358,4 +1488,25 @@ /// The OpAsmOpInterface, see OpAsmInterface.td for more details. #include "mlir/IR/OpAsmInterface.h.inc" +namespace llvm { +template <> +struct DenseMapInfo { + static inline mlir::AsmDialectResourceHandle getEmptyKey() { + return {DenseMapInfo::getEmptyKey(), + DenseMapInfo::getEmptyKey()}; + } + static inline mlir::AsmDialectResourceHandle getTombstoneKey() { + return {DenseMapInfo::getTombstoneKey(), + DenseMapInfo::getTombstoneKey()}; + } + static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) { + return DenseMapInfo::getHashValue(handle.getResource()); + } + static bool isEqual(const mlir::AsmDialectResourceHandle &lhs, + const mlir::AsmDialectResourceHandle &rhs) { + return lhs.getResource() == rhs.getResource(); + } +}; +} // namespace llvm + #endif diff --git a/mlir/include/mlir/Parser/Parser.h b/mlir/include/mlir/Parser/Parser.h --- a/mlir/include/mlir/Parser/Parser.h +++ b/mlir/include/mlir/Parser/Parser.h @@ -13,6 +13,7 @@ #ifndef MLIR_PARSER_PARSER_H #define MLIR_PARSER_PARSER_H +#include "mlir/IR/AsmState.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include @@ -84,7 +85,7 @@ /// SSA uses and definitions). `asmState` should only be provided if this /// detailed information is desired. LogicalResult parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, - MLIRContext *context, + const ParserConfig &config, LocationAttr *sourceFileLoc = nullptr, AsmParserState *asmState = nullptr); @@ -96,7 +97,7 @@ /// non-null, it is populated with a file location representing the start of the /// source file that is being parsed. LogicalResult parseSourceFile(llvm::StringRef filename, Block *block, - MLIRContext *context, + const ParserConfig &config, LocationAttr *sourceFileLoc = nullptr); /// This parses the file specified by the indicated filename using the provided @@ -111,7 +112,7 @@ /// `asmState` should only be provided if this detailed information is desired. LogicalResult parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr, Block *block, - MLIRContext *context, + const ParserConfig &config, LocationAttr *sourceFileLoc = nullptr, AsmParserState *asmState = nullptr); @@ -123,22 +124,22 @@ /// populated with a file location representing the start of the source file /// that is being parsed. LogicalResult parseSourceString(llvm::StringRef sourceStr, Block *block, - MLIRContext *context, + const ParserConfig &config, LocationAttr *sourceFileLoc = nullptr); namespace detail { /// The internal implementation of the templated `parseSourceFile` methods /// below, that simply forwards to the non-templated version. template -inline OwningOpRef parseSourceFile(MLIRContext *ctx, +inline OwningOpRef parseSourceFile(const ParserConfig &config, ParserArgs &&...args) { LocationAttr sourceFileLoc; Block block; - if (failed(parseSourceFile(std::forward(args)..., &block, ctx, + if (failed(parseSourceFile(std::forward(args)..., &block, config, &sourceFileLoc))) return OwningOpRef(); return detail::constructContainerOpForParserIfNecessary( - &block, ctx, sourceFileLoc); + &block, config.getContext(), sourceFileLoc); } } // namespace detail @@ -152,8 +153,8 @@ /// `SingleBlockImplicitTerminator` trait. template inline OwningOpRef -parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context) { - return detail::parseSourceFile(context, sourceMgr); +parseSourceFile(const llvm::SourceMgr &sourceMgr, const ParserConfig &config) { + return detail::parseSourceFile(config, sourceMgr); } /// This parses the file specified by the indicated filename. If the source IR @@ -166,8 +167,8 @@ /// `SingleBlockImplicitTerminator` trait. template inline OwningOpRef parseSourceFile(StringRef filename, - MLIRContext *context) { - return detail::parseSourceFile(context, filename); + const ParserConfig &config) { + return detail::parseSourceFile(config, filename); } /// This parses the file specified by the indicated filename using the provided @@ -181,8 +182,8 @@ template inline OwningOpRef parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr, - MLIRContext *context) { - return detail::parseSourceFile(context, filename, sourceMgr); + const ParserConfig &config) { + return detail::parseSourceFile(config, filename, sourceMgr); } /// This parses the provided string containing MLIR. If the source IR contained @@ -195,13 +196,13 @@ /// `SingleBlockImplicitTerminator` trait. template inline OwningOpRef parseSourceString(llvm::StringRef sourceStr, - MLIRContext *context) { + const ParserConfig &config) { LocationAttr sourceFileLoc; Block block; - if (failed(parseSourceString(sourceStr, &block, context, &sourceFileLoc))) + if (failed(parseSourceString(sourceStr, &block, config, &sourceFileLoc))) return OwningOpRef(); return detail::constructContainerOpForParserIfNecessary( - &block, context, sourceFileLoc); + &block, config.getContext(), sourceFileLoc); } /// This parses a single MLIR attribute to an MLIR context if it was valid. If diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -112,6 +112,13 @@ /// The OpAsmOpInterface, see OpAsmInterface.td for more details. #include "mlir/IR/OpAsmInterface.cpp.inc" +LogicalResult +OpAsmDialectInterface::parseResource(AsmParsedResourceEntry &entry) const { + return entry.emitError() << "unknown 'resource' key '" << entry.getKey() + << "' for dialect '" << getDialect()->getNamespace() + << "'"; +} + //===----------------------------------------------------------------------===// // OpPrintingFlags //===----------------------------------------------------------------------===// @@ -1254,6 +1261,15 @@ return name; } +//===----------------------------------------------------------------------===// +// Resources +//===----------------------------------------------------------------------===// + +AsmParsedResourceEntry::~AsmParsedResourceEntry() = default; +AsmResourceBuilder::~AsmResourceBuilder() = default; +AsmResourceParser::~AsmResourceParser() = default; +AsmResourcePrinter::~AsmResourcePrinter() = default; + //===----------------------------------------------------------------------===// // AsmState //===----------------------------------------------------------------------===// @@ -1278,6 +1294,17 @@ /// Get the state used for SSA names. SSANameState &getSSANameState() { return nameState; } + /// Return the dialects within the context that implement + /// OpAsmDialectInterface. + DialectInterfaceCollection &getDialectInterfaces() { + return interfaces; + } + + /// Return the non-dialect resource printers. + auto getResourcePrinters() { + return llvm::make_pointee_range(externalResourcePrinters); + } + /// Get the printer flags. const OpPrintingFlags &getPrinterFlags() const { return printerFlags; } @@ -1292,6 +1319,9 @@ /// Collection of OpAsm interfaces implemented in the context. DialectInterfaceCollection interfaces; + /// A collection of non-dialect resource printers. + SmallVector> externalResourcePrinters; + /// The state used for attribute and type aliases. AliasState aliasState; @@ -1303,6 +1333,9 @@ /// An optional location map to be populated. AsmState::LocationMap *locationMap; + + // Allow direct access to the impl fields. + friend AsmState; }; } // namespace detail } // namespace mlir @@ -1352,6 +1385,11 @@ return impl->getPrinterFlags(); } +void AsmState::attachResourcePrinter( + std::unique_ptr printer) { + impl->externalResourcePrinters.emplace_back(std::move(printer)); +} + //===----------------------------------------------------------------------===// // AsmPrinter::Impl //===----------------------------------------------------------------------===// @@ -1403,6 +1441,15 @@ /// allows for the internal location to use an attribute alias. void printLocation(LocationAttr loc, bool allowAlias = false); + /// Print a reference to the given resource that is owned by the given + /// dialect. + void printResourceHandle(Dialect *dialect, + const AsmDialectResourceHandle &resource) { + auto *interface = cast(dialect); + os << interface->getResourceKey(resource); + dialectResources[dialect].insert(resource); + } + void printAffineMap(AffineMap map); void printAffineExpr(AffineExpr expr, @@ -1462,6 +1509,9 @@ /// A tracker for the number of new lines emitted during printing. NewLineCounter newLine; + + /// A set of dialect resources that were referenced during printing. + DenseMap> dialectResources; }; } // namespace mlir @@ -2216,6 +2266,11 @@ Impl subPrinter(attrNameStr, printerFlags, state); DialectAsmPrinter printer(subPrinter); dialect.printAttribute(attr, printer); + + // FIXME: Delete this when we no longer require a nested printer. + for (auto &it : subPrinter.dialectResources) + for (const auto &resource : it.second) + dialectResources[it.first].insert(resource); } printDialectSymbol(os, "#", dialect.getNamespace(), attrName); } @@ -2230,6 +2285,11 @@ Impl subPrinter(typeNameStr, printerFlags, state); DialectAsmPrinter printer(subPrinter); dialect.printType(type, printer); + + // FIXME: Delete this when we no longer require a nested printer. + for (auto &it : subPrinter.dialectResources) + for (const auto &resource : it.second) + dialectResources[it.first].insert(resource); } printDialectSymbol(os, "!", dialect.getNamespace(), typeName); } @@ -2300,6 +2360,12 @@ ::printSymbolReference(symbolRef, impl->getStream()); } +void AsmPrinter::printResourceHandle(Dialect *dialect, + const AsmDialectResourceHandle &resource) { + assert(impl && "expected AsmPrinter::printResourceHandle to be overriden"); + impl->printResourceHandle(dialect, resource); +} + //===----------------------------------------------------------------------===// // Affine expressions and maps //===----------------------------------------------------------------------===// @@ -2629,6 +2695,51 @@ void printUserIDs(Operation *user, bool prefixComma = false); private: + /// This class represents a resource builder implementation for the MLIR + /// textual assembly format. + class ResourceBuilder : public AsmResourceBuilder { + public: + using ValueFn = function_ref; + using PrintFn = function_ref; + + ResourceBuilder(OperationPrinter &p, PrintFn printFn) + : p(p), printFn(printFn) {} + ~ResourceBuilder() override = default; + + void buildBool(StringRef key, bool data) final { + printFn(key, [&](raw_ostream &os) { p.os << (data ? "true" : "false"); }); + } + + void buildString(StringRef key, StringRef data) final { + printFn(key, [&](raw_ostream &os) { p.printEscapedString(data); }); + } + + void buildBlob(StringRef key, ArrayRef data, + unsigned dataAlignment) final { + printFn(key, [&](raw_ostream &os) { + // Store the blob in a hex string containing the alignment and the data. + os << "\"0x" + << llvm::toHex(StringRef(reinterpret_cast(&dataAlignment), + sizeof(dataAlignment))) + << llvm::toHex(StringRef(data.data(), data.size())) << "\""; + }); + } + + private: + OperationPrinter &p; + PrintFn printFn; + }; + + /// Print the metadata dictionary for the file, eliding it if it is empty. + void printFileMetadataDictionary(Operation *op); + + /// Print the resource sections for the file metadata dictionary. + /// `checkAddMetadataDict` is used to indicate that metadata is going to be + /// added, and the file metadata dictionary should be started if it hasn't + /// yet. + void printResourceFileMetadata(function_ref checkAddMetadataDict, + Operation *op); + // Contains the stack of default dialects to use when printing regions. // A new dialect is pushed to the stack before parsing regions nested under an // operation implementing `OpAsmOpInterface`, and popped when done. At the @@ -2654,6 +2765,76 @@ // Output the aliases at the top level that can be deferred. state->getAliasState().printDeferredAliases(os, newLine); + + // Output any file level metadata. + printFileMetadataDictionary(op); +} + +void OperationPrinter::printFileMetadataDictionary(Operation *op) { + bool sawMetadataEntry = false; + auto checkAddMetadataDict = [&] { + if (!std::exchange(sawMetadataEntry, true)) + os << newLine << "{-#" << newLine; + }; + + // Add the various types of metadata. + printResourceFileMetadata(checkAddMetadataDict, op); + + // If the file dictionary exists, close it. + if (sawMetadataEntry) + os << newLine << "#-}" << newLine; +} + +void OperationPrinter::printResourceFileMetadata( + function_ref checkAddMetadataDict, Operation *op) { + // Functor used to add data entries to the file metadata dictionary. + bool hadResource = false; + auto processProvider = [&](StringRef dictName, StringRef name, auto &provider, + auto &&...providerArgs) { + bool hadEntry = false; + auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) { + checkAddMetadataDict(); + + // Emit the top-level resource entry if we haven't yet. + if (!std::exchange(hadResource, true)) + os << " " << dictName << "_resources: {" << newLine; + // Emit the parent resource entry if we haven't yet. + if (!std::exchange(hadEntry, true)) + os << " " << name << ": {" << newLine; + else + os << "," << newLine; + + os << " " << key << ": "; + valueFn(os); + }; + ResourceBuilder entryBuilder(*this, printFn); + provider.buildResources(op, providerArgs..., entryBuilder); + + if (hadEntry) + os << newLine << " }"; + }; + + // Print the `dialect_resources` section if we have any dialects with + // resources. + for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) { + StringRef name = interface.getDialect()->getNamespace(); + auto it = dialectResources.find(interface.getDialect()); + if (it != dialectResources.end()) + processProvider("dialect", name, interface, it->second); + else + processProvider("dialect", name, interface, + SetVector()); + } + if (hadResource) + os << newLine << " }"; + + // Print the `external_resources` section if we have any external clients with + // resources. + hadResource = false; + for (const auto &printer : state->getResourcePrinters()) + processProvider("external", printer.getName(), printer); + if (hadResource) + os << newLine << " }"; } /// Print a block argument in the usual format of: diff --git a/mlir/lib/Parser/AffineParser.cpp b/mlir/lib/Parser/AffineParser.cpp --- a/mlir/lib/Parser/AffineParser.cpp +++ b/mlir/lib/Parser/AffineParser.cpp @@ -718,7 +718,8 @@ /*RequiresNullTerminator=*/false); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); SymbolState symbolState; - ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr); + ParserConfig config(context); + ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr); Parser parser(state); raw_ostream &os = printDiagnosticInfo ? llvm::errs() : llvm::nulls(); diff --git a/mlir/lib/Parser/AsmParserImpl.h b/mlir/lib/Parser/AsmParserImpl.h --- a/mlir/lib/Parser/AsmParserImpl.h +++ b/mlir/lib/Parser/AsmParserImpl.h @@ -242,16 +242,11 @@ return success(); } - /// Returns true if the current token corresponds to a keyword. - bool isCurrentTokenAKeyword() const { - return parser.getToken().isAny(Token::bare_identifier, Token::inttype) || - parser.getToken().isKeyword(); - } - /// Parse the given keyword if present. ParseResult parseOptionalKeyword(StringRef keyword) override { // Check that the current token has the same spelling. - if (!isCurrentTokenAKeyword() || parser.getTokenSpelling() != keyword) + if (!parser.isCurrentTokenAKeyword() || + parser.getTokenSpelling() != keyword) return failure(); parser.consumeToken(); return success(); @@ -260,7 +255,7 @@ /// Parse a keyword, if present, into 'keyword'. ParseResult parseOptionalKeyword(StringRef *keyword) override { // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) + if (!parser.isCurrentTokenAKeyword()) return failure(); *keyword = parser.getTokenSpelling(); @@ -273,7 +268,7 @@ parseOptionalKeyword(StringRef *keyword, ArrayRef allowedKeywords) override { // Check that the current token is a keyword. - if (!isCurrentTokenAKeyword()) + if (!parser.isCurrentTokenAKeyword()) return failure(); StringRef currentKeyword = parser.getTokenSpelling(); @@ -439,6 +434,22 @@ return success(); } + //===--------------------------------------------------------------------===// + // Resource Parsing + //===--------------------------------------------------------------------===// + + /// Parse a handle to a resource within the assembly format. + FailureOr + parseResourceHandle(Dialect *dialect) override { + const auto *interface = dyn_cast(dialect); + if (!interface) { + return parser.emitError() << "dialect '" << dialect->getNamespace() + << "' does not expect resource handles"; + } + StringRef resourceName; + return parser.parseResourceHandle(interface, resourceName); + } + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -207,7 +207,8 @@ inputStr, /*BufferName=*/"", /*RequiresNullTerminator=*/false); sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); - ParserState state(sourceMgr, context, symbolState, /*asmState=*/nullptr); + ParserConfig config(context); + ParserState state(sourceMgr, config, symbolState, /*asmState=*/nullptr); Parser parser(state); Token startTok = parser.getToken(); @@ -237,6 +238,7 @@ /// attribute-alias ::= `#` alias-name /// Attribute Parser::parseExtendedAttr(Type type) { + MLIRContext *ctx = getContext(); Attribute attr = parseExtendedSymbol( *this, Token::hash_identifier, state.symbols.attributeAliasDefinitions, [&](StringRef dialectName, StringRef symbolData, @@ -250,7 +252,7 @@ if (Dialect *dialect = builder.getContext()->getOrLoadDialect(dialectName)) { return parseSymbol( - symbolData, state.context, state.symbols, [&](Parser &parser) { + symbolData, ctx, state.symbols, [&](Parser &parser) { CustomDialectAsmParser customParser(symbolData, parser); return dialect->parseAttribute(customParser, attrType); }); @@ -258,9 +260,8 @@ // Otherwise, form a new opaque attribute. return OpaqueAttr::getChecked( - [&] { return emitError(loc); }, - StringAttr::get(state.context, dialectName), symbolData, - attrType ? attrType : NoneType::get(state.context)); + [&] { return emitError(loc); }, StringAttr::get(ctx, dialectName), + symbolData, attrType ? attrType : NoneType::get(ctx)); }); // Ensure that the attribute has the same type as requested. @@ -280,25 +281,23 @@ /// type-alias ::= `!` alias-name /// Type Parser::parseExtendedType() { + MLIRContext *ctx = getContext(); return parseExtendedSymbol( *this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions, - [&](StringRef dialectName, StringRef symbolData, - SMLoc loc) -> Type { + [&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type { // If we found a registered dialect, then ask it to parse the type. - auto *dialect = state.context->getOrLoadDialect(dialectName); - - if (dialect) { + if (auto *dialect = ctx->getOrLoadDialect(dialectName)) { return parseSymbol( - symbolData, state.context, state.symbols, [&](Parser &parser) { + symbolData, ctx, state.symbols, [&](Parser &parser) { CustomDialectAsmParser customParser(symbolData, parser); return dialect->parseType(customParser); }); } // Otherwise, form a new opaque type. - return OpaqueType::getChecked( - [&] { return emitError(loc); }, - StringAttr::get(state.context, dialectName), symbolData); + return OpaqueType::getChecked([&] { return emitError(loc); }, + StringAttr::get(ctx, dialectName), + symbolData); }); } diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -99,6 +99,10 @@ case ')': return formToken(Token::r_paren, tokStart); case '{': + if (*curPtr == '-' && *(curPtr + 1) == '#') { + curPtr += 2; + return formToken(Token::file_metadata_begin, tokStart); + } return formToken(Token::l_brace, tokStart); case '}': return formToken(Token::r_brace, tokStart); @@ -140,12 +144,14 @@ case '@': return lexAtIdentifier(tokStart); - case '!': - LLVM_FALLTHROUGH; - case '^': - LLVM_FALLTHROUGH; case '#': + if (*curPtr == '-' && *(curPtr + 1) == '}') { + curPtr += 2; + return formToken(Token::file_metadata_end, tokStart); + } LLVM_FALLTHROUGH; + case '!': + case '^': case '%': return lexPrefixedIdentifier(tokStart); case '"': diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -28,11 +28,12 @@ Builder builder; - Parser(ParserState &state) : builder(state.context), state(state) {} + Parser(ParserState &state) + : builder(state.config.getContext()), state(state) {} // Helper methods to get stuff from the parser-global state. ParserState &getState() const { return state; } - MLIRContext *getContext() const { return state.context; } + MLIRContext *getContext() const { return state.config.getContext(); } const llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); } /// Parse a comma-separated list of elements up until the specified end token. @@ -153,6 +154,23 @@ const llvm::fltSemantics &semantics, size_t typeSizeInBits); + /// Returns true if the current token corresponds to a keyword. + bool isCurrentTokenAKeyword() const { + return getToken().isAny(Token::bare_identifier, Token::inttype) || + getToken().isKeyword(); + } + + /// Parse a keyword, if present, into 'keyword'. + ParseResult parseOptionalKeyword(StringRef *keyword); + + //===--------------------------------------------------------------------===// + // Resource Parsing + //===--------------------------------------------------------------------===// + + /// Parse a handle to a dialect resource within the assembly format. + FailureOr + parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name); + //===--------------------------------------------------------------------===// // Type Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -13,6 +13,7 @@ #include "Parser.h" #include "AsmParserImpl.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Verifier.h" @@ -291,6 +292,47 @@ return success(); } +ParseResult Parser::parseOptionalKeyword(StringRef *keyword) { + // Check that the current token is a keyword. + if (!isCurrentTokenAKeyword()) + return failure(); + + *keyword = getTokenSpelling(); + consumeToken(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Resource Parsing + +FailureOr +Parser::parseResourceHandle(const OpAsmDialectInterface *dialect, + StringRef &name) { + assert(dialect && "expected valid dialect interface"); + SMLoc nameLoc = getToken().getLoc(); + if (failed(parseOptionalKeyword(&name))) + return emitError("expected identifier key for 'resource' entry"); + auto &resources = getState().symbols.dialectResources; + + // Otherwise, ask the dialect to resolve a reference to this handle. This + // allows for us to remap the name of the handle if necessary. + std::pair &entry = + resources[dialect][name]; + if (entry.first.empty()) { + FailureOr result = dialect->declareResource(name); + if (failed(result)) { + return emitError(nameLoc) + << "unknown 'resource' key '" << name << "' for dialect '" + << dialect->getDialect()->getNamespace() << "'"; + } + entry.first = dialect->getResourceKey(*result); + entry.second = *result; + } + + name = entry.first; + return entry.second; +} + //===----------------------------------------------------------------------===// // OperationParser //===----------------------------------------------------------------------===// @@ -2064,17 +2106,103 @@ private: /// Parse an attribute alias declaration. + /// + /// attribute-alias-def ::= '#' alias-name `=` attribute-value + /// ParseResult parseAttributeAliasDef(); - /// Parse an attribute alias declaration. + /// Parse a type alias declaration. + /// + /// type-alias-def ::= '!' alias-name `=` type + /// ParseResult parseTypeAliasDef(); + + /// Parse a top-level file metadata dictionary. + /// + /// file-metadata-dict ::= '{-#' file-metadata-entry* `#-}' + /// + ParseResult parseFileMetadataDictionary(); + + /// Parse a resource metadata dictionary. + ParseResult parseResourceFileMetadata( + function_ref parseBody); + ParseResult parseDialectResourceFileMetadata(); + ParseResult parseExternalResourceFileMetadata(); +}; + +/// This class represents an implementation of a resource entry for the MLIR +/// textual format. +class ParsedResourceEntry : public AsmParsedResourceEntry { +public: + ParsedResourceEntry(StringRef key, SMLoc keyLoc, Token value, Parser &p) + : key(key), keyLoc(keyLoc), value(value), p(p) {} + ~ParsedResourceEntry() override = default; + + StringRef getKey() const final { return key; } + + InFlightDiagnostic emitError() const final { return p.emitError(keyLoc); } + + FailureOr parseAsBool() const final { + if (value.is(Token::kw_true)) + return true; + if (value.is(Token::kw_false)) + return false; + return p.emitError(value.getLoc(), + "expected 'true' or 'false' value for key '" + key + + "'"); + } + + FailureOr parseAsString() const final { + if (value.isNot(Token::string)) + return p.emitError(value.getLoc(), + "expected string value for key '" + key + "'"); + return value.getStringValue(); + } + + FailureOr + parseAsBlob(BlobAllocatorFn allocator) const final { + // Blob data within then textual format is represented as a hex string. + // TODO: We could avoid an additional alloc+copy here if we pre-allocated + // the buffer to use during hex processing. + Optional blobData = + value.is(Token::string) ? value.getHexStringValue() : llvm::None; + if (!blobData) + return p.emitError(value.getLoc(), + "expected hex string blob for key '" + key + "'"); + + // Extract the alignment of the blob data, which gets stored at the + // beginning of the string. + if (blobData->size() < sizeof(unsigned)) { + return p.emitError(value.getLoc(), + "expected hex string blob for key '" + key + + "' to encode alignment in first 4 bytes"); + } + unsigned align = 0; + memcpy(&align, blobData->data(), sizeof(unsigned)); + + // Get the data portion of the blob. + StringRef data = StringRef(*blobData).drop_front(sizeof(unsigned)); + if (data.empty()) + return AsmResourceBlob(); + + // Allocate memory for the blob using the provided allocator and copy the + // data into it. + AsmResourceBlob blob = allocator(data.size(), align); + assert(llvm::isAddrAligned(llvm::Align(align), blob.getData().data()) && + blob.isMutable() && + "blob allocator did not return a properly aligned address"); + memcpy(blob.getMutableData().data(), data.data(), data.size()); + return blob; + } + +private: + StringRef key; + SMLoc keyLoc; + Token value; + Parser &p; }; } // namespace -/// Parses an attribute alias declaration. -/// -/// attribute-alias-def ::= '#' alias-name `=` attribute-value -/// ParseResult TopLevelOperationParser::parseAttributeAliasDef() { assert(getToken().is(Token::hash_identifier)); StringRef aliasName = getTokenSpelling().drop_front(); @@ -2103,10 +2231,6 @@ return success(); } -/// Parse a type alias declaration. -/// -/// type-alias-def ::= '!' alias-name `=` type -/// ParseResult TopLevelOperationParser::parseTypeAliasDef() { assert(getToken().is(Token::exclamation_identifier)); StringRef aliasName = getTokenSpelling().drop_front(); @@ -2135,6 +2259,108 @@ return success(); } +ParseResult TopLevelOperationParser::parseFileMetadataDictionary() { + consumeToken(Token::file_metadata_begin); + return parseCommaSeparatedListUntil( + Token::file_metadata_end, [&]() -> ParseResult { + // Parse the key of the metadata dictionary. + SMLoc keyLoc = getToken().getLoc(); + StringRef key; + if (failed(parseOptionalKeyword(&key))) + return emitError("expected identifier key in file " + "metadata dictionary"); + if (parseToken(Token::colon, "expected ':'")) + return failure(); + + // Process the metadata entry. + if (key == "dialect_resources") + return parseDialectResourceFileMetadata(); + if (key == "external_resources") + return parseExternalResourceFileMetadata(); + return emitError(keyLoc, "unknown key '" + key + + "' in file metadata dictionary"); + }); +} + +ParseResult TopLevelOperationParser::parseResourceFileMetadata( + function_ref parseBody) { + if (parseToken(Token::l_brace, "expected '{'")) + return failure(); + + return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult { + // Parse the top-level name entry. + SMLoc nameLoc = getToken().getLoc(); + StringRef name; + if (failed(parseOptionalKeyword(&name))) + return emitError("expected identifier key for 'resource' entry"); + + if (parseToken(Token::colon, "expected ':'") || + parseToken(Token::l_brace, "expected '{'")) + return failure(); + return parseBody(name, nameLoc); + }); +} + +ParseResult TopLevelOperationParser::parseDialectResourceFileMetadata() { + return parseResourceFileMetadata([&](StringRef name, + SMLoc nameLoc) -> ParseResult { + // Lookup the dialect and check that it can handle a resource entry. + Dialect *dialect = getContext()->getOrLoadDialect(name); + if (!dialect) + return emitError(nameLoc, "dialect '" + name + "' is unknown"); + const auto *handler = dyn_cast(dialect); + if (!handler) { + return emitError() << "unexpected 'resource' section for dialect '" + << dialect->getNamespace() << "'"; + } + + return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult { + // Parse the name of the resource entry. + SMLoc keyLoc = getToken().getLoc(); + StringRef key; + if (failed(parseResourceHandle(handler, key)) || + parseToken(Token::colon, "expected ':'")) + return failure(); + Token valueTok = getToken(); + consumeToken(); + + ParsedResourceEntry entry(key, keyLoc, valueTok, *this); + return handler->parseResource(entry); + }); + }); +} + +ParseResult TopLevelOperationParser::parseExternalResourceFileMetadata() { + return parseResourceFileMetadata([&](StringRef name, + SMLoc nameLoc) -> ParseResult { + AsmResourceParser *handler = state.config.getResourceParser(name); + + // TODO: Should we require handling external resources in some scenarios? + if (!handler) { + emitWarning(getEncodedSourceLocation(nameLoc)) + << "ignoring unknown external resources for '" << name << "'"; + } + + return parseCommaSeparatedListUntil(Token::r_brace, [&]() -> ParseResult { + // Parse the name of the resource entry. + SMLoc keyLoc = getToken().getLoc(); + StringRef key; + if (failed(parseOptionalKeyword(&key))) + return emitError( + "expected identifier key for 'external_resources' entry"); + if (parseToken(Token::colon, "expected ':'")) + return failure(); + Token valueTok = getToken(); + consumeToken(); + + if (!handler) + return success(); + ParsedResourceEntry entry(key, keyLoc, valueTok, *this); + return handler->parseResource(entry); + }); + }); +} + ParseResult TopLevelOperationParser::parse(Block *topLevelBlock, Location parserLoc) { // Create a top-level operation to contain the parsed state. @@ -2179,6 +2405,12 @@ if (parseTypeAliasDef()) return failure(); break; + + // Parse a file-level metadata dictionary. + case Token::file_metadata_begin: + if (parseFileMetadataDictionary()) + return failure(); + break; } } } @@ -2186,50 +2418,51 @@ //===----------------------------------------------------------------------===// LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, - Block *block, MLIRContext *context, + Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc, AsmParserState *asmState) { const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); - Location parserLoc = FileLineColLoc::get( - context, sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0); + Location parserLoc = + FileLineColLoc::get(config.getContext(), sourceBuf->getBufferIdentifier(), + /*line=*/0, /*column=*/0); if (sourceFileLoc) *sourceFileLoc = parserLoc; SymbolState aliasState; - ParserState state(sourceMgr, context, aliasState, asmState); + ParserState state(sourceMgr, config, aliasState, asmState); return TopLevelOperationParser(state).parse(block, parserLoc); } LogicalResult mlir::parseSourceFile(llvm::StringRef filename, Block *block, - MLIRContext *context, + const ParserConfig &config, LocationAttr *sourceFileLoc) { llvm::SourceMgr sourceMgr; - return parseSourceFile(filename, sourceMgr, block, context, sourceFileLoc); + return parseSourceFile(filename, sourceMgr, block, config, sourceFileLoc); } LogicalResult mlir::parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr, Block *block, - MLIRContext *context, + const ParserConfig &config, LocationAttr *sourceFileLoc, AsmParserState *asmState) { if (sourceMgr.getNumBuffers() != 0) { // TODO: Extend to support multiple buffers. - return emitError(mlir::UnknownLoc::get(context), + return emitError(mlir::UnknownLoc::get(config.getContext()), "only main buffer parsed at the moment"); } auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(filename); if (std::error_code error = fileOrErr.getError()) - return emitError(mlir::UnknownLoc::get(context), + return emitError(mlir::UnknownLoc::get(config.getContext()), "could not open input file " + filename); // Load the MLIR source file. sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), SMLoc()); - return parseSourceFile(sourceMgr, block, context, sourceFileLoc, asmState); + return parseSourceFile(sourceMgr, block, config, sourceFileLoc, asmState); } LogicalResult mlir::parseSourceString(llvm::StringRef sourceStr, Block *block, - MLIRContext *context, + const ParserConfig &config, LocationAttr *sourceFileLoc) { auto memBuffer = MemoryBuffer::getMemBuffer(sourceStr); if (!memBuffer) @@ -2237,5 +2470,5 @@ SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc()); - return parseSourceFile(sourceMgr, block, context, sourceFileLoc); + return parseSourceFile(sourceMgr, block, config, sourceFileLoc); } diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h --- a/mlir/lib/Parser/ParserState.h +++ b/mlir/lib/Parser/ParserState.h @@ -22,12 +22,18 @@ /// This class contains record of any parsed top-level symbols. struct SymbolState { - // A map from attribute alias identifier to Attribute. + /// A map from attribute alias identifier to Attribute. llvm::StringMap attributeAliasDefinitions; - // A map from type alias identifier to Type. + /// A map from type alias identifier to Type. llvm::StringMap typeAliasDefinitions; + /// A map of dialect resource keys to the resolved resource name and handle + /// to use during parsing. + DenseMap>> + dialectResources; + /// A set of locations into the main parser memory buffer for each of the /// active nested parsers. Given that some nested parsers, i.e. custom dialect /// parsers, operate on a temporary memory buffer, this provides an anchor @@ -47,11 +53,11 @@ /// This class refers to all of the state maintained globally by the parser, /// such as the current lexer position etc. struct ParserState { - ParserState(const llvm::SourceMgr &sourceMgr, MLIRContext *ctx, + ParserState(const llvm::SourceMgr &sourceMgr, const ParserConfig &config, SymbolState &symbols, AsmParserState *asmState) - : context(ctx), lex(sourceMgr, ctx), curToken(lex.lexToken()), - symbols(symbols), parserDepth(symbols.nestedParserLocs.size()), - asmState(asmState) { + : config(config), lex(sourceMgr, config.getContext()), + curToken(lex.lexToken()), symbols(symbols), + parserDepth(symbols.nestedParserLocs.size()), asmState(asmState) { // Set the top level lexer for the symbol state if one doesn't exist. if (!symbols.topLevelLexer) symbols.topLevelLexer = &lex; @@ -64,8 +70,8 @@ ParserState(const ParserState &) = delete; void operator=(const ParserState &) = delete; - /// The context we're parsing into. - MLIRContext *const context; + /// The configuration used to setup the parser. + const ParserConfig &config; /// The lexer for the source file we're parsing. Lexer lex; diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -129,9 +129,12 @@ // Get the internal string data, without the quotes. StringRef bytes = getSpelling().drop_front().drop_back(); - // Try to extract the binary data from the hex string. + // Try to extract the binary data from the hex string. We expect the hex + // string to start with `0x` and have an even number of hex nibbles (nibbles + // should come in pairs). std::string hex; - if (!bytes.consume_front("0x") || !llvm::tryGetFromHex(bytes, hex)) + if (!bytes.consume_front("0x") || (bytes.size() & 1) || + !llvm::tryGetFromHex(bytes, hex)) return llvm::None; return hex; } diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -72,6 +72,9 @@ TOK_PUNCTUATION(star, "*") TOK_PUNCTUATION(vertical_bar, "|") +TOK_PUNCTUATION(file_metadata_begin, "{-#") +TOK_PUNCTUATION(file_metadata_end, "#-}") + // Keywords. These turn "foo" into Token::kw_foo enums. // NOTE: Please key these alphabetized to make it easier to find something in diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -232,8 +232,7 @@ if (failed(parseStridedLayout(offset, strides))) return failure(); // Construct strided affine map. - AffineMap map = - makeStridedLinearLayoutMap(strides, offset, state.context); + AffineMap map = makeStridedLinearLayoutMap(strides, offset, getContext()); layout = AffineMapAttr::get(map); } else { // Either it is MemRefLayoutAttrInterface or memory space attribute. diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -25,3 +25,17 @@ // expected-error@below {{Test iterating `APInt`: }} // expected-error@below {{Test iterating `IntegerAttr`: }} arith.constant dense<> : tensor<0xi64> + +// Check that we handle an external constant parsed from the config. +// expected-error@below {{Test iterating `uint64_t`: 1, 2, 3}} +// expected-error@below {{Test iterating `APInt`: unable to iterate type}} +// expected-error@below {{Test iterating `IntegerAttr`: unable to iterate type}} +arith.constant #test.e1di64_elements : tensor<3xi64> + +{-# + dialect_resources: { + test: { + blob1: "0x08000000010000000000000002000000000000000300000000000000" + } + } +#-} diff --git a/mlir/test/IR/file-metadata-resources.mlir b/mlir/test/IR/file-metadata-resources.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/file-metadata-resources.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s -split-input-file | FileCheck %s + +// Check that we only preserve the blob that got referenced. +// CHECK: test: { +// CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000" +// CHECK-NEXT: } + +module attributes { test.blob_ref = #test.e1di64_elements } {} + +{-# + dialect_resources: { + test: { + blob1: "0x08000000010000000000000002000000000000000300000000000000", + blob2: "0x08000000040000000000000005000000000000000600000000000000" + } + } +#-} diff --git a/mlir/test/IR/invalid-file-metadata.mlir b/mlir/test/IR/invalid-file-metadata.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/invalid-file-metadata.mlir @@ -0,0 +1,142 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// expected-error@+2 {{expected identifier key in file metadata dictionary}} +{-# + +// ----- + +// expected-error@+2 {{expected ':'}} +{-# + key +#-} + +// ----- + +// expected-error@+2 {{unknown key 'some_key' in file metadata dictionary}} +{-# + some_key: {} +#-} + +// ----- + +//===----------------------------------------------------------------------===// +// `dialect_resources` +//===----------------------------------------------------------------------===// + +// expected-error@+2 {{expected '{'}} +{-# + dialect_resources: "value" +#-} + +// ----- + +// expected-error@+3 {{expected identifier key for 'resource' entry}} +{-# + dialect_resources: { + 10 + } +#-} + +// ----- + +// expected-error@+3 {{expected ':'}} +{-# + dialect_resources: { + entry "value" + } +#-} + +// ----- + +// expected-error@+3 {{dialect 'foobar' is unknown}} +{-# + dialect_resources: { + foobar: { + entry: "foo" + } + } +#-} + +// ----- + +// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'builtin'}} +{-# + dialect_resources: { + builtin: { + unknown_entry: "foo" + } + } +#-} + +// ----- + +// expected-error@+4 {{expected hex string blob for key 'invalid_blob'}} +{-# + dialect_resources: { + test: { + invalid_blob: 10 + } + } +#-} + +// ----- + +// expected-error@+4 {{expected hex string blob for key 'invalid_blob'}} +{-# + dialect_resources: { + test: { + invalid_blob: "" + } + } +#-} + +// ----- + +// expected-error@+4 {{expected hex string blob for key 'invalid_blob' to encode alignment in first 4 bytes}} +{-# + dialect_resources: { + test: { + invalid_blob: "0x" + } + } +#-} + +// ----- + +//===----------------------------------------------------------------------===// +// `external_resources` +//===----------------------------------------------------------------------===// + +// expected-error@+2 {{expected '{'}} +{-# + external_resources: "value" +#-} + +// ----- + +// expected-error@+3 {{expected identifier key for 'resource' entry}} +{-# + external_resources: { + 10 + } +#-} + +// ----- + +// expected-error@+3 {{expected ':'}} +{-# + external_resources: { + entry "value" + } +#-} + +// ----- + +// expected-warning@+3 {{ignoring unknown external resources for 'foobar'}} +{-# + external_resources: { + foobar: { + entry: "foo" + } + } +#-} diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -17,6 +17,7 @@ include "TestDialect.td" include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SubElementInterfaces.td" // All of the attributes will extend this class. @@ -214,4 +215,29 @@ let assemblyFormat = "`<` $a `>`"; } +// Test simple extern 1D vector using ElementsAttrInterface. +def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ + ElementsAttrInterface + ]> { + let mnemonic = "e1di64_elements"; + let parameters = (ins + AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type, + ResourceHandleParameter<"TestExternalElementsDataHandle">:$handle + ); + let extraClassDeclaration = [{ + /// Return the elements referenced by this attribute. + llvm::ArrayRef getElements() const; + + /// The set of data types that can be iterated by this attribute. + using ContiguousIterableTypesT = std::tuple; + + /// Provide begin iterators for the various iterable types. + // * uint64_t + auto value_begin_impl(OverloadToken) const { + return getElements().begin(); + } + }]; + let assemblyFormat = "`<` $handle `>`"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -25,6 +25,10 @@ #include "TestAttrInterfaces.h.inc" #include "TestOpEnums.h.inc" +namespace test { +struct TestExternalElementsDataHandle; +} // namespace test + #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -194,6 +194,14 @@ return get(getContext(), first, second, third); } +//===----------------------------------------------------------------------===// +// TestExtern1DI64ElementsAttr +//===----------------------------------------------------------------------===// + +ArrayRef TestExtern1DI64ElementsAttr::getElements() const { + return getHandle().getData()->getData(); +} + //===----------------------------------------------------------------------===// // Tablegen Generated Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -21,6 +21,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Traits.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -44,6 +45,69 @@ class RewritePatternSet; } // namespace mlir +//===----------------------------------------------------------------------===// +// External Elements Data +//===----------------------------------------------------------------------===// + +namespace test { +/// This class represents a single external elements instance. It keeps track of +/// the data, and handles deallocation when necessary. +class TestExternalElementsData : public mlir::AsmResourceBlob { +public: + using mlir::AsmResourceBlob::AsmResourceBlob; + TestExternalElementsData(mlir::AsmResourceBlob &&blob) + : mlir::AsmResourceBlob(std::move(blob)) {} + + /// Return the data of this external elements instance. + llvm::ArrayRef getData() const; + + /// Allocate a new external elements instance with the given number of + /// elements. + static TestExternalElementsData allocate(size_t numElements); +}; + +/// A handle used to reference external elements instances. +struct TestExternalElementsDataHandle + : public mlir::AsmDialectResourceHandle::Base< + TestExternalElementsDataHandle, + llvm::StringMapEntry>> { + using Base::Base; + + /// Return a key to use for this handle. + llvm::StringRef getKey() const { return getResource()->getKey(); } + + /// Return the data referenced by this handle. + TestExternalElementsData *getData() const { + return getResource()->getValue().get(); + } +}; + +/// This class acts as a manager for external elements data. It provides API +/// for creating and accessing registered elements data. +class TestExternalElementsDataManager { + using DataMap = llvm::StringMap>; + +public: + /// Return the data registered for the given name, or nullptr if no data is + /// registered. + const TestExternalElementsData *getData(llvm::StringRef name) const; + + /// Register an entry with the provided name, which may be modified if another + /// entry was already inserted with that name. Returns the inserted entry. + std::pair insert(llvm::StringRef name); + + /// Set the data for the given entry, which is expected to exist. + void setData(llvm::StringRef name, TestExternalElementsData &&data); + +private: + llvm::StringMap> dataMap; +}; +} // namespace test + +//===----------------------------------------------------------------------===// +// TestDialect +//===----------------------------------------------------------------------===// + #include "TestOpInterfaces.h.inc" #include "TestOpStructs.h.inc" #include "TestOpsDialect.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Diagnostics.h" @@ -43,6 +44,55 @@ registry.insert(); } +//===----------------------------------------------------------------------===// +// External Elements Data +//===----------------------------------------------------------------------===// + +ArrayRef TestExternalElementsData::getData() const { + ArrayRef data = AsmResourceBlob::getData(); + return ArrayRef((const uint64_t *)data.data(), + data.size() / sizeof(uint64_t)); +} + +TestExternalElementsData +TestExternalElementsData::allocate(size_t numElements) { + return TestExternalElementsData( + llvm::ArrayRef(new uint64_t[numElements], numElements), + [](const uint64_t *data, size_t) { delete[] data; }, + /*dataIsMutable=*/true); +} + +const TestExternalElementsData * +TestExternalElementsDataManager::getData(StringRef name) const { + auto it = dataMap.find(name); + return it != dataMap.end() ? &*it->second : nullptr; +} + +std::pair +TestExternalElementsDataManager::insert(StringRef name) { + auto it = dataMap.try_emplace(name, nullptr); + if (it.second) + return it; + + llvm::SmallString<32> nameStorage(name); + nameStorage.push_back('_'); + size_t nameCounter = 1; + do { + nameStorage += std::to_string(nameCounter++); + auto it = dataMap.try_emplace(nameStorage, nullptr); + if (it.second) + return it; + nameStorage.resize(name.size() + 1); + } while (true); +} + +void TestExternalElementsDataManager::setData(StringRef name, + TestExternalElementsData &&data) { + auto it = dataMap.find(name); + assert(it != dataMap.end() && "data not registered"); + it->second = std::make_unique(std::move(data)); +} + //===----------------------------------------------------------------------===// // TestDialect Interfaces //===----------------------------------------------------------------------===// @@ -63,6 +113,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; + //===------------------------------------------------------------------===// + // Aliases + //===------------------------------------------------------------------===// + AliasResult getAlias(Attribute attr, raw_ostream &os) const final { StringAttr strAttr = attr.dyn_cast(); if (!strAttr) @@ -108,6 +162,52 @@ } return AliasResult::NoAlias; } + + //===------------------------------------------------------------------===// + // Resources + //===------------------------------------------------------------------===// + + std::string + getResourceKey(const AsmDialectResourceHandle &handle) const override { + return cast(handle).getKey().str(); + } + + FailureOr + declareResource(StringRef key) const final { + TestDialect *dialect = cast(getDialect()); + TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); + + // Resolve the reference by inserting a new entry into the manager. + auto it = mgr.insert(key).first; + return TestExternalElementsDataHandle(&*it); + } + + LogicalResult parseResource(AsmParsedResourceEntry &entry) const final { + TestDialect *dialect = cast(getDialect()); + TestExternalElementsDataManager &mgr = dialect->getExternalDataManager(); + + // The resource entries are external constant data. + auto blobAllocFn = [](unsigned size, unsigned align) { + assert(align == alignof(uint64_t) && "unexpected data alignment"); + return TestExternalElementsData::allocate(size / sizeof(uint64_t)); + }; + FailureOr blob = entry.parseAsBlob(blobAllocFn); + if (failed(blob)) + return failure(); + + mgr.setData(entry.getKey(), std::move(*blob)); + return success(); + } + + void + buildResources(Operation *op, + const SetVector &referencedResources, + AsmResourceBuilder &provider) const final { + for (const AsmDialectResourceHandle &handle : referencedResources) { + const auto &testHandle = cast(handle); + provider.buildBlob(testHandle.getKey(), testHandle.getData()->getData()); + } + } }; struct TestDialectFoldInterface : public DialectFoldInterface { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.td b/mlir/test/lib/Dialect/Test/TestDialect.td --- a/mlir/test/lib/Dialect/Test/TestDialect.td +++ b/mlir/test/lib/Dialect/Test/TestDialect.td @@ -41,6 +41,12 @@ ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; void printType(::mlir::Type type, ::mlir::DialectAsmPrinter &printer) const override; + + /// Returns the external elements data manager for this dialect. + TestExternalElementsDataManager &getExternalDataManager() { + return externalDataManager; + } + private: // Storage for a custom fallback interface. void *fallbackEffectOpInterfaces; @@ -49,6 +55,9 @@ ::llvm::SetVector<::mlir::Type> &stack) const; void printTestType(::mlir::Type type, ::mlir::AsmPrinter &printer, ::llvm::SetVector<::mlir::Type> &stack) const; + + /// An external data manager used to test external elements data. + TestExternalElementsDataManager externalDataManager; }]; } diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -33,7 +33,6 @@ // DEF: return {}; def Test_Dialect: Dialect { -// DECL-NOT: TestDialect // DEF-NOT: TestDialect let name = "TestDialect"; let cppNamespace = "::test"; diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -34,7 +34,6 @@ // DEF: return {}; def Test_Dialect: Dialect { -// DECL-NOT: TestDialect let name = "TestDialect"; let cppNamespace = "::test"; } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -229,6 +229,8 @@ // Inherit constructors from the attribute or type class. defCls.declare(Visibility::Public); defCls.declare("Base::Base"); + defCls.declare("DialectT = " + + def.getDialect().getCppClassName()); // Emit the extra declarations first in case there's a definition in there. if (Optional extraDecl = def.getExtraDecls()) @@ -636,7 +638,11 @@ if (defs.empty()) return false; { - NamespaceEmitter nsEmitter(os, defs.front().getDialect()); + Dialect dialect = defs.front().getDialect(); + NamespaceEmitter nsEmitter(os, dialect); + + // Declare the dialect first so that def classes can reference it. + os << "class " << dialect.getCppClassName() << ";\n\n"; // Declare all the def classes first (in case they reference each other). for (const AttrOrTypeDef &def : defs) diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(Dialect) add_subdirectory(Interfaces) add_subdirectory(IR) +add_subdirectory(Parser) add_subdirectory(Pass) add_subdirectory(Support) add_subdirectory(Rewrite) diff --git a/mlir/unittests/Parser/CMakeLists.txt b/mlir/unittests/Parser/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/unittests/Parser/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_unittest(MLIRParserTests + ResourceTest.cpp + + DEPENDS + MLIRTestInterfaceIncGen +) +target_include_directories(MLIRParserTests PRIVATE "${MLIR_BINARY_DIR}/test/lib/Dialect/Test") + +target_link_libraries(MLIRParserTests PRIVATE + MLIRIR + MLIRParser + MLIRTestDialect +) diff --git a/mlir/unittests/Parser/ResourceTest.cpp b/mlir/unittests/Parser/ResourceTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Parser/ResourceTest.cpp @@ -0,0 +1,75 @@ +//===- ResourceTest.cpp -----------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" +#include "mlir/Parser/Parser.h" + +#include "gmock/gmock.h" + +using namespace mlir; + +namespace { +TEST(MLIRParser, ResourceKeyConflict) { + std::string moduleStr = R"mlir( + "test.use1"() {attr = #test.e1di64_elements : tensor<3xi64> } : () -> () + + {-# + dialect_resources: { + test: { + blob1: "0x08000000010000000000000002000000000000000300000000000000" + } + } + #-} + )mlir"; + std::string moduleStr2 = R"mlir( + "test.use2"() {attr = #test.e1di64_elements : tensor<3xi64> } : () -> () + + {-# + dialect_resources: { + test: { + blob1: "0x08000000040000000000000005000000000000000600000000000000" + } + } + #-} + )mlir"; + + MLIRContext context; + context.loadDialect(); + + // Parse both modules into the same context so that we ensure the conflicting + // resources have been loaded. + OwningOpRef module1 = + parseSourceString(moduleStr, &context); + OwningOpRef module2 = + parseSourceString(moduleStr2, &context); + ASSERT_TRUE(module1 && module2); + + // Merge the two modules so that we can test printing the remapped resources. + Block *block = module1->getBody(); + block->getOperations().splice(block->end(), + module2->getBody()->getOperations()); + + // Check that conflicting resources were remapped. + std::string outputStr; + { + llvm::raw_string_ostream os(outputStr); + module1->print(os); + } + StringRef output(outputStr); + EXPECT_TRUE( + output.contains("\"test.use1\"() {attr = #test.e1di64_elements")); + EXPECT_TRUE(output.contains( + "blob1: \"0x08000000010000000000000002000000000000000300000000000000\"")); + EXPECT_TRUE(output.contains( + "\"test.use2\"() {attr = #test.e1di64_elements")); + EXPECT_TRUE(output.contains( + "blob1_1: " + "\"0x08000000040000000000000005000000000000000600000000000000\"")); +} +} // namespace