diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -89,17 +89,23 @@ ``` section { - id: byte - length: varint + idAndIsAligned: byte // id | (hasAlign << 7) + length: varint, + + alignment: varint?, + padding: byte[], // Padding bytes are always `0xCB`. + + data: byte[] } ``` -Sections are a mechanism for grouping data within the bytecode. The enable +Sections are a mechanism for grouping data within the bytecode. They enable delayed processing, which is useful for out-of-order processing of data, -lazy-loading, and more. Each section contains a Section ID and a length (which -allowing for skipping over the section). - -TODO: Sections should also carry an optional alignment. Add this when necessary. +lazy-loading, and more. Each section contains a Section ID, whose high bit +indicates if the section has alignment requirements, a length (which allows for +skipping over the section), and an optional alignment. When an alignment is +present, a variable number of padding bytes (0xCB) may appear before the section +data. The alignment of a section must be a power of 2. ## MLIR Encoding @@ -244,6 +250,56 @@ further information. As such, a common encoding idiom is to use a leading `varint` code to indicate how the attribute or type was encoded. +### Resource Section + +Resources are encoded using two [sections](#sections), one section +(`resource_section`) containing the actual encoded representation, and another +section (`resource_offset_section`) containing the offsets of each encoded +resource into the previous section. + +``` +resource_section { + resources: resource[] +} +resource { + value: resource_bool | resource_string | resource_blob +} +resource_bool { + value: byte +} +resource_string { + value: varint +} +resource_blob { + alignment: varint, + size: varint, + padding: byte[], + blob: byte[] +} + +resource_offset_section { + numExternalResourceGroups: varint, + resourceGroups: resource_group[] +} +resource_group { + key: varint, + numResources: varint, + resources: resource_info[] +} +resource_info { + key: varint, + size: varint + kind: byte, +} +``` + +Resources are grouped by the provider, either an external entity or a dialect, +with each `resource_group` in the offset section containing the corresponding +provider, number of elements, and info for each element within the group. For +each element, we record the key, the value kind, and the encoded size. We avoid +using the direct offset into the `resource_section`, as a smaller relative +offsets provides more effective compression. + ### IR Section The IR section contains the encoded form of operations within the bytecode. diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -18,6 +18,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectInterface.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/Twine.h" @@ -105,6 +106,18 @@ << ", but got: " << baseResult; } + /// Read a handle to a dialect resource. + template + FailureOr readResourceHandle() { + FailureOr handle = readResourceHandle(); + if (failed(handle)) + return failure(); + if (auto *result = dyn_cast(&*handle)) + return std::move(*result); + return emitError() << "provided resource handle differs from the " + "expected resource type"; + } + //===--------------------------------------------------------------------===// // Primitives //===--------------------------------------------------------------------===// @@ -129,6 +142,10 @@ /// Read a string from the bytecode. virtual LogicalResult readString(StringRef &result) = 0; + +private: + /// Read a handle to a dialect resource. + virtual FailureOr readResourceHandle() = 0; }; //===----------------------------------------------------------------------===// @@ -171,6 +188,10 @@ writeList(types, [this](T type) { writeType(type); }); } + /// Write the given handle to a dialect resource. + virtual void + writeResourceHandle(const AsmDialectResourceHandle &resource) = 0; + //===--------------------------------------------------------------------===// // Primitives //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h --- a/mlir/include/mlir/Bytecode/BytecodeWriter.h +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -13,23 +13,59 @@ #ifndef MLIR_BYTECODE_BYTECODEWRITER_H #define MLIR_BYTECODE_BYTECODEWRITER_H -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/StringRef.h" +#include "mlir/IR/AsmState.h" namespace mlir { class Operation; +/// This class contains the configuration used for the bytecode writer. It +/// controls various aspects of bytecode generation, and contains all of the +/// various bytecode writer hooks. +class BytecodeWriterConfig { +public: + /// `producer` is an optional string that can be used to identify the producer + /// of the bytecode when reading. It has no functional effect on the bytecode + /// serialization. + BytecodeWriterConfig(StringRef producer = "MLIR" LLVM_VERSION_STRING); + ~BytecodeWriterConfig(); + + /// An internal implementation class that contains the state of the + /// configuration. + struct Impl; + + /// Return an instance of the internal implementation. + const Impl &getImpl() const { return *impl; } + + //===--------------------------------------------------------------------===// + // Resources + //===--------------------------------------------------------------------===// + + /// Attach the given resource printer to the writer configuration. + void attachResourcePrinter(std::unique_ptr printer); + + /// Attach an resource printer, in the form of a callable, to the + /// configuration. + template + std::enable_if_t>::value> + attachResourcePrinter(StringRef name, CallableT &&printFn) { + attachResourcePrinter(AsmResourcePrinter::fromCallable( + name, std::forward(printFn))); + } + +private: + /// A pointer to allocated storage for the impl state. + std::unique_ptr impl; +}; + //===----------------------------------------------------------------------===// // Entry Points //===----------------------------------------------------------------------===// /// Write the bytecode for the given operation to the provided output stream. /// For streams where it matters, the given stream should be in "binary" mode. -/// `producer` is an optional string that can be used to identify the producer -/// of the bytecode when reading. It has no functional effect on the bytecode -/// serialization. void writeBytecodeToFile(Operation *op, raw_ostream &os, - StringRef producer = "MLIR" LLVM_VERSION_STRING); + const BytecodeWriterConfig &config = {}); } // namespace mlir diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -262,6 +262,17 @@ } }; +/// This enum represents the different kinds of resource values. +enum class AsmResourceEntryKind { + /// A blob of data with an accompanying alignment. + Blob, + /// A boolean value. + Bool, + /// A string value. + String, +}; +StringRef toString(AsmResourceEntryKind kind); + /// This class represents a single parsed resource entry. class AsmParsedResourceEntry { public: @@ -273,6 +284,9 @@ /// Emit an error at the location of this entry. virtual InFlightDiagnostic emitError() const = 0; + /// Return the kind of this value. + virtual AsmResourceEntryKind getKind() const = 0; + /// Parse the resource entry represented by a boolean. Returns failure if the /// entry does not correspond to a bool. virtual FailureOr parseAsBool() const = 0; diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -2344,6 +2344,14 @@ InFlightDiagnostic emitError() const final { return p.emitError(keyLoc); } + AsmResourceEntryKind getKind() const final { + if (value.isAny(Token::kw_true, Token::kw_false)) + return AsmResourceEntryKind::Bool; + return value.getSpelling().startswith("\"0x") + ? AsmResourceEntryKind::Blob + : AsmResourceEntryKind::String; + } + FailureOr parseAsBool() const final { if (value.is(Token::kw_true)) return true; diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -25,6 +25,9 @@ enum { /// The current bytecode version. kVersion = 0, + + /// An arbitrary value used to fill alignment padding. + kAlignmentByte = 0xCB, }; //===----------------------------------------------------------------------===// @@ -51,8 +54,15 @@ /// and their nested regions/operations. kIR = 4, + /// This section contains the resources of the bytecode. + kResource = 5, + + /// This section contains the offsets of resources within the Resource + /// section. + kResourceOffset = 6, + /// The total number of section types. - kNumSections = 5, + kNumSections = 7, }; } // namespace Section diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -20,6 +20,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/Support/MemoryBufferRef.h" #include "llvm/Support/SaveAndRestore.h" @@ -40,11 +41,32 @@ return "AttrTypeOffset (3)"; case bytecode::Section::kIR: return "IR (4)"; + case bytecode::Section::kResource: + return "Resource (5)"; + case bytecode::Section::kResourceOffset: + return "ResourceOffset (6)"; default: return ("Unknown (" + Twine(static_cast(sectionID)) + ")").str(); } } +/// Returns true if the given top-level section ID is optional. +static bool isSectionOptional(bytecode::Section::ID sectionID) { + switch (sectionID) { + case bytecode::Section::kString: + case bytecode::Section::kDialect: + case bytecode::Section::kAttrType: + case bytecode::Section::kAttrTypeOffset: + case bytecode::Section::kIR: + return false; + case bytecode::Section::kResource: + case bytecode::Section::kResourceOffset: + return true; + default: + llvm_unreachable("unknown section ID"); + } +} + //===----------------------------------------------------------------------===// // EncodingReader //===----------------------------------------------------------------------===// @@ -65,11 +87,34 @@ /// Returns the remaining size of the bytecode. size_t size() const { return dataEnd - dataIt; } + /// Align the current reader position to the specified alignment. + LogicalResult alignTo(unsigned alignment) { + if (!llvm::isPowerOf2_32(alignment)) + return emitError("expected alignment to be a power-of-two"); + + // Shift the reader position to the next alignment boundary. + while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { + uint8_t padding; + if (failed(parseByte(padding))) + return failure(); + if (padding != bytecode::kAlignmentByte) { + return emitError("expected alignment byte (0xCB), but got: '0x" + + llvm::utohexstr(padding) + "'"); + } + } + + // TODO: Check that the current data pointer is actually at the expected + // alignment. + + return success(); + } + /// Emit an error using the given arguments. template InFlightDiagnostic emitError(Args &&...args) const { return ::emitError(fileLoc).append(std::forward(args)...); } + InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } /// Parse a single byte from the stream. template @@ -101,6 +146,17 @@ return success(); } + /// Parse an aligned blob of data, where the alignment was encoded alongside + /// the data. + LogicalResult parseBlobAndAlignment(ArrayRef &data, + uint64_t &alignment) { + uint64_t dataSize; + if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || + failed(alignTo(alignment))) + return failure(); + return parseBytes(dataSize, data); + } + /// Parse a variable length encoded integer from the byte stream. The first /// encoded byte contains a prefix in the low bits indicating the encoded /// length of the value. This length prefix is a bit sequence of '0's followed @@ -177,13 +233,31 @@ /// contents of the section in `sectionData`. LogicalResult parseSection(bytecode::Section::ID §ionID, ArrayRef §ionData) { + uint8_t sectionIDAndHasAlignment; uint64_t length; - if (failed(parseByte(sectionID)) || failed(parseVarInt(length))) + if (failed(parseByte(sectionIDAndHasAlignment)) || + failed(parseVarInt(length))) return failure(); + + // Extract the section ID and whether the section is aligned. The high bit + // of the ID is the alignment flag. + sectionID = static_cast(sectionIDAndHasAlignment & + 0b01111111); + bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; + + // Check that the section is actually valid before trying to process its + // data. if (sectionID >= bytecode::Section::kNumSections) return emitError("invalid section ID: ", unsigned(sectionID)); - // Parse the actua section data now that we have its length. + // Process the section alignment if present. + if (hasAlignment) { + uint64_t alignment; + if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) + return failure(); + } + + // Parse the actual section data. return parseBytes(static_cast(length), sectionData); } @@ -346,6 +420,14 @@ return success(); } + /// Return the loaded dialect, or nullptr if the dialect is unknown. This can + /// only be called after `load`. + Dialect *getLoadedDialect() const { + assert(dialect && + "expected `load` to be invoked before `getLoadedDialect`"); + return *dialect; + } + /// The loaded dialect entry. This field is None if we haven't attempted to /// load, nullptr if we failed to load, otherwise the loaded dialect. Optional dialect; @@ -393,6 +475,225 @@ return success(); } +//===----------------------------------------------------------------------===// +// ResourceSectionReader +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to read the resource section from the bytecode. +class ResourceSectionReader { +public: + /// Initialize the resource section reader with the given section data. + LogicalResult initialize(Location fileLoc, const ParserConfig &config, + MutableArrayRef dialects, + StringSectionReader &stringReader, + ArrayRef sectionData, + ArrayRef offsetSectionData); + + /// Parse a dialect resource handle from the resource section. + LogicalResult parseResourceHandle(EncodingReader &reader, + AsmDialectResourceHandle &result) { + return parseEntry(reader, dialectResources, result, "resource handle"); + } + +private: + /// The table of dialect resources within the bytecode file. + SmallVector dialectResources; +}; + +class ParsedResourceEntry : public AsmParsedResourceEntry { +public: + ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, + EncodingReader &reader, StringSectionReader &stringReader) + : key(key), kind(kind), reader(reader), stringReader(stringReader) {} + ~ParsedResourceEntry() override = default; + + StringRef getKey() const final { return key; } + + InFlightDiagnostic emitError() const final { return reader.emitError(); } + + AsmResourceEntryKind getKind() const final { return kind; } + + FailureOr parseAsBool() const final { + if (kind != AsmResourceEntryKind::Bool) + return emitError() << "expected a bool resource entry, but found a " + << toString(kind) << " entry instead"; + + bool value; + if (failed(reader.parseByte(value))) + return failure(); + return value; + } + FailureOr parseAsString() const final { + if (kind != AsmResourceEntryKind::String) + return emitError() << "expected a string resource entry, but found a " + << toString(kind) << " entry instead"; + + StringRef string; + if (failed(stringReader.parseString(reader, string))) + return failure(); + return string.str(); + } + + FailureOr + parseAsBlob(BlobAllocatorFn allocator) const final { + if (kind != AsmResourceEntryKind::Blob) + return emitError() << "expected a blob resource entry, but found a " + << toString(kind) << " entry instead"; + + ArrayRef data; + uint64_t alignment; + if (failed(reader.parseBlobAndAlignment(data, alignment))) + return failure(); + + // Allocate memory for the blob using the provided allocator and copy the + // data into it. + // FIXME: If the current holder of the bytecode can ensure its lifetime + // (e.g. when mmap'd), we should not copy the data. We should use the data + // from the bytecode directly. + AsmResourceBlob blob = allocator(data.size(), alignment); + assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && + blob.isMutable() && + "blob allocator did not return a properly aligned address"); + memcpy(blob.getMutableData().data(), data.data(), data.size()); + return blob; + } + +private: + StringRef key; + AsmResourceEntryKind kind; + EncodingReader &reader; + StringSectionReader &stringReader; +}; +} // namespace + +template +static LogicalResult +parseResourceGroup(Location fileLoc, bool allowEmpty, + EncodingReader &offsetReader, EncodingReader &resourceReader, + StringSectionReader &stringReader, T *handler, + function_ref processKeyFn = {}) { + uint64_t numResources; + if (failed(offsetReader.parseVarInt(numResources))) + return failure(); + + for (uint64_t i = 0; i < numResources; ++i) { + StringRef key; + AsmResourceEntryKind kind; + uint64_t resourceOffset; + ArrayRef data; + if (failed(stringReader.parseString(offsetReader, key)) || + failed(offsetReader.parseVarInt(resourceOffset)) || + failed(offsetReader.parseByte(kind)) || + failed(resourceReader.parseBytes(resourceOffset, data))) + return failure(); + + // Process the resource key. + if ((processKeyFn && failed(processKeyFn(key)))) + return failure(); + + // If the resource data is empty and we allow it, don't error out when + // parsing below, just skip it. + if (allowEmpty && data.empty()) + continue; + + // Ignore the entry if we don't have a valid handler. + if (!handler) + continue; + + // Otherwise, parse the resource value. + EncodingReader entryReader(data, fileLoc); + ParsedResourceEntry entry(key, kind, entryReader, stringReader); + if (failed(handler->parseResource(entry))) + return failure(); + if (!entryReader.empty()) { + return entryReader.emitError( + "unexpected trailing bytes in resource entry '", key, "'"); + } + } + return success(); +} + +LogicalResult +ResourceSectionReader::initialize(Location fileLoc, const ParserConfig &config, + MutableArrayRef dialects, + StringSectionReader &stringReader, + ArrayRef sectionData, + ArrayRef offsetSectionData) { + EncodingReader resourceReader(sectionData, fileLoc); + EncodingReader offsetReader(offsetSectionData, fileLoc); + + // Read the number of external resource providers. + uint64_t numExternalResourceGroups; + if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) + return failure(); + + // Utility functor that dispatches to `parseResourceGroup`, but implicitly + // provides most of the arguments. + auto parseGroup = [&](auto *handler, bool allowEmpty = false, + function_ref keyFn = {}) { + return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, + stringReader, handler, keyFn); + }; + + // Read the external resources from the bytecode. + for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { + StringRef key; + if (failed(stringReader.parseString(offsetReader, key))) + return failure(); + + // Get the handler for these resources. + // TODO: Should we require handling external resources in some scenarios? + AsmResourceParser *handler = config.getResourceParser(key); + if (!handler) { + emitWarning(fileLoc) << "ignoring unknown external resources for '" << key + << "'"; + } + + if (failed(parseGroup(handler))) + return failure(); + } + + // Read the dialect resources from the bytecode. + MLIRContext *ctx = fileLoc->getContext(); + while (!offsetReader.empty()) { + BytecodeDialect *dialect; + if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || + failed(dialect->load(resourceReader, ctx))) + return failure(); + Dialect *loadedDialect = dialect->getLoadedDialect(); + if (!loadedDialect) { + return resourceReader.emitError() + << "dialect '" << dialect->name << "' is unknown"; + } + const auto *handler = dyn_cast(loadedDialect); + if (!handler) { + return resourceReader.emitError() + << "unexpected resources for dialect '" << dialect->name << "'"; + } + + // Ensure that each resource is declared before being processed. + auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { + FailureOr handle = + handler->declareResource(key); + if (failed(handle)) { + return resourceReader.emitError() + << "unknown 'resource' key '" << key << "' for dialect '" + << dialect->name << "'"; + } + dialectResources.push_back(*handle); + return success(); + }; + + // Parse the resources for this dialect. We allow empty resources because we + // just treat these as declarations. + if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) + return failure(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // Attribute/Type Reader //===----------------------------------------------------------------------===// @@ -419,8 +720,10 @@ using TypeEntry = Entry; public: - AttrTypeReader(StringSectionReader &stringReader, Location fileLoc) - : stringReader(stringReader), fileLoc(fileLoc) {} + AttrTypeReader(StringSectionReader &stringReader, + ResourceSectionReader &resourceReader, Location fileLoc) + : stringReader(stringReader), resourceReader(resourceReader), + fileLoc(fileLoc) {} /// Initialize the attribute and type information within the reader. LogicalResult initialize(MutableArrayRef dialects, @@ -483,6 +786,10 @@ /// custom encoded attribute/type entries. StringSectionReader &stringReader; + /// The resource section reader used to resolve resource references when + /// parsing custom encoded attribute/type entries. + ResourceSectionReader &resourceReader; + /// The set of attribute and type entries. SmallVector attributes; SmallVector types; @@ -494,9 +801,10 @@ class DialectReader : public DialectBytecodeReader { public: DialectReader(AttrTypeReader &attrTypeReader, - StringSectionReader &stringReader, EncodingReader &reader) + StringSectionReader &stringReader, + ResourceSectionReader &resourceReader, EncodingReader &reader) : attrTypeReader(attrTypeReader), stringReader(stringReader), - reader(reader) {} + resourceReader(resourceReader), reader(reader) {} InFlightDiagnostic emitError(const Twine &msg) override { return reader.emitError(msg); @@ -514,6 +822,13 @@ return attrTypeReader.parseType(reader, result); } + FailureOr readResourceHandle() override { + AsmDialectResourceHandle handle; + if (failed(resourceReader.parseResourceHandle(reader, handle))) + return failure(); + return handle; + } + //===--------------------------------------------------------------------===// // Primitives //===--------------------------------------------------------------------===// @@ -575,6 +890,7 @@ private: AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; + ResourceSectionReader &resourceReader; EncodingReader &reader; }; } // namespace @@ -707,7 +1023,7 @@ } // Ask the dialect to parse the entry. - DialectReader dialectReader(*this, stringReader, reader); + DialectReader dialectReader(*this, stringReader, resourceReader, reader); if constexpr (std::is_same_v) entry.entry = entry.dialect->interface->readType(dialectReader); else @@ -724,7 +1040,8 @@ class BytecodeReader { public: BytecodeReader(Location fileLoc, const ParserConfig &config) - : config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc), + : config(config), fileLoc(fileLoc), + attrTypeReader(stringReader, resourceReader, fileLoc), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -761,6 +1078,13 @@ return attrTypeReader.parseType(reader, result); } + //===--------------------------------------------------------------------===// + // Resource Section + + LogicalResult + parseResourceSection(Optional> resourceData, + Optional> resourceOffsetData); + //===--------------------------------------------------------------------===// // IR Section @@ -863,6 +1187,9 @@ SmallVector dialects; SmallVector opNames; + /// The reader used to process resources within the bytecode. + ResourceSectionReader resourceReader; + /// The table of strings referenced within the bytecode file. StringSectionReader stringReader; @@ -914,11 +1241,12 @@ } sectionDatas[sectionID] = sectionData; } - // Check that all of the sections were found. + // Check that all of the required sections were found. for (int i = 0; i < bytecode::Section::kNumSections; ++i) { - if (!sectionDatas[i]) { + bytecode::Section::ID sectionID = static_cast(i); + if (!sectionDatas[i] && !isSectionOptional(sectionID)) { return reader.emitError("missing data for top-level section: ", - toString(bytecode::Section::ID(i))); + toString(sectionID)); } } @@ -931,6 +1259,12 @@ if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) return failure(); + // Process the resource section if present. + if (failed(parseResourceSection( + sectionDatas[bytecode::Section::kResource], + sectionDatas[bytecode::Section::kResourceOffset]))) + return failure(); + // Process the attribute and type section. if (failed(attrTypeReader.initialize( dialects, *sectionDatas[bytecode::Section::kAttrType], @@ -1008,6 +1342,31 @@ return *opName->opName; } +//===----------------------------------------------------------------------===// +// Resource Section + +LogicalResult BytecodeReader::parseResourceSection( + Optional> resourceData, + Optional> resourceOffsetData) { + // Ensure both sections are either present or not. + if (resourceData.has_value() != resourceOffsetData.has_value()) { + if (resourceOffsetData) + return emitError(fileLoc, "unexpected resource offset section when " + "resource section is not present"); + return emitError( + fileLoc, + "expected resource offset section when resource section is present"); + } + + // If the resource sections are absent, there is nothing to do. + if (!resourceData) + return success(); + + // Initialize the resource reader with the resource sections. + return resourceReader.initialize(fileLoc, config, dialects, stringReader, + *resourceData, *resourceOffsetData); +} + //===----------------------------------------------------------------------===// // IR Section diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -23,6 +23,29 @@ using namespace mlir; using namespace mlir::bytecode::detail; +//===----------------------------------------------------------------------===// +// BytecodeWriterConfig +//===----------------------------------------------------------------------===// + +struct BytecodeWriterConfig::Impl { + Impl(StringRef producer) : producer(producer) {} + + /// The producer of the bytecode. + StringRef producer; + + /// A collection of non-dialect resource printers. + SmallVector> externalResourcePrinters; +}; + +BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer) + : impl(std::make_unique(producer)) {} +BytecodeWriterConfig::~BytecodeWriterConfig() = default; + +void BytecodeWriterConfig::attachResourcePrinter( + std::unique_ptr printer) { + impl->externalResourcePrinters.emplace_back(std::move(printer)); +} + //===----------------------------------------------------------------------===// // EncodingEmitter //===----------------------------------------------------------------------===// @@ -56,6 +79,48 @@ currentResult[offset - prevResultSize] = value; } + /// Emit the provided blob of data, which is owned by the caller and is + /// guaranteed to not die before the end of the bytecode process. + void emitOwnedBlob(ArrayRef data) { + // Push the current buffer before adding the provided data. + appendResult(std::move(currentResult)); + appendOwnedResult(data); + } + + /// Emit the provided blob of data that has the given alignment, which is + /// owned by the caller and is guaranteed to not die before the end of the + /// bytecode process. The alignment value is also encoded, making it available + /// on load. + void emitOwnedBlobAndAlignment(ArrayRef data, uint32_t alignment) { + emitVarInt(alignment); + emitVarInt(data.size()); + + alignTo(alignment); + emitOwnedBlob(data); + } + void emitOwnedBlobAndAlignment(ArrayRef data, uint32_t alignment) { + ArrayRef castedData(reinterpret_cast(data.data()), + data.size()); + emitOwnedBlobAndAlignment(castedData, alignment); + } + + /// Align the emitter to the given alignment. + void alignTo(unsigned alignment) { + if (alignment < 2) + return; + assert(llvm::isPowerOf2_32(alignment) && "expected valid alignment"); + + // Check to see if we need to emit any padding bytes to meet the desired + // alignment. + size_t curOffset = size(); + size_t paddingSize = llvm::alignTo(curOffset, alignment) - curOffset; + while (paddingSize--) + emitByte(bytecode::kAlignmentByte); + + // Keep track of the maximum required alignment. + requiredAlignment = std::max(requiredAlignment, alignment); + } + //===--------------------------------------------------------------------===// // Integer Emission @@ -119,15 +184,37 @@ /// Emit a nested section of the given code, whose contents are encoded in the /// provided emitter. void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) { - // Emit the section code and length. + // Emit the section code and length. The high bit of the code is used to + // indicate whether the section alignment is present, so save an offset to + // it. + uint64_t codeOffset = currentResult.size(); emitByte(code); emitVarInt(emitter.size()); + // Integrate the alignment of the section into this emitter if necessary. + unsigned emitterAlign = emitter.requiredAlignment; + if (emitterAlign > 1) { + if (size() & (emitterAlign - 1)) { + emitVarInt(emitterAlign); + alignTo(emitterAlign); + + // Indicate that we needed to align the section, the high bit of the + // code field is used for this. + currentResult[codeOffset] |= 0b10000000; + } else { + // Otherwise, if we happen to be at a compatible offset, we just + // remember that we need this alignment. + requiredAlignment = std::max(requiredAlignment, emitterAlign); + } + } + // Push our current buffer and then merge the provided section body into // ours. appendResult(std::move(currentResult)); for (std::vector &result : emitter.prevResultStorage) - appendResult(std::move(result)); + prevResultStorage.push_back(std::move(result)); + llvm::append_range(prevResultList, emitter.prevResultList); + prevResultSize += emitter.prevResultSize; appendResult(std::move(emitter.currentResult)); } @@ -140,9 +227,16 @@ /// Append a new result buffer to the current contents. void appendResult(std::vector &&result) { - prevResultSize += result.size(); + if (result.empty()) + return; prevResultStorage.emplace_back(std::move(result)); - prevResultList.emplace_back(prevResultStorage.back()); + appendOwnedResult(prevResultStorage.back()); + } + void appendOwnedResult(ArrayRef result) { + if (result.empty()) + return; + prevResultSize += result.size(); + prevResultList.emplace_back(result); } /// The result of the emitter currently being built. We refrain from building @@ -157,6 +251,9 @@ /// An up-to-date total size of all of the buffers within `prevResultList`. /// This enables O(1) size checks of the current encoding. size_t prevResultSize = 0; + + /// The highest required alignment for the start of this section. + unsigned requiredAlignment = 1; }; /// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need @@ -250,7 +347,8 @@ BytecodeWriter(Operation *op) : numberingState(op) {} /// Write the bytecode for the given root operation. - void write(Operation *rootOp, raw_ostream &os, StringRef producer); + void write(Operation *rootOp, raw_ostream &os, + const BytecodeWriterConfig::Impl &config); private: //===--------------------------------------------------------------------===// @@ -271,6 +369,12 @@ void writeRegion(EncodingEmitter &emitter, Region *region); void writeIRSection(EncodingEmitter &emitter, Operation *op); + //===--------------------------------------------------------------------===// + // Resources + + void writeResourceSection(Operation *op, EncodingEmitter &emitter, + const BytecodeWriterConfig::Impl &config); + //===--------------------------------------------------------------------===// // Strings @@ -288,7 +392,7 @@ } // namespace void BytecodeWriter::write(Operation *rootOp, raw_ostream &os, - StringRef producer) { + const BytecodeWriterConfig::Impl &config) { EncodingEmitter emitter; // Emit the bytecode file header. This is how we identify the output as a @@ -299,7 +403,7 @@ emitter.emitVarInt(bytecode::kVersion); // Emit the producer. - emitter.emitNulTerminatedString(producer); + emitter.emitNulTerminatedString(config.producer); // Emit the dialect section. writeDialectSection(emitter); @@ -310,6 +414,9 @@ // Emit the IR section. writeIRSection(emitter, rootOp); + // Emit the resources section. + writeResourceSection(rootOp, emitter, config); + // Emit the string section. writeStringSection(emitter); @@ -386,6 +493,10 @@ emitter.emitVarInt(numberingState.getNumber(type)); } + void writeResourceHandle(const AsmDialectResourceHandle &resource) override { + emitter.emitVarInt(numberingState.getNumber(resource)); + } + //===--------------------------------------------------------------------===// // Primitives //===--------------------------------------------------------------------===// @@ -613,6 +724,111 @@ emitter.emitSection(bytecode::Section::kIR, std::move(irEmitter)); } +//===----------------------------------------------------------------------===// +// Resources + +namespace { +/// This class represents a resource builder implementation for the MLIR +/// bytecode format. +class ResourceBuilder : public AsmResourceBuilder { +public: + using PostProcessFn = function_ref; + + ResourceBuilder(EncodingEmitter &emitter, StringSectionBuilder &stringSection, + PostProcessFn postProcessFn) + : emitter(emitter), stringSection(stringSection), + postProcessFn(postProcessFn) {} + ~ResourceBuilder() override = default; + + void buildBlob(StringRef key, ArrayRef data, + uint32_t dataAlignment) final { + emitter.emitOwnedBlobAndAlignment(data, dataAlignment); + postProcessFn(key, AsmResourceEntryKind::Blob); + } + void buildBool(StringRef key, bool data) final { + emitter.emitByte(data); + postProcessFn(key, AsmResourceEntryKind::Bool); + } + void buildString(StringRef key, StringRef data) final { + emitter.emitVarInt(stringSection.insert(data)); + postProcessFn(key, AsmResourceEntryKind::String); + } + +private: + EncodingEmitter &emitter; + StringSectionBuilder &stringSection; + PostProcessFn postProcessFn; +}; +} // namespace + +void BytecodeWriter::writeResourceSection( + Operation *op, EncodingEmitter &emitter, + const BytecodeWriterConfig::Impl &config) { + EncodingEmitter resourceEmitter; + EncodingEmitter resourceOffsetEmitter; + uint64_t prevOffset = 0; + SmallVector> + curResourceEntries; + + // Functor used to process the offset for a resource of `kind` defined by + // 'key'. + auto appendResourceOffset = [&](StringRef key, AsmResourceEntryKind kind) { + uint64_t curOffset = resourceEmitter.size(); + curResourceEntries.emplace_back(key, kind, curOffset - prevOffset); + prevOffset = curOffset; + }; + + // Functor used to emit a resource group defined by 'key'. + auto emitResourceGroup = [&](uint64_t key) { + resourceOffsetEmitter.emitVarInt(key); + resourceOffsetEmitter.emitVarInt(curResourceEntries.size()); + for (auto [key, kind, size] : curResourceEntries) { + resourceOffsetEmitter.emitVarInt(stringSection.insert(key)); + resourceOffsetEmitter.emitVarInt(size); + resourceOffsetEmitter.emitByte(kind); + } + }; + + // Builder used to emit resources. + ResourceBuilder entryBuilder(resourceEmitter, stringSection, + appendResourceOffset); + + // Emit the external resource entries. + resourceOffsetEmitter.emitVarInt(config.externalResourcePrinters.size()); + for (const auto &printer : config.externalResourcePrinters) { + curResourceEntries.clear(); + printer->buildResources(op, entryBuilder); + emitResourceGroup(stringSection.insert(printer->getName())); + } + + // Emit the dialect resource entries. + for (DialectNumbering &dialect : numberingState.getDialects()) { + if (!dialect.asmInterface) + continue; + curResourceEntries.clear(); + dialect.asmInterface->buildResources(op, dialect.resources, entryBuilder); + + // Emit the declaration resources for this dialect, these didn't get emitted + // by the interface. These resources don't have data attached, so just use a + // "blob" kind as a placeholder. + for (const auto &resource : dialect.resourceMap) + if (resource.second->isDeclaration) + appendResourceOffset(resource.first, AsmResourceEntryKind::Blob); + + // Emit the resource group for this dialect. + if (!curResourceEntries.empty()) + emitResourceGroup(dialect.number); + } + + // If we didn't emit any resource groups, elide the resource sections. + if (resourceOffsetEmitter.size() == 0) + return; + + emitter.emitSection(bytecode::Section::kResourceOffset, + std::move(resourceOffsetEmitter)); + emitter.emitSection(bytecode::Section::kResource, std::move(resourceEmitter)); +} + //===----------------------------------------------------------------------===// // Strings @@ -627,7 +843,7 @@ //===----------------------------------------------------------------------===// void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, - StringRef producer) { + const BytecodeWriterConfig &config) { BytecodeWriter writer(op); - writer.write(op, os, producer); + writer.write(op, os, config.getImpl()); } diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h --- a/mlir/lib/Bytecode/Writer/IRNumbering.h +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -14,8 +14,10 @@ #ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H #define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H -#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OpImplementation.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/StringMap.h" namespace mlir { class BytecodeDialectInterface; @@ -76,6 +78,25 @@ unsigned refCount = 1; }; +//===----------------------------------------------------------------------===// +// Dialect Resource Numbering +//===----------------------------------------------------------------------===// + +/// This class represents a numbering entry for a dialect resource. +struct DialectResourceNumbering { + DialectResourceNumbering(std::string key) : key(std::move(key)) {} + + /// The key used to reference this resource. + std::string key; + + /// The number assigned to this resource. + unsigned number = 0; + + /// A flag indicating if this resource is only a declaration, not a full + /// definition. + bool isDeclaration = true; +}; + //===----------------------------------------------------------------------===// // Dialect Numbering //===----------------------------------------------------------------------===// @@ -93,6 +114,15 @@ /// The bytecode dialect interface of the dialect if defined. const BytecodeDialectInterface *interface = nullptr; + + /// The asm dialect interface of the dialect if defined. + const OpAsmDialectInterface *asmInterface = nullptr; + + /// The referenced resources of this dialect. + SetVector resources; + + /// A mapping from resource key to the corresponding resource numbering entry. + llvm::MapVector resourceMap; }; //===----------------------------------------------------------------------===// @@ -134,6 +164,10 @@ assert(valueIDs.count(value) && "value not numbered"); return valueIDs[value]; } + unsigned getNumber(const AsmDialectResourceHandle &resource) { + assert(dialectResources.count(resource) && "resource not numbered"); + return dialectResources[resource]->number; + } /// Return the block and value counts of the given region. std::pair getBlockValueCount(Region *region) { @@ -162,6 +196,12 @@ void number(Region ®ion); void number(Type type); + /// Number the given dialect resources. + void number(Dialect *dialect, ArrayRef resources); + + /// Finalize the numberings of any dialect resources. + void finalizeDialectResourceNumberings(Operation *rootOp); + /// Mapping from IR to the respective numbering entries. DenseMap attrs; DenseMap opNames; @@ -172,10 +212,16 @@ std::vector orderedOpNames; std::vector orderedTypes; + /// A mapping from dialect resource handle to the numbering for the referenced + /// resource. + llvm::DenseMap + dialectResources; + /// Allocators used for the various numbering entries. llvm::SpecificBumpPtrAllocator attrAllocator; llvm::SpecificBumpPtrAllocator dialectAllocator; llvm::SpecificBumpPtrAllocator opNameAllocator; + llvm::SpecificBumpPtrAllocator resourceAllocator; llvm::SpecificBumpPtrAllocator typeAllocator; /// The value ID for each Block and Value. diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -9,6 +9,7 @@ #include "IRNumbering.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" @@ -24,6 +25,9 @@ void writeAttribute(Attribute attr) override { state.number(attr); } void writeType(Type type) override { state.number(type); } + void writeResourceHandle(const AsmDialectResourceHandle &resource) override { + state.number(resource.getDialect(), resource); + } /// Stubbed out methods that are not used for numbering. void writeVarInt(uint64_t) override {} @@ -148,6 +152,9 @@ groupByDialectPerByte(llvm::makeMutableArrayRef(orderedAttrs)); groupByDialectPerByte(llvm::makeMutableArrayRef(orderedOpNames)); groupByDialectPerByte(llvm::makeMutableArrayRef(orderedTypes)); + + // Finalize the numbering of the dialect resources. + finalizeDialectResourceNumberings(op); } void IRNumberingState::number(Attribute attr) { @@ -174,12 +181,23 @@ // dummy writing to number any nested components. if (const auto *interface = numbering->dialect->interface) { // TODO: We don't allow custom encodings for mutable attributes right now. - if (attr.hasTrait()) - return; - - NumberingDialectWriter writer(*this); - (void)interface->writeAttribute(attr, writer); + if (!attr.hasTrait()) { + NumberingDialectWriter writer(*this); + if (succeeded(interface->writeAttribute(attr, writer))) + return; + } } + // If this attribute will be emitted using the fallback, number the nested + // dialect resources. We don't number everything (e.g. no nested + // attributes/types), because we don't want to encode things we won't decode + // (the textual format can't really share much). + AsmState tempState(attr.getContext()); + llvm::raw_null_ostream dummyOS; + attr.print(dummyOS, tempState); + + // Number the used dialect resources. + for (const auto &it : tempState.getDialectResources()) + number(it.getFirst(), it.getSecond().getArrayRef()); } void IRNumberingState::number(Block &block) { @@ -203,6 +221,7 @@ if (!numbering) { numbering = &numberDialect(dialect->getNamespace()); numbering->interface = dyn_cast(dialect); + numbering->asmInterface = dyn_cast(dialect); } return *numbering; } @@ -292,10 +311,92 @@ // writing to number any nested components. if (const auto *interface = numbering->dialect->interface) { // TODO: We don't allow custom encodings for mutable types right now. - if (type.hasTrait()) + if (!type.hasTrait()) { + NumberingDialectWriter writer(*this); + if (succeeded(interface->writeType(type, writer))) + return; + } + } + // If this type will be emitted using the fallback, number the nested dialect + // resources. We don't number everything (e.g. no nested attributes/types), + // because we don't want to encode things we won't decode (the textual format + // can't really share much). + AsmState tempState(type.getContext()); + llvm::raw_null_ostream dummyOS; + type.print(dummyOS, tempState); + + // Number the used dialect resources. + for (const auto &it : tempState.getDialectResources()) + number(it.getFirst(), it.getSecond().getArrayRef()); +} + +void IRNumberingState::number(Dialect *dialect, + ArrayRef resources) { + DialectNumbering &dialectNumber = numberDialect(dialect); + assert( + dialectNumber.asmInterface && + "expected dialect owning a resource to implement OpAsmDialectInterface"); + + for (const auto &resource : resources) { + // Check if this is a newly seen resource. + if (!dialectNumber.resources.insert(resource)) return; - NumberingDialectWriter writer(*this); - (void)interface->writeType(type, writer); + auto *numbering = + new (resourceAllocator.Allocate()) DialectResourceNumbering( + dialectNumber.asmInterface->getResourceKey(resource)); + dialectNumber.resourceMap.insert({numbering->key, numbering}); + dialectResources.try_emplace(resource, numbering); + } +} + +namespace { +/// A dummy resource builder used to number dialect resources. +struct NumberingResourceBuilder : public AsmResourceBuilder { + NumberingResourceBuilder(DialectNumbering *dialect, unsigned &nextResourceID) + : dialect(dialect), nextResourceID(nextResourceID) {} + ~NumberingResourceBuilder() override = default; + + void buildBlob(StringRef key, ArrayRef, uint32_t) final { + numberEntry(key); + } + void buildBool(StringRef key, bool) final { numberEntry(key); } + void buildString(StringRef key, StringRef) final { + // TODO: We could pre-number the value string here as well. + numberEntry(key); + } + + /// Number the dialect entry for the given key. + void numberEntry(StringRef key) { + // TODO: We could pre-number resource key strings here as well. + + auto it = dialect->resourceMap.find(key); + if (it != dialect->resourceMap.end()) { + it->second->number = nextResourceID++; + it->second->isDeclaration = false; + } + } + + DialectNumbering *dialect; + unsigned &nextResourceID; +}; +} // namespace + +void IRNumberingState::finalizeDialectResourceNumberings(Operation *rootOp) { + unsigned nextResourceID = 0; + for (DialectNumbering &dialect : getDialects()) { + if (!dialect.asmInterface) + continue; + NumberingResourceBuilder entryBuilder(&dialect, nextResourceID); + dialect.asmInterface->buildResources(rootOp, dialect.resources, + entryBuilder); + + // Number any resources that weren't added by the dialect. This can happen + // if there was no backing data to the resource, but we still want these + // resource references to roundtrip, so we number them and indicate that the + // data is missing. + for (const auto &it : dialect.resourceMap) + if (it.second->isDeclaration) + it.second->number = nextResourceID++; } } diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1271,6 +1271,18 @@ AsmResourceParser::~AsmResourceParser() = default; AsmResourcePrinter::~AsmResourcePrinter() = default; +StringRef mlir::toString(AsmResourceEntryKind kind) { + switch (kind) { + case AsmResourceEntryKind::Blob: + return "blob"; + case AsmResourceEntryKind::Bool: + return "bool"; + case AsmResourceEntryKind::String: + return "string"; + } + llvm_unreachable("unknown AsmResourceEntryKind"); +} + //===----------------------------------------------------------------------===// // AsmState //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -116,6 +117,12 @@ /// UnknownLoc { /// } kUnknownLoc = 15, + + /// DenseResourceElementsAttr { + /// type: Type, + /// handle: ResourceHandle + /// } + kDenseResourceElementsAttr = 16, }; /// This enum contains marker codes used to indicate which type is currently @@ -272,6 +279,8 @@ Attribute readAttribute(DialectBytecodeReader &reader) const override; ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const; + DenseResourceElementsAttr + readDenseResourceElementsAttr(DialectBytecodeReader &reader) const; DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const; FloatAttr readFloatAttr(DialectBytecodeReader &reader) const; IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const; @@ -289,6 +298,8 @@ LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const override; void write(ArrayAttr attr, DialectBytecodeWriter &writer) const; + void write(DenseResourceElementsAttr attr, + DialectBytecodeWriter &writer) const; void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const; void write(IntegerAttr attr, DialectBytecodeWriter &writer) const; void write(FloatAttr attr, DialectBytecodeWriter &writer) const; @@ -381,6 +392,8 @@ return readNameLoc(reader); case builtin_encoding::kUnknownLoc: return UnknownLoc::get(getContext()); + case builtin_encoding::kDenseResourceElementsAttr: + return readDenseResourceElementsAttr(reader); default: reader.emitError() << "unknown builtin attribute code: " << code; return Attribute(); @@ -390,9 +403,12 @@ LogicalResult BuiltinDialectBytecodeInterface::writeAttribute( Attribute attr, DialectBytecodeWriter &writer) const { return TypeSwitch(attr) - .Case([&](auto attr) { + .Case([&](auto attr) { + write(attr, writer); + return success(); + }) + .Case([&](auto attr) { write(attr, writer); return success(); }) @@ -425,6 +441,31 @@ writer.writeAttributes(attr.getValue()); } +//===----------------------------------------------------------------------===// +// DenseResourceElementsAttr + +DenseResourceElementsAttr +BuiltinDialectBytecodeInterface::readDenseResourceElementsAttr( + DialectBytecodeReader &reader) const { + ShapedType type; + if (failed(reader.readType(type))) + return DenseResourceElementsAttr(); + + FailureOr handle = + reader.readResourceHandle(); + if (failed(handle)) + return DenseResourceElementsAttr(); + + return DenseResourceElementsAttr::get(type, *handle); +} + +void BuiltinDialectBytecodeInterface::write( + DenseResourceElementsAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kDenseResourceElementsAttr); + writer.writeType(attr.getType()); + writer.writeResourceHandle(attr.getRawHandle()); +} + //===----------------------------------------------------------------------===// // DictionaryAttr diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -32,7 +32,7 @@ // ID // RUN: not mlir-opt %S/invalid-structure-section-id-unknown.mlirbc 2>&1 | FileCheck %s --check-prefix=SECTION_ID_UNKNOWN -// SECTION_ID_UNKNOWN: invalid section ID: 255 +// SECTION_ID_UNKNOWN: invalid section ID: 127 //===--------------------------------------------------------------------===// // Length diff --git a/mlir/test/Bytecode/resources.mlir b/mlir/test/Bytecode/resources.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/resources.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: s390x- + +// CHECK-LABEL: @TestDialectResources +module @TestDialectResources attributes { + // CHECK: bytecode.test = dense_resource : tensor<2xui32> + // CHECK: bytecode.test2 = dense_resource : tensor<4xf64> + // CHECK: bytecode.test3 = dense_resource : tensor<4xf64> + bytecode.test = dense_resource : tensor<2xui32>, + bytecode.test2 = dense_resource : tensor<4xf64>, + bytecode.test3 = dense_resource : tensor<4xf64> +} {} + +// CHECK: builtin: { +// CHECK-NEXT: resource: "0x08000000010000000000000002000000000000000300000000000000" +// CHECK-NEXT: resource_2: "0x08000000010000000000000002000000000000000300000000000000" + +{-# + dialect_resources: { + builtin: { + resource: "0x08000000010000000000000002000000000000000300000000000000", + resource_2: "0x08000000010000000000000002000000000000000300000000000000" + } + } +#-}