diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md new file mode 100644 --- /dev/null +++ b/mlir/docs/BytecodeFormat.md @@ -0,0 +1,314 @@ +# MLIR Bytecode Format + +This documents describes the MLIR bytecode format and its encoding. + +[TOC] + +## Magic Number + +MLIR uses the following four-byte magic number to indicate bytecode files: + +'\[‘M’8, ‘L’8, ‘ï’8, ‘R’8\]' + +In hex: + +'\[‘4D’8, ‘4C’8, ‘EF’8, ‘52’8\]' + +## Format Overview + +An MLIR Bytecode file is comprised of a byte stream, with a few simple +structural concepts layered on top. + +### Primitives + +#### Fixed-Width Integers + +``` + byte ::= `0x00`...`0xFF` +``` + +Fixed width integers are unsigned integers of a known byte size. The values are +stored in little-endian byte order. + +TODO: Add larger fixed width integers as necessary. + +#### Variable-Width Integers + +Variable width integers, or `VarInt`s, provide a compact representation for +integers. Each encoded VarInt consists of one to nine bytes, which together +represent a single 64-bit value. The MLIR bytecode utilizes the "PrefixVarInt" +encoding for VarInts. This encoding is a variant of the +[LEB128 ("Little-Endian Base 128")](https://en.wikipedia.org/wiki/LEB128) +encoding, where each byte of the encoding provides up to 7 bits for the value, +with the remaining bit used to store a tag indicating the number of bytes used +for the encoding. This means that small unsigned integers (less than 2^7) may be +stored in one byte, unsigned integers up to 2^14 may be stored in two bytes, +etc. + +The first byte of the encoding includes a length prefix in the low bits. This +prefix is a bit sequence of '0's followed by a terminal '1', or the end of the +byte. The number of '0' bits indicate the number of _additional_ bytes, not +including the prefix byte, used to encode the value. All of the remaining bits +in the first byte, along with all of the bits in the additional bytes, provide +the value of the integer. Below are the various possible encodings of the prefix +byte: + +``` +xxxxxxx1: 7 value bits, the encoding uses 1 byte +xxxxxx10: 14 value bits, the encoding uses 2 bytes +xxxxx100: 21 value bits, the encoding uses 3 bytes +xxxx1000: 28 value bits, the encoding uses 4 bytes +xxx10000: 35 value bits, the encoding uses 5 bytes +xx100000: 42 value bits, the encoding uses 6 bytes +x1000000: 49 value bits, the encoding uses 7 bytes +10000000: 56 value bits, the encoding uses 8 bytes +00000000: 64 value bits, the encoding uses 9 bytes +``` + +#### Strings + +Strings are blobs of characters with an associated length. + +### Sections + +``` +section { + id: byte + length: varint +} +``` + +Sections are a mechanism for grouping data within the bytecode. The enable +delayed processing, which is useful for out-of-order processing of data, +lazy-loading, and more. Each section contains a Section ID and a length (which +allowing for skipping over the section). + +TODO: Sections should also carry an optional alignment. Add this when necessary. + +## MLIR Encoding + +Given the generic structure of MLIR, the bytecode encoding is actually fairly +simplistic. It effectively maps to the core components of MLIR. + +### Top Level Structure + +The top-level structure of the bytecode contains the 4-byte "magic number", a +version number, a null-terminated producer string, and a list of sections. Each +section is currently only expected to appear once within a bytecode file. + +``` +bytecode { + magic: "MLïR", + version: varint, + producer: string, + sections: section[] +} +``` + +### String Section + +``` +strings { + numStrings: varint, + reverseStringLengths: varint[], + stringData: byte[] +} +``` + +The string section contains a table of strings referenced within the bytecode, +more easily enabling string sharing. This section is encoded first with the +total number of strings, followed by the sizes of each of the individual strings +in reverse order. The remaining encoding contains a single blob containing all +of the strings concatenated together. + +### Dialect Section + +The dialect section of the bytecode contains all of the dialects referenced +within the encoded IR, and some information about the components of those +dialects that were also referenced. + +``` +dialect_section { + numDialects: varint, + dialectNames: varint[], + opNames: op_name_group[] +} + +op_name_group { + dialect: varint, + numOpNames: varint, + opNames: varint[] +} +``` + +Dialects are encoded as indexes to the name string within the string section. +Operation names are encoded in groups by dialect, with each group containing the +dialect, the number of operation names, and the array of indexes to each name +within the string section. + +### Attribute/Type Sections + +Attributes and types are encoded using two [sections](#sections), one section +(`attr_type_section`) containing the actual encoded representation, and another +section (`attr_type_offset_section`) containing the offsets of each encoded +attribute/type into the previous section. This structure allows for attributes +and types to always be lazily loaded on demand. + +``` +attr_type_section { + attrs: attribute[], + types: type[] +} +attr_type_offset_section { + numAttrs: varint, + numTypes: varint, + offsets: attr_type_offset_group[] +} + +attr_type_offset_group { + dialect: varint, + numElements: varint, + offsets: varint[] // (offset << 1) | (hasCustomEncoding) +} + +attribute { + encoding: ... +} +type { + encoding: ... +} +``` + +Each `offset` in the `attr_type_offset_section` above is the size of the +encoding for the attribute or type and a flag indicating if the encoding uses +the textual assembly format, or a custom bytecode encoding. We avoid using the +direct offset into the `attr_type_section`, as a smaller relative offsets +provides more effective compression. Attributes and types are grouped by +dialect, with each `attr_type_offset_group` in the offset section containing the +corresponding parent dialect, number of elements, and offsets for each element +within the group. + +#### Attribute/Type Encodings + +In the abstract, an attribute/type is encoded in one of two possible ways: via +its assembly format, or via a custom dialect defined encoding. + +##### Assembly Format Fallback + +In the case where a dialect does not define a method for encoding the attribute +or type, the textual assembly format of that attribute or type is used as a +fallback. For example, a type of `!bytecode.type` would be encoded as the null +terminated string "!bytecode.type". This ensures that every attribute and type +may be encoded, even if the owning dialect has not yet opted in to a more +efficient serialization. + +TODO: We shouldn't redundantly encode the dialect name here, we should use a +reference to the parent dialect instead. + +##### Dialect Defined Encoding + +TODO: This is not yet supported. + +### IR Section + +The IR section contains the encoded form of operations within the bytecode. + +#### Operation Encoding + +``` +op { + name: varint, + encodingMask: byte, + location: varint, + + attrDict: varint?, + + numResults: varint?, + resultTypes: varint[], + + numOperands: varint?, + operands: varint[], + + numSuccessors: varint?, + successors: varint[], + + regionEncoding: varint?, // (numRegions << 1) | (isIsolatedFromAbove) + regions: region[] +} +``` + +The encoding of an operation is important because this is generally the most +commonly appearing structure in the bytecode. A single encoding is used for +every type of operation. Given this prevelance, many of the fields of an +operation are optional. The `encodingMask` field is a bitmask which indicates +which of the components of the operation are present. + +##### Location + +The location is encoded as the index of the location within the attribute table. + +##### Attributes + +If the operation has attribues, the index of the operation attribute dictionary +within the attribute table is encoded. + +##### Results + +If the operation has results, the number of results and the indexes of the +result types within the type table are encoded. + +##### Operands + +If the operation has operands, the number of operands and the value index of +each operand is encoded. This value index is the relative ordering of the +definition of that value from the start of the first ancestor isolated region. + +##### Successors + +If the operation has successors, the number of successors and the indexes of the +successor blocks within the parent region are encoded. + +##### Regions + +If the operation has regions, the number of regions and if the regions are +isolated from above are encoded together in a single varint. Afterwards, each +region is encoded inline. + +#### Region Encoding + +``` +region { + numBlocks: varint, + + numValues: varint?, + blocks: block[] +} +``` + +A region is encoded first with the number of blocks within. If the region is +non-empty, the number of values defined directly within the region are encoded, +followed by the blocks of the region. + +#### Block Encoding + +``` +block { + encoding: varint, // (numOps << 1) | (hasBlockArgs) + arguments: block_arguments?, // Optional based on encoding + ops : op[] +} + +block_arguments { + numArgs: varint?, + args: block_argument[] +} + +block_argument { + typeIndex: varint, + location: varint +} +``` + +A block is encoded with an array of operations and block arguments. The first +field is an encoding that combines the number of operations in the block, with a +flag indicating if the block has arguments. diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/BytecodeReader.h @@ -0,0 +1,34 @@ +//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines interfaces to read MLIR bytecode files/streams. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BYTECODE_BYTECODEREADER_H +#define MLIR_BYTECODE_BYTECODEREADER_H + +#include "mlir/IR/AsmState.h" +#include "mlir/Support/LLVM.h" + +namespace llvm { +class MemoryBufferRef; +} // namespace llvm + +namespace mlir { +/// Returns true if the given buffer starts with the magic bytes that signal +/// MLIR bytecode. +bool isBytecode(llvm::MemoryBufferRef buffer); + +/// Read the operations defined within the given memory buffer, containing MLIR +/// bytecode, into the provided block. +LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, + const ParserConfig &config); +} // namespace mlir + +#endif // MLIR_BYTECODE_BYTECODEREADER_H diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -0,0 +1,36 @@ +//===- BytecodeWriter.h - MLIR Bytecode Writer ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines interfaces to write MLIR bytecode files/streams. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BYTECODE_BYTECODEWRITER_H +#define MLIR_BYTECODE_BYTECODEWRITER_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +class Operation; + +//===----------------------------------------------------------------------===// +// Entry Points +//===----------------------------------------------------------------------===// + +/// Write the bytecode for the given operation to the provided output stream. +/// For streams where it matters, the given stream should be in "binary" mode. +/// `producer` is an optional string that can be used to identify the producer +/// of the bytecode when reading. It has no functional effect on the bytecode +/// serialization. +void writeBytecodeToFile(Operation *op, raw_ostream &os, + StringRef producer = "MLIR" LLVM_VERSION_STRING); + +} // namespace mlir + +#endif // MLIR_BYTECODE_BYTECODEWRITER_H diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -642,11 +642,11 @@ OperationState(Location location, OperationName name); OperationState(Location location, OperationName name, ValueRange operands, - TypeRange types, ArrayRef attributes, + TypeRange types, ArrayRef attributes = {}, BlockRange successors = {}, MutableArrayRef> regions = {}); OperationState(Location location, StringRef name, ValueRange operands, - TypeRange types, ArrayRef attributes, + TypeRange types, ArrayRef attributes = {}, BlockRange successors = {}, MutableArrayRef> regions = {}); diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -50,13 +50,15 @@ /// - preloadDialectsInContext will trigger the upfront loading of all /// dialects from the global registry in the MLIRContext. This option is /// deprecated and will be removed soon. +/// - emitBytecode will generate bytecode output instead of text. LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, std::unique_ptr buffer, const PassPipelineCLParser &passPipeline, DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext = false); + bool preloadDialectsInContext = false, + bool emitBytecode = false); /// Support a callback to setup the pass manager. /// - passManagerSetupFn is the callback invoked to setup the pass manager to @@ -67,7 +69,8 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext = false); + bool preloadDialectsInContext = false, + bool emitBytecode = false); /// Implementation for tools like `mlir-opt`. /// - toolName is used for the header displayed by `--help`. diff --git a/mlir/lib/Bytecode/CMakeLists.txt b/mlir/lib/Bytecode/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Reader) +add_subdirectory(Writer) diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Encoding.h @@ -0,0 +1,81 @@ +//===- Encoding.h - MLIR binary format encoding information -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines enum values describing the structure of MLIR bytecode +// files. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_BYTECODE_ENCODING_H +#define LIB_MLIR_BYTECODE_ENCODING_H + +#include + +namespace mlir { +namespace bytecode { +//===----------------------------------------------------------------------===// +// General constants +//===----------------------------------------------------------------------===// + +enum { + /// The current bytecode version. + kVersion = 0, +}; + +//===----------------------------------------------------------------------===// +// Sections +//===----------------------------------------------------------------------===// + +namespace Section { +enum ID : uint8_t { + /// This section contains strings referenced within the bytecode. + kString = 0, + + /// This section contains the dialects referenced within an IR module. + kDialect = 1, + + /// This section contains the attributes and types referenced within an IR + /// module. + kAttrType = 2, + + /// This section contains the offsets for the attribute and types within the + /// AttrType section. + kAttrTypeOffset = 3, + + /// This section contains the list of operations serialized into the bytecode, + /// and their nested regions/operations. + kIR = 4, + + /// The total number of section types. + kNumSections = 5, +}; +} // namespace Section + +//===----------------------------------------------------------------------===// +// IR Section +//===----------------------------------------------------------------------===// + +/// This enum represents a mask of all of the potential components of an +/// operation. This mask is used when encoding an operation to indicate which +/// components are present in the bytecode. +namespace OpEncodingMask { +enum : uint8_t { + // clang-format off + kHasAttrs = 0b00000001, + kHasResults = 0b00000010, + kHasOperands = 0b00000100, + kHasSuccessors = 0b00001000, + kHasInlineRegions = 0b00010000, + // clang-format on +}; +} // namespace OpEncodingMask + +} // namespace bytecode +} // namespace mlir + +#endif diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -0,0 +1,1222 @@ +//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// TODO: Support for big-endian architectures. +// TODO: Properly preserve use lists of values. + +#include "mlir/Bytecode/BytecodeReader.h" +#include "../Encoding.h" +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/Verifier.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/SaveAndRestore.h" + +#define DEBUG_TYPE "mlir-bytecode-reader" + +using namespace mlir; + +/// Stringify the given section ID. +static std::string toString(bytecode::Section::ID sectionID) { + switch (sectionID) { + case bytecode::Section::kString: + return "String (0)"; + case bytecode::Section::kDialect: + return "Dialect (1)"; + case bytecode::Section::kAttrType: + return "AttrType (2)"; + case bytecode::Section::kAttrTypeOffset: + return "AttrTypeOffset (3)"; + case bytecode::Section::kIR: + return "IR (4)"; + default: + return ("Unknown (" + Twine(sectionID) + ")").str(); + } +} + +//===----------------------------------------------------------------------===// +// EncodingReader +//===----------------------------------------------------------------------===// + +namespace { +class EncodingReader { +public: + explicit EncodingReader(ArrayRef contents, Location fileLoc) + : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} + explicit EncodingReader(StringRef contents, Location fileLoc) + : EncodingReader({reinterpret_cast(contents.data()), + contents.size()}, + fileLoc) {} + + /// Returns true if the entire section has been read. + bool empty() const { return dataIt == dataEnd; } + + /// Returns the remaining size of the bytecode. + size_t size() const { return dataEnd - dataIt; } + + /// Emit an error using the given arguments. + template + LogicalResult emitError(Args &&...args) const { + return ::emitError(fileLoc).append(std::forward(args)...); + } + + /// Parse a single byte from the stream. + template + LogicalResult parseByte(T &value) { + if (empty()) + return emitError("attempting to parse a byte at the end of the bytecode"); + value = static_cast(*dataIt++); + return success(); + } + /// Parse a range of bytes of 'length' into the given result. + LogicalResult parseBytes(size_t length, ArrayRef &result) { + if (length > size()) { + return emitError("attempting to parse ", length, " bytes when only ", + size(), " remain"); + } + result = {dataIt, length}; + dataIt += length; + return success(); + } + /// Parse a range of bytes of 'length' into the given result, which can be + /// assumed to be large enough to hold `length`. + LogicalResult parseBytes(size_t length, uint8_t *result) { + if (length > size()) { + return emitError("attempting to parse ", length, " bytes when only ", + size(), " remain"); + } + memcpy(result, dataIt, length); + dataIt += length; + return success(); + } + + /// Parse a variable length encoded integer from the byte stream. The first + /// encoded byte contains a prefix in the low bits indicating the encoded + /// length of the value. This length prefix is a bit sequence of '0's followed + /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes + /// (not including the prefix byte). All remaining bits in the first byte, + /// along with all of the bits in additional bytes, provide the value of the + /// integer encoded in little-endian order. + LogicalResult parseVarInt(uint64_t &result) { + // Parse the first byte of the encoding, which contains the length prefix. + if (failed(parseByte(result))) + return failure(); + + // Handle the overwhelmingly common case where the value is stored in a + // single byte. In this case, the first bit is the `1` marker bit. + if (LLVM_LIKELY(result & 1)) { + result >>= 1; + return success(); + } + + // Handle the overwhelming uncommon case where the value required all 8 + // bytes (i.e. a really really big number). In this case, the marker byte is + // all zeros: `00000000`. + if (LLVM_UNLIKELY(result == 0)) + return parseBytes(sizeof(result), reinterpret_cast(&result)); + return parseMultiByteVarInt(result); + } + + /// Parse a variable length encoded integer whose low bit is used to encode an + /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. + LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { + if (failed(parseVarInt(result))) + return failure(); + flag = result & 1; + result >>= 1; + return success(); + } + + /// Skip the first `length` bytes within the reader. + LogicalResult skipBytes(size_t length) { + if (length > size()) { + return emitError("attempting to skip ", length, " bytes when only ", + size(), " remain"); + } + dataIt += length; + return success(); + } + + /// Parse a null-terminated string into `result` (without including the NUL + /// terminator). + LogicalResult parseNullTerminatedString(StringRef &result) { + const char *startIt = (const char *)dataIt; + const char *nulIt = (const char *)memchr(startIt, 0, size()); + if (!nulIt) + return emitError( + "malformed null-terminated string, no null character found"); + + result = StringRef(startIt, nulIt - startIt); + dataIt = (const uint8_t *)nulIt + 1; + return success(); + } + + /// Parse a section header, placing the kind of section in `sectionID` and the + /// contents of the section in `sectionData`. + LogicalResult parseSection(bytecode::Section::ID §ionID, + ArrayRef §ionData) { + size_t length; + if (failed(parseByte(sectionID)) || failed(parseVarInt(length))) + return failure(); + if (sectionID >= bytecode::Section::kNumSections) + return emitError("invalid section ID: ", unsigned(sectionID)); + + // Parse the actua section data now that we have its length. + return parseBytes(length, sectionData); + } + +private: + /// Parse a variable length encoded integer from the byte stream. This method + /// is a fallback when the number of bytes used to encode the value is greater + /// than 1, but less than the max (9). The provided `result` value can be + /// assumed to already contain the first byte of the value. + /// NOTE: This method is marked noinline to avoid pessimizing the common case + /// of single byte encoding. + LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) { + // Count the number of trailing zeros in the marker byte, this indicates the + // number of trailing bytes that are part of the value. We use `uint32_t` + // here because we only care about the first byte, and so that be actually + // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop + // implementation). + uint32_t numBytes = + llvm::countTrailingZeros(result, llvm::ZB_Undefined); + assert(numBytes > 0 && numBytes <= 7 && + "unexpected number of trailing zeros in varint encoding"); + + // Parse in the remaining bytes of the value. + if (failed(parseBytes(numBytes, reinterpret_cast(&result) + 1))) + return failure(); + + // Shift out the low-order bits that were used to mark how the value was + // encoded. + result >>= (numBytes + 1); + return success(); + } + + /// The current data iterator, and an iterator to the end of the buffer. + const uint8_t *dataIt, *dataEnd; + + /// A location for the bytecode used to report errors. + Location fileLoc; +}; +} // namespace + +/// Resolve an index into the given entry list. `entry` may either be a +/// reference, in which case it is assigned to the corresponding value in +/// `entries`, or a pointer, in which case it is assigned to the address of the +/// element in `entries`. +template +static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, + uint64_t index, T &entry, + StringRef entryStr) { + if (index >= entries.size()) + return reader.emitError("invalid ", entryStr, " index: ", index); + + // If the provided entry is a pointer, resolve to the address of the entry. + if constexpr (std::is_convertible_v, T>) + entry = entries[index]; + else + entry = &entries[index]; + return success(); +} + +/// Parse and resolve an index into the given entry list. +template +static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, + T &entry, StringRef entryStr) { + uint64_t entryIdx; + if (failed(reader.parseVarInt(entryIdx))) + return failure(); + return resolveEntry(reader, entries, entryIdx, entry, entryStr); +} + +//===----------------------------------------------------------------------===// +// BytecodeDialect +//===----------------------------------------------------------------------===// + +namespace { +/// This struct represents a dialect entry within the bytecode. +struct BytecodeDialect { + /// Load the dialect into the provided context if it hasn't been loaded yet. + /// Returns failure if the dialect couldn't be loaded *and* the provided + /// context does not allow unregistered dialects. The provided reader is used + /// for error emission if necessary. + LogicalResult load(EncodingReader &reader, MLIRContext *ctx) { + if (dialect) + return success(); + Dialect *loadedDialect = ctx->getOrLoadDialect(name); + if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { + return reader.emitError( + "dialect '", name, + "' is unknown. If this is intended, please call " + "allowUnregisteredDialects() on the MLIRContext, or use " + "-allow-unregistered-dialect with the MLIR tool used."); + } + dialect = loadedDialect; + return success(); + } + + /// The loaded dialect entry. This field is None if we haven't attempted to + /// load, nullptr if we failed to load, otherwise the loaded dialect. + Optional dialect; + + /// The name of the dialect. + StringRef name; +}; + +/// This struct represents an operation name entry within the bytecode. +struct BytecodeOperationName { + BytecodeOperationName(BytecodeDialect *dialect, StringRef name) + : dialect(dialect), name(name) {} + + /// The loaded operation name, or None if it hasn't been processed yet. + Optional opName; + + /// The dialect that owns this operation name. + BytecodeDialect *dialect; + + /// The name of the operation, without the dialect prefix. + StringRef name; +}; +} // namespace + +/// Parse a single dialect group encoded in the byte stream. +static LogicalResult parseDialectGrouping( + EncodingReader &reader, MutableArrayRef dialects, + function_ref entryCallback) { + // Parse the dialect and the number of entries in the group. + BytecodeDialect *dialect; + if (failed(parseEntry(reader, dialects, dialect, "dialect"))) + return failure(); + uint64_t numEntries; + if (failed(reader.parseVarInt(numEntries))) + return failure(); + + for (uint64_t i = 0; i < numEntries; ++i) + if (failed(entryCallback(dialect))) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Attribute/Type Reader +//===----------------------------------------------------------------------===// + +namespace { +/// This class provides support for reading attribute and type entries from the +/// bytecode. Attribute and Type entries are read lazily on demand, so we use +/// this reader to manage when to actually parse them from the bytecode. +class AttrTypeReader { + /// This class represents a single attribute or type entry. + template + struct Entry { + /// The entry, or null if it hasn't been resolved yet. + T entry = {}; + /// The parent dialect of this entry. + BytecodeDialect *dialect = nullptr; + /// A flag indicating if the entry was encoded using a custom encoding, + /// instead of using the textual assembly format. + bool hasCustomEncoding = false; + /// The raw data of this entry in the bytecode. + ArrayRef data; + }; + using AttrEntry = Entry; + using TypeEntry = Entry; + +public: + AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {} + + /// Initialize the attribute and type information within the reader. + LogicalResult initialize(MutableArrayRef dialects, + ArrayRef sectionData, + ArrayRef offsetSectionData); + + /// Resolve the attribute or type at the given index. Returns nullptr on + /// failure. + Attribute resolveAttribute(size_t index) { + return resolveEntry(attributes, index, "Attribute"); + } + Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } + +private: + /// Resolve the given entry at `index`. + template + T resolveEntry(SmallVectorImpl> &entries, size_t index, + StringRef entryType); + + /// Parse the value defined within the given reader. `code` indicates how the + /// entry was encoded. + LogicalResult parseEntry(EncodingReader &reader, bool hasCustomEncoding, + Attribute &result); + LogicalResult parseEntry(EncodingReader &reader, bool hasCustomEncoding, + Type &result); + + /// The set of attribute and type entries. + SmallVector attributes; + SmallVector types; + + /// A location used for error emission. + Location fileLoc; +}; +} // namespace + +LogicalResult +AttrTypeReader::initialize(MutableArrayRef dialects, + ArrayRef sectionData, + ArrayRef offsetSectionData) { + EncodingReader offsetReader(offsetSectionData, fileLoc); + + // Parse the number of attribute and type entries. + uint64_t numAttributes, numTypes; + if (failed(offsetReader.parseVarInt(numAttributes)) || + failed(offsetReader.parseVarInt(numTypes))) + return failure(); + attributes.resize(numAttributes); + types.resize(numTypes); + + // A functor used to accumulate the offsets for the entries in the given + // range. + uint64_t currentOffset = 0; + auto parseEntries = [&](auto &&range) { + size_t currentIndex = 0, endIndex = range.size(); + + // Parse an individual entry. + auto parseEntryFn = [&](BytecodeDialect *dialect) { + auto &entry = range[currentIndex++]; + + uint64_t entrySize; + if (failed(offsetReader.parseVarIntWithFlag(entrySize, + entry.hasCustomEncoding))) + return failure(); + + // Verify that the offset is actually valid. + if (currentOffset + entrySize > sectionData.size()) { + return offsetReader.emitError( + "Attribute or Type entry offset points past the end of section"); + } + + entry.data = sectionData.slice(currentOffset, entrySize); + entry.dialect = dialect; + currentOffset += entrySize; + return success(); + }; + while (currentIndex != endIndex) + if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) + return failure(); + return success(); + }; + + // Process each of the attributes, and then the types. + if (failed(parseEntries(attributes)) || failed(parseEntries(types))) + return failure(); + + // Ensure that we read everything from the section. + if (!offsetReader.empty()) { + return offsetReader.emitError( + "unexpected trailing data in the Attribute/Type offset section"); + } + return success(); +} + +template +T AttrTypeReader::resolveEntry(SmallVectorImpl> &entries, size_t index, + StringRef entryType) { + if (index >= entries.size()) { + emitError(fileLoc) << "invalid " << entryType << " index: " << index; + return {}; + } + + // If the entry has already been resolved, there is nothing left to do. + Entry &entry = entries[index]; + if (entry.entry) + return entry.entry; + + // Parse the entry. + EncodingReader reader(entry.data, fileLoc); + if (failed(parseEntry(reader, entry.hasCustomEncoding, entry.entry))) + return T(); + if (!reader.empty()) { + (void)reader.emitError("unexpected trailing bytes after " + entryType + + " entry"); + return T(); + } + return entry.entry; +} + +LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader, + bool hasCustomEncoding, + Attribute &result) { + // Handle the fallback case, where the attribute was encoded using its + // assembly format. + if (!hasCustomEncoding) { + StringRef attrStr; + if (failed(reader.parseNullTerminatedString(attrStr))) + return failure(); + + size_t numRead = 0; + if (!(result = parseAttribute(attrStr, fileLoc->getContext(), numRead))) + return failure(); + if (numRead != attrStr.size()) { + return reader.emitError( + "trailing characters found after Attribute assembly format: ", + attrStr.drop_front(numRead)); + } + return success(); + } + + return reader.emitError("unexpected Attribute encoding"); +} + +LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader, + bool hasCustomEncoding, Type &result) { + // Handle the fallback case, where the type was encoded using its + // assembly format. + if (!hasCustomEncoding) { + StringRef typeStr; + if (failed(reader.parseNullTerminatedString(typeStr))) + return failure(); + + size_t numRead = 0; + if (!(result = parseType(typeStr, fileLoc->getContext(), numRead))) + return failure(); + if (numRead != typeStr.size()) { + return reader.emitError( + "trailing characters found after Type assembly format: " + + typeStr.drop_front(numRead)); + } + return success(); + } + + return reader.emitError("unexpected Type encoding"); +} + +//===----------------------------------------------------------------------===// +// Bytecode Reader +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to read a bytecode buffer and translate it into MLIR. +class BytecodeReader { +public: + BytecodeReader(Location fileLoc, const ParserConfig &config) + : config(config), fileLoc(fileLoc), attrTypeReader(fileLoc), + // Use the builtin unrealized conversion cast operation to represent + // forward references to values that aren't yet defined. + forwardRefOpState(UnknownLoc::get(config.getContext()), + "builtin.unrealized_conversion_cast", ValueRange(), + NoneType::get(config.getContext())) {} + + /// Read the bytecode defined within `buffer` into the given block. + LogicalResult read(llvm::MemoryBufferRef buffer, Block *block); + +private: + /// Return the context for this config. + MLIRContext *getContext() const { return config.getContext(); } + + /// Parse the bytecode version. + LogicalResult parseVersion(EncodingReader &reader); + + //===--------------------------------------------------------------------===// + // Dialect Section + + LogicalResult parseDialectSection(ArrayRef sectionData); + + /// Parse an operation name reference using the given reader. + FailureOr parseOpName(EncodingReader &reader); + + //===--------------------------------------------------------------------===// + // Attribute/Type Section + + /// Parse an attribute or type using the given reader. Returns nullptr in the + /// case of failure. + Attribute parseAttribute(EncodingReader &reader); + Type parseType(EncodingReader &reader); + + template + T parseAttribute(EncodingReader &reader) { + if (Attribute attr = parseAttribute(reader)) { + if (auto derivedAttr = attr.dyn_cast()) + return derivedAttr; + (void)reader.emitError("expected attribute of type: ", + llvm::getTypeName(), ", but got: ", attr); + } + return T(); + } + + //===--------------------------------------------------------------------===// + // IR Section + + /// This struct represents the current read state of a range of regions. This + /// struct is used to enable iterative parsing of regions. + struct RegionReadState { + RegionReadState(Operation *op, bool isIsolatedFromAbove) + : RegionReadState(op->getRegions(), isIsolatedFromAbove) {} + RegionReadState(MutableArrayRef regions, bool isIsolatedFromAbove) + : curRegion(regions.begin()), endRegion(regions.end()), + isIsolatedFromAbove(isIsolatedFromAbove) {} + + /// The current regions being read. + MutableArrayRef::iterator curRegion, endRegion; + + /// The number of values defined immediately within this region. + unsigned numValues = 0; + + /// The current blocks of the region being read. + SmallVector curBlocks; + Region::iterator curBlock = {}; + + /// The number of operations remaining to be read from the current block + /// being read. + uint64_t numOpsRemaining = 0; + + /// A flag indicating if the regions being read are isolated from above. + bool isIsolatedFromAbove = false; + }; + + LogicalResult parseIRSection(ArrayRef sectionData, Block *block); + LogicalResult parseRegions(EncodingReader &reader, + std::vector ®ionStack, + RegionReadState &readState); + FailureOr parseOpWithoutRegions(EncodingReader &reader, + RegionReadState &readState, + bool &isIsolatedFromAbove); + + LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState); + LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); + LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); + + //===--------------------------------------------------------------------===// + // String Section + + LogicalResult parseStringSection(ArrayRef sectionData); + + /// Parse a shared string from the string section. The shared string is + /// encoded using an index to a corresponding string in the string section. + LogicalResult parseSharedString(EncodingReader &reader, StringRef &result) { + return parseEntry(reader, strings, result, "string"); + } + + //===--------------------------------------------------------------------===// + // Value Processing + + /// Parse an operand reference using the given reader. Returns nullptr in the + /// case of failure. + Value parseOperand(EncodingReader &reader); + + /// Sequentially define the given value range. + LogicalResult defineValues(EncodingReader &reader, ValueRange values); + + /// Create a value to use for a forward reference. + Value createForwardRef(); + + //===--------------------------------------------------------------------===// + // Fields + + /// This class represents a single value scope, in which a value scope is + /// delimited by isolated from above regions. + struct ValueScope { + /// Push a new region state onto this scope, reserving enough values for + /// those defined within the current region of the provided state. + void push(RegionReadState &readState) { + nextValueIDs.push_back(values.size()); + values.resize(values.size() + readState.numValues); + } + + /// Pop the values defined for the current region within the provided region + /// state. + void pop(RegionReadState &readState) { + values.resize(values.size() - readState.numValues); + nextValueIDs.pop_back(); + } + + /// The set of values defined in this scope. + std::vector values; + + /// The ID for the next defined value for each region current being + /// processed in this scope. + SmallVector nextValueIDs; + }; + + /// The configuration of the parser. + const ParserConfig &config; + + /// A location to use when emitting errors. + Location fileLoc; + + /// The reader used to process attribute and types within the bytecode. + AttrTypeReader attrTypeReader; + + /// The version of the bytecode being read. + uint64_t version = 0; + + /// The producer of the bytecode being read. + StringRef producer; + + /// The table of IR units referenced within the bytecode file. + SmallVector dialects; + SmallVector opNames; + + /// The table of strings referenced within the bytecode file. + SmallVector strings; + + /// The current set of available IR value scopes. + std::vector valueScopes; + /// A block containing the set of operations defined to create forward + /// references. + Block forwardRefOps; + /// A block containing previously created, and no longer used, forward + /// reference operations. + Block openForwardRefOps; + /// An operation state used when instantiating forward references. + OperationState forwardRefOpState; +}; +} // namespace + +LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { + EncodingReader reader(buffer.getBuffer(), fileLoc); + + // Skip over the bytecode header, this should have already been checked. + if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) + return failure(); + // Parse the bytecode version and producer. + if (failed(parseVersion(reader)) || + failed(reader.parseNullTerminatedString(producer))) + return failure(); + + // Add a diagnostic handler that attaches a note that includes the original + // producer of the bytecode. + ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { + diag.attachNote() << "in bytecode version " << version + << " produced by: " << producer; + return failure(); + }); + + // Parse the raw data for each of the top-level sections of the bytecode. + Optional> sectionDatas[bytecode::Section::kNumSections]; + while (!reader.empty()) { + // Read the next section from the bytecode. + bytecode::Section::ID sectionID; + ArrayRef sectionData; + if (failed(reader.parseSection(sectionID, sectionData))) + return failure(); + + // Check for duplicate sections, we only expect one instance of each. + if (sectionDatas[sectionID]) { + return reader.emitError("duplicate top-level section: ", + toString(sectionID)); + } + sectionDatas[sectionID] = sectionData; + } + // Check that all of the sections were found. + for (int i = 0; i < bytecode::Section::kNumSections; ++i) { + if (!sectionDatas[i]) { + return reader.emitError("missing data for top-level section: ", + toString(bytecode::Section::ID(i))); + } + } + + // Process the string section first. + if (failed(parseStringSection(*sectionDatas[bytecode::Section::kString]))) + return failure(); + + // Process the dialect section. + if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) + return failure(); + + // Process the attribute and type section. + if (failed(attrTypeReader.initialize( + dialects, *sectionDatas[bytecode::Section::kAttrType], + *sectionDatas[bytecode::Section::kAttrTypeOffset]))) + return failure(); + + // Finally, process the IR section. + return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); +} + +LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { + if (failed(reader.parseVarInt(version))) + return failure(); + + // Validate the bytecode version. + uint64_t currentVersion = bytecode::kVersion; + if (version < currentVersion) { + return reader.emitError("bytecode version ", version, + " is older than the current version of ", + currentVersion, ", and upgrade is not supported"); + } + if (version > currentVersion) { + return reader.emitError("bytecode version ", version, + " is newer than the current version ", + currentVersion); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Dialect Section + +LogicalResult +BytecodeReader::parseDialectSection(ArrayRef sectionData) { + EncodingReader sectionReader(sectionData, fileLoc); + + // Parse the number of dialects in the section. + uint64_t numDialects; + if (failed(sectionReader.parseVarInt(numDialects))) + return failure(); + dialects.resize(numDialects); + + // Parse each of the dialects. + for (uint64_t i = 0; i < numDialects; ++i) + if (failed(parseSharedString(sectionReader, dialects[i].name))) + return failure(); + + // Parse the operation names, which are grouped by dialect. + auto parseOpName = [&](BytecodeDialect *dialect) { + StringRef opName; + if (failed(parseSharedString(sectionReader, opName))) + return failure(); + opNames.emplace_back(dialect, opName); + return success(); + }; + while (!sectionReader.empty()) + if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) + return failure(); + return success(); +} + +FailureOr BytecodeReader::parseOpName(EncodingReader &reader) { + BytecodeOperationName *opName = nullptr; + if (failed(parseEntry(reader, opNames, opName, "operation name"))) + return failure(); + + // Check to see if this operation name has already been resolved. If we + // haven't, load the dialect and build the operation name. + if (!opName->opName) { + if (failed(opName->dialect->load(reader, getContext()))) + return failure(); + opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), + getContext()); + } + return *opName->opName; +} + +//===----------------------------------------------------------------------===// +// Attribute/Type Section + +Attribute BytecodeReader::parseAttribute(EncodingReader &reader) { + uint64_t attrIdx; + if (failed(reader.parseVarInt(attrIdx))) + return Attribute(); + return attrTypeReader.resolveAttribute(attrIdx); +} + +Type BytecodeReader::parseType(EncodingReader &reader) { + uint64_t typeIdx; + if (failed(reader.parseVarInt(typeIdx))) + return Type(); + return attrTypeReader.resolveType(typeIdx); +} + +//===----------------------------------------------------------------------===// +// IR Section + +LogicalResult BytecodeReader::parseIRSection(ArrayRef sectionData, + Block *block) { + EncodingReader reader(sectionData, fileLoc); + + // A stack of operation regions currently being read from the bytecode. + std::vector regionStack; + + // Parse the top-level block using a temporary module operation. + OwningOpRef moduleOp = ModuleOp::create(fileLoc); + regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true); + regionStack.back().curBlocks.push_back(moduleOp->getBody()); + regionStack.back().curBlock = regionStack.back().curRegion->begin(); + if (failed(parseBlock(reader, regionStack.back()))) + return failure(); + valueScopes.emplace_back(ValueScope()); + valueScopes.back().push(regionStack.back()); + + // Iteratively parse regions until everything has been resolved. + while (!regionStack.empty()) + if (failed(parseRegions(reader, regionStack, regionStack.back()))) + return failure(); + if (!forwardRefOps.empty()) { + return reader.emitError( + "not all forward unresolved forward operand references"); + } + + // Verify that the parsed operations are valid. + if (failed(verify(*moduleOp))) + return failure(); + + // Splice the parsed operations over to the provided top-level block. + auto &parsedOps = moduleOp->getBody()->getOperations(); + auto &destOps = block->getOperations(); + destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), + parsedOps, parsedOps.begin(), parsedOps.end()); + return success(); +} + +LogicalResult +BytecodeReader::parseRegions(EncodingReader &reader, + std::vector ®ionStack, + RegionReadState &readState) { + // Read the regions of this operation. + for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { + // If the current block hasn't been setup yet, parse the header for this + // region. + if (readState.curBlock == Region::iterator()) { + if (failed(parseRegion(reader, readState))) + return failure(); + + // If the region is empty, there is nothing to more to do. + if (readState.curRegion->empty()) + continue; + } + + // Parse the blocks within the region. + do { + while (readState.numOpsRemaining--) { + // Read in the next operation. We don't read its regions directly, we + // handle those afterwards as necessary. + bool isIsolatedFromAbove = false; + FailureOr op = + parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); + if (failed(op)) + return failure(); + + // If the op has regions, add it to the stack for processing. + if ((*op)->getNumRegions()) { + regionStack.emplace_back(*op, isIsolatedFromAbove); + + // If the op is isolated from above, push a new value scope. + if (isIsolatedFromAbove) + valueScopes.emplace_back(ValueScope()); + return success(); + } + } + + // Move to the next block of the region. + if (++readState.curBlock == readState.curRegion->end()) + break; + if (failed(parseBlock(reader, readState))) + return failure(); + } while (true); + + // Reset the current block and any values reserved for this region. + readState.curBlock = {}; + valueScopes.back().pop(readState); + } + + // When the regions have been fully parsed, pop them off of the read stack. If + // the regions were isolated from above, we also pop the last value scope. + regionStack.pop_back(); + if (readState.isIsolatedFromAbove) + valueScopes.pop_back(); + return success(); +} + +FailureOr +BytecodeReader::parseOpWithoutRegions(EncodingReader &reader, + RegionReadState &readState, + bool &isIsolatedFromAbove) { + // Parse the name of the operation. + FailureOr opName = parseOpName(reader); + if (failed(opName)) + return failure(); + + // Parse the operation mask, which indicates which components of the operation + // are present. + uint8_t opMask; + if (failed(reader.parseByte(opMask))) + return failure(); + + /// Parse the location. + LocationAttr opLoc = parseAttribute(reader); + if (!opLoc) + return failure(); + + // With the location and name resolved, we can start building the operation + // state. + OperationState opState(opLoc, *opName); + + // Parse the attributes of the operation. + if (opMask & bytecode::OpEncodingMask::kHasAttrs) { + DictionaryAttr dictAttr = parseAttribute(reader); + if (!dictAttr) + return failure(); + opState.attributes = dictAttr; + } + + /// Parse the results of the operation. + if (opMask & bytecode::OpEncodingMask::kHasResults) { + uint64_t numResults; + if (failed(reader.parseVarInt(numResults))) + return failure(); + opState.types.resize(numResults); + for (int i = 0, e = numResults; i < e; ++i) + if (!(opState.types[i] = parseType(reader))) + return failure(); + } + + /// Parse the operands of the operation. + if (opMask & bytecode::OpEncodingMask::kHasOperands) { + uint64_t numOperands; + if (failed(reader.parseVarInt(numOperands))) + return failure(); + opState.operands.resize(numOperands); + for (int i = 0, e = numOperands; i < e; ++i) + if (!(opState.operands[i] = parseOperand(reader))) + return failure(); + } + + /// Parse the successors of the operation. + if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { + uint64_t numSuccs; + if (failed(reader.parseVarInt(numSuccs))) + return failure(); + opState.successors.resize(numSuccs); + for (int i = 0, e = numSuccs; i < e; ++i) { + if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], + "successor"))) + return failure(); + } + } + + /// Parse the regions of the operation. + if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { + uint64_t numRegions; + if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) + return failure(); + + opState.regions.reserve(numRegions); + for (int i = 0, e = numRegions; i < e; ++i) + opState.regions.push_back(std::make_unique()); + } + + // Create the operation at the back of the current block. + Operation *op = Operation::create(opState); + readState.curBlock->push_back(op); + + // If the operation had results, update the value references. + if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) + return failure(); + + return op; +} + +LogicalResult BytecodeReader::parseRegion(EncodingReader &reader, + RegionReadState &readState) { + // Parse the number of blocks in the region. + uint64_t numBlocks; + if (failed(reader.parseVarInt(numBlocks))) + return failure(); + + // If the region is empty, there is nothing else to do. + if (numBlocks == 0) + return success(); + + // Parse the number of values defined in this region. + uint64_t numValues; + if (failed(reader.parseVarInt(numValues))) + return failure(); + readState.numValues = numValues; + + // Create the blocks within this region. We do this before processing so that + // we can rely on the blocks existing when creating operations. + readState.curBlocks.clear(); + readState.curBlocks.reserve(numBlocks); + for (uint64_t i = 0; i < numBlocks; ++i) { + readState.curBlocks.push_back(new Block()); + readState.curRegion->push_back(readState.curBlocks.back()); + } + + // Prepare the current value scope for this region. + valueScopes.back().push(readState); + + // Parse the entry block of the region. + readState.curBlock = readState.curRegion->begin(); + return parseBlock(reader, readState); +} + +LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, + RegionReadState &readState) { + bool hasArgs; + if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) + return failure(); + + // Parse the arguments of the block. + if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) + return failure(); + + // We don't parse the operations of the block here, that's done elsewhere. + return success(); +} + +LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader, + Block *block) { + // Parse the value ID for the first argument, and the number of arguments. + uint64_t numArgs; + if (failed(reader.parseVarInt(numArgs))) + return failure(); + + SmallVector argTypes; + SmallVector argLocs; + argTypes.reserve(numArgs); + argLocs.reserve(numArgs); + + while (numArgs--) { + Type argType = parseType(reader); + if (!argType) + return failure(); + LocationAttr argLoc = parseAttribute(reader); + if (!argLoc) + return failure(); + + argTypes.push_back(argType); + argLocs.push_back(argLoc); + } + block->addArguments(argTypes, argLocs); + return defineValues(reader, block->getArguments()); +} + +//===----------------------------------------------------------------------===// +// String Section + +LogicalResult +BytecodeReader::parseStringSection(ArrayRef sectionData) { + EncodingReader stringReader(sectionData, fileLoc); + + // Parse the number of strings in the section. + uint64_t numStrings; + if (failed(stringReader.parseVarInt(numStrings))) + return failure(); + strings.resize(numStrings); + + // Parse each of the strings. The sizes of the strings are encoded in reverse + // order, so that's the order we populate the table. + size_t stringDataEndOffset = sectionData.size(); + size_t totalStringDataSize = 0; + for (StringRef &string : llvm::reverse(strings)) { + uint64_t stringSize; + if (failed(stringReader.parseVarInt(stringSize))) + return failure(); + if (stringDataEndOffset < stringSize) { + return stringReader.emitError( + "string size exceeds the available data size"); + } + + // Extract the string from the data, dropping the null character. + size_t stringOffset = stringDataEndOffset - stringSize; + string = StringRef( + reinterpret_cast(sectionData.data() + stringOffset), + stringSize - 1); + stringDataEndOffset = stringOffset; + + // Update the total string data size. + totalStringDataSize += stringSize; + } + + // Check that the only remaining data was for the strings + if (stringReader.size() != totalStringDataSize) { + return stringReader.emitError("unexpected trailing data between the " + "offsets for strings and their data"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Value Processing + +Value BytecodeReader::parseOperand(EncodingReader &reader) { + std::vector &values = valueScopes.back().values; + Value *value = nullptr; + if (failed(parseEntry(reader, values, value, "value"))) + return Value(); + + // Create a new forward reference if necessary. + if (!*value) + *value = createForwardRef(); + return *value; +} + +LogicalResult BytecodeReader::defineValues(EncodingReader &reader, + ValueRange newValues) { + ValueScope &valueScope = valueScopes.back(); + std::vector &values = valueScope.values; + + unsigned &valueID = valueScope.nextValueIDs.back(); + unsigned valueIDEnd = valueID + newValues.size(); + if (valueIDEnd > values.size()) { + return reader.emitError( + "value index range was outside of the expected range for " + "the parent region, got [", + valueID, ", ", valueIDEnd, "), but the maximum index was ", + values.size() - 1); + } + + // Assign the values and update any forward references. + for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { + Value newValue = newValues[i]; + + // Check to see if a definition for this value already exists. + if (Value oldValue = std::exchange(values[valueID], newValue)) { + Operation *forwardRefOp = oldValue.getDefiningOp(); + + // Assert that this is a forward reference operation. Given how we compute + // definition ids (incrementally as we parse), it shouldn't be possible + // for the value to be defined any other way. + assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && + "value index was already defined?"); + + oldValue.replaceAllUsesWith(newValue); + forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); + } + } + return success(); +} + +Value BytecodeReader::createForwardRef() { + // Check for an avaliable existing operation to use. Otherwise, create a new + // fake operation to use for the reference. + if (!openForwardRefOps.empty()) { + Operation *op = &openForwardRefOps.back(); + op->moveBefore(&forwardRefOps, forwardRefOps.end()); + } else { + forwardRefOps.push_back(Operation::create(forwardRefOpState)); + } + return forwardRefOps.back().getResult(0); +} + +//===----------------------------------------------------------------------===// +// Entry Points +//===----------------------------------------------------------------------===// + +bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { + return buffer.getBuffer().startswith("ML\xefR"); +} + +LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, + const ParserConfig &config) { + Location sourceFileLoc = + FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), + /*line=*/0, /*column=*/0); + if (!isBytecode(buffer)) { + return emitError(sourceFileLoc, + "input buffer is not an MLIR bytecode file"); + } + + BytecodeReader reader(sourceFileLoc, config); + return reader.read(buffer, block); +} diff --git a/mlir/lib/Bytecode/Reader/CMakeLists.txt b/mlir/lib/Bytecode/Reader/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Reader/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_library(MLIRBytecodeReader + BytecodeReader.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode + + LINK_LIBS PUBLIC + MLIRAsmParser + MLIRIR + MLIRSupport + ) diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -0,0 +1,520 @@ +//===- BytecodeWriter.cpp - MLIR Bytecode Writer --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "../Encoding.h" +#include "IRNumbering.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/CachedHashString.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "mlir-bytecode-writer" + +using namespace mlir; +using namespace mlir::bytecode::detail; + +//===----------------------------------------------------------------------===// +// EncodingEmitter +//===----------------------------------------------------------------------===// + +namespace { +/// This class functions as the underlying encoding emitter for the bytecode +/// writer. This class is a bit different compared to other types of encoders; +/// it does not use a single buffer, but instead may contain several buffers +/// (some owned by the writer, and some not) that get concatted during the final +/// emission. +class EncodingEmitter { +public: + EncodingEmitter() = default; + EncodingEmitter(const EncodingEmitter &) = delete; + EncodingEmitter &operator=(const EncodingEmitter &) = delete; + + /// Write the current contents to the provided stream. + void writeTo(raw_ostream &os) const; + + /// Return the current size of the encoded buffer. + size_t size() const { return prevResultSize + currentResult.size(); } + + //===--------------------------------------------------------------------===// + // Emission + //===--------------------------------------------------------------------===// + + /// Backpatch a byte in the result buffer at the given offset. + void patchByte(uint64_t offset, uint8_t value) { + assert(offset < size() && offset >= prevResultSize && + "cannot patch previously emitted data"); + currentResult[offset - prevResultSize] = value; + } + + //===--------------------------------------------------------------------===// + // Integer Emission + + /// Emit a single byte. + template + void emitByte(T byte) { + currentResult.push_back(static_cast(byte)); + } + + /// Emit a range of bytes. + void emitBytes(ArrayRef bytes) { + llvm::append_range(currentResult, bytes); + } + + /// Emit a variable length integer. The first encoded byte contains a prefix + /// in the low bits indicating the encoded length of the value. This length + /// prefix is a bit sequence of '0's followed by a '1'. The number of '0' bits + /// indicate the number of _additional_ bytes (not including the prefix byte). + /// All remaining bits in the first byte, along with all of the bits in + /// additional bytes, provide the value of the integer encoded in + /// little-endian order. + void emitVarInt(uint64_t value) { + // In the most common case, the value can be represented in a single byte. + // Given how hot this case is, explicitly handle that here. + if ((value >> 7) == 0) + return emitByte((value << 1) | 0x1); + emitMultiByteVarInt(value); + } + + /// Emit a variable length integer whose low bit is used to encode the + /// provided flag, i.e. encoded as: (value << 1) | (flag ? 1 : 0). + void emitVarIntWithFlag(uint64_t value, bool flag) { + emitVarInt((value << 1) | (flag ? 1 : 0)); + } + + //===--------------------------------------------------------------------===// + // String Emission + + /// Emit the given string as a nul terminated string. + void emitNulTerminatedString(StringRef str) { + emitString(str); + emitByte(0); + } + + /// Emit the given string without a nul terminator. + void emitString(StringRef str) { + emitBytes({reinterpret_cast(str.data()), str.size()}); + } + + //===--------------------------------------------------------------------===// + // Section Emission + + /// Emit a nested section of the given code, whose contents are encoded in the + /// provided emitter. + void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) { + // Emit the section code and length. + emitByte(code); + emitVarInt(emitter.size()); + + // Push our current buffer and then merge the provided section body into + // ours. + appendResult(std::move(currentResult)); + for (std::vector &result : emitter.prevResultStorage) + appendResult(std::move(result)); + appendResult(std::move(emitter.currentResult)); + } + +private: + /// Emit the given value using a variable width encoding. This method is a + /// fallback when the number of bytes needed to encode the value is greater + /// than 1. We mark it noinline here so that the single byte hot path isn't + /// pessimized. + LLVM_ATTRIBUTE_NOINLINE void emitMultiByteVarInt(uint64_t value); + + /// Append a new result buffer to the current contents. + void appendResult(std::vector &&result) { + prevResultSize += result.size(); + prevResultStorage.emplace_back(std::move(result)); + prevResultList.emplace_back(prevResultStorage.back()); + } + + /// The result of the emitter currently being built. We refrain from building + /// a single buffer to simplify emitting sections, large data, and more. The + /// result is thus represented using multiple distinct buffers, some of which + /// we own (via prevResultStorage), and some of which are just pointers into + /// externally owned buffers. + std::vector currentResult; + std::vector> prevResultList; + std::vector> prevResultStorage; + + /// An up-to-date total size of all of the buffers within `prevResultList`. + /// This enables O(1) size checks of the current encoding. + size_t prevResultSize = 0; +}; + +/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need +/// to go through an intermediate buffer when interacting with code that wants a +/// raw_ostream. +class raw_emitter_ostream : public raw_ostream { +public: + explicit raw_emitter_ostream(EncodingEmitter &emitter) : emitter(emitter) { + SetUnbuffered(); + } + +private: + void write_impl(const char *ptr, size_t size) override { + emitter.emitBytes({reinterpret_cast(ptr), size}); + } + uint64_t current_pos() const override { return emitter.size(); } + + /// The section being emitted to. + EncodingEmitter &emitter; +}; +} // namespace + +void EncodingEmitter::writeTo(raw_ostream &os) const { + for (auto &prevResult : prevResultList) + os.write((const char *)prevResult.data(), prevResult.size()); + os.write((const char *)currentResult.data(), currentResult.size()); +} + +void EncodingEmitter::emitMultiByteVarInt(uint64_t value) { + // Compute the number of bytes needed to encode the value. Each byte can hold + // up to 7-bits of data. We only check up to the number of bits we can encode + // in the first byte (8). + uint64_t it = value >> 7; + for (size_t numBytes = 2; numBytes < 9; ++numBytes) { + if (LLVM_LIKELY(it >>= 7) == 0) { + uint64_t encodedValue = (value << 1) | 0x1; + encodedValue <<= (numBytes - 1); + emitBytes({reinterpret_cast(&encodedValue), numBytes}); + return; + } + } + + // If the value is too large to encode in a single byte, emit a special all + // zero marker byte and splat the value directly. + emitByte(0); + emitBytes({reinterpret_cast(&value), sizeof(value)}); +} + +//===----------------------------------------------------------------------===// +// Bytecode Writer +//===----------------------------------------------------------------------===// + +namespace { +class BytecodeWriter { +public: + BytecodeWriter(Operation *op) : numberingState(op) {} + + /// Write the bytecode for the given root operation. + void write(Operation *rootOp, raw_ostream &os, StringRef producer); + +private: + //===--------------------------------------------------------------------===// + // Dialects + + void writeDialectSection(EncodingEmitter &emitter); + + //===--------------------------------------------------------------------===// + // Attributes and Types + + void writeAttrTypeSection(EncodingEmitter &emitter); + + //===--------------------------------------------------------------------===// + // Operations + + void writeBlock(EncodingEmitter &emitter, Block *block); + void writeOp(EncodingEmitter &emitter, Operation *op); + void writeRegion(EncodingEmitter &emitter, Region *region); + void writeIRSection(EncodingEmitter &emitter, Operation *op); + + //===--------------------------------------------------------------------===// + // Strings + + void writeStringSection(EncodingEmitter &emitter); + + /// Get the number for the given shared string, that is contained within the + /// string section. + size_t getSharedStringNumber(StringRef str); + + //===--------------------------------------------------------------------===// + // Fields + + /// The IR numbering state generated for the root operation. + IRNumberingState numberingState; + + /// A set of strings referenced within the bytecode. The value of the map is + /// unused. + llvm::MapVector strings; +}; +} // namespace + +void BytecodeWriter::write(Operation *rootOp, raw_ostream &os, + StringRef producer) { + EncodingEmitter emitter; + + // Emit the bytecode file header. This is how we identify the output as a + // bytecode file. + emitter.emitString("ML\xefR"); + + // Emit the bytecode version. + emitter.emitVarInt(bytecode::kVersion); + + // Emit the producer. + emitter.emitNulTerminatedString(producer); + + // Emit the dialect section. + writeDialectSection(emitter); + + // Emit the attributes and types section. + writeAttrTypeSection(emitter); + + // Emit the IR section. + writeIRSection(emitter, rootOp); + + // Emit the string section. + writeStringSection(emitter); + + // Write the generated bytecode to the provided output stream. + emitter.writeTo(os); +} + +//===----------------------------------------------------------------------===// +// Dialects + +/// Write the given entries in contiguous groups with the same parent dialect. +/// Each dialect sub-group is encoded with the parent dialect and number of +/// elements, followed by the encoding for the entries. The given callback is +/// invoked to encode each individual entry. +template +static void writeDialectGrouping(EncodingEmitter &emitter, EntriesT &&entries, + EntryCallbackT &&callback) { + for (auto it = entries.begin(), e = entries.end(); it != e;) { + auto groupStart = it++; + + // Find the end of the group that shares the same parent dialect. + DialectNumbering *currentDialect = groupStart->dialect; + it = std::find_if(it, e, [&](const auto &entry) { + return entry.dialect != currentDialect; + }); + + // Emit the dialect and number of elements. + emitter.emitVarInt(currentDialect->number); + emitter.emitVarInt(std::distance(groupStart, it)); + + // Emit the entries within the group. + for (auto &entry : llvm::make_range(groupStart, it)) + callback(entry); + } +} + +void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) { + EncodingEmitter dialectEmitter; + + // Emit the referenced dialects. + auto dialects = numberingState.getDialects(); + dialectEmitter.emitVarInt(llvm::size(dialects)); + for (DialectNumbering &dialect : dialects) + dialectEmitter.emitVarInt(getSharedStringNumber(dialect.name)); + + // Emit the referenced operation names grouped by dialect. + auto emitOpName = [&](OpNameNumbering &name) { + dialectEmitter.emitVarInt(getSharedStringNumber(name.name.stripDialect())); + }; + writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName); + + emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter)); +} + +//===----------------------------------------------------------------------===// +// Attributes and Types + +void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { + EncodingEmitter attrTypeEmitter; + EncodingEmitter offsetEmitter; + offsetEmitter.emitVarInt(llvm::size(numberingState.getAttributes())); + offsetEmitter.emitVarInt(llvm::size(numberingState.getTypes())); + + // A functor used to emit an attribute or type entry. + uint64_t prevOffset = 0; + auto emitAttrOrType = [&](auto &entry) { + // TODO: Allow dialects to provide more optimal implementations of attribute + // and type encodings. + bool hasCustomEncoding = false; + + // Emit the entry using the textual format. + raw_emitter_ostream(attrTypeEmitter) << entry.getValue(); + attrTypeEmitter.emitByte(0); + + // Record the offset of this entry. + uint64_t curOffset = attrTypeEmitter.size(); + offsetEmitter.emitVarIntWithFlag(curOffset - prevOffset, hasCustomEncoding); + prevOffset = curOffset; + }; + + // Emit the attribute and type entries for each dialect. + writeDialectGrouping(offsetEmitter, numberingState.getAttributes(), + emitAttrOrType); + writeDialectGrouping(offsetEmitter, numberingState.getTypes(), + emitAttrOrType); + + // Emit the sections to the stream. + emitter.emitSection(bytecode::Section::kAttrTypeOffset, + std::move(offsetEmitter)); + emitter.emitSection(bytecode::Section::kAttrType, std::move(attrTypeEmitter)); +} + +//===----------------------------------------------------------------------===// +// Operations + +void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) { + ArrayRef args = block->getArguments(); + bool hasArgs = !args.empty(); + + // Emit the number of operations in this block, and if it has arguments. We + // use the low bit of the operation count to indicate if the block has + // arguments. + unsigned numOps = numberingState.getOperationCount(block); + emitter.emitVarIntWithFlag(numOps, hasArgs); + + // Emit the arguments of the block. + if (hasArgs) { + emitter.emitVarInt(args.size()); + for (BlockArgument arg : args) { + emitter.emitVarInt(numberingState.getNumber(arg.getType())); + emitter.emitVarInt(numberingState.getNumber(arg.getLoc())); + } + } + + // Emit the operations within the block. + for (Operation &op : *block) + writeOp(emitter, &op); +} + +void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) { + emitter.emitVarInt(numberingState.getNumber(op->getName())); + + // Emit a mask for the operation components. We need to fill this in later + // (when we actually know what needs to be emitted), so emit a placeholder for + // now. + uint64_t maskOffset = emitter.size(); + uint8_t opEncodingMask = 0; + emitter.emitByte(0); + + // Emit the location for this operation. + emitter.emitVarInt(numberingState.getNumber(op->getLoc())); + + // Emit the attributes of this operation. + DictionaryAttr attrs = op->getAttrDictionary(); + if (!attrs.empty()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs; + emitter.emitVarInt(numberingState.getNumber(op->getAttrDictionary())); + } + + // Emit the result types of the operation. + if (unsigned numResults = op->getNumResults()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasResults; + emitter.emitVarInt(numResults); + for (Type type : op->getResultTypes()) + emitter.emitVarInt(numberingState.getNumber(type)); + } + + // Emit the operands of the operation. + if (unsigned numOperands = op->getNumOperands()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasOperands; + emitter.emitVarInt(numOperands); + for (Value operand : op->getOperands()) + emitter.emitVarInt(numberingState.getNumber(operand)); + } + + // Emit the successors of the operation. + if (unsigned numSuccessors = op->getNumSuccessors()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasSuccessors; + emitter.emitVarInt(numSuccessors); + for (Block *successor : op->getSuccessors()) + emitter.emitVarInt(numberingState.getNumber(successor)); + } + + // Check for regions. + unsigned numRegions = op->getNumRegions(); + if (numRegions) + opEncodingMask |= bytecode::OpEncodingMask::kHasInlineRegions; + + // Update the mask for the operation. + emitter.patchByte(maskOffset, opEncodingMask); + + // With the mask emitted, we can now emit the regions of the operation. We do + // this after mask emission to avoid offset complications that may arise by + // emitting the regions first (e.g. if the regions are huge, backpatching the + // op encoding mask is more annoying). + if (numRegions) { + bool isIsolatedFromAbove = op->hasTrait(); + emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove); + + for (Region ®ion : op->getRegions()) + writeRegion(emitter, ®ion); + } +} + +void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) { + // If the region is empty, we only need to emit the number of blocks (which is + // zero). + if (region->empty()) + return emitter.emitVarInt(/*numBlocks*/ 0); + + // Emit the number of blocks and values within the region. + unsigned numBlocks, numValues; + std::tie(numBlocks, numValues) = numberingState.getBlockValueCount(region); + emitter.emitVarInt(numBlocks); + emitter.emitVarInt(numValues); + + // Emit the blocks within the region. + for (Block &block : *region) + writeBlock(emitter, &block); +} + +void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) { + EncodingEmitter irEmitter; + + // Write the IR section the same way as a block with no arguments. Note that + // the low-bit of the operation count for a block is used to indicate if the + // block has arguments, which in this case is always false. + irEmitter.emitVarIntWithFlag(/*numOps*/ 1, /*hasArgs*/ false); + + // Emit the operations. + writeOp(irEmitter, op); + + emitter.emitSection(bytecode::Section::kIR, std::move(irEmitter)); +} + +//===----------------------------------------------------------------------===// +// Strings + +void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) { + EncodingEmitter stringEmitter; + stringEmitter.emitVarInt(strings.size()); + + // Emit the sizes in reverse order, so that we don't need to backpatch an + // offset to the string data or have a separate section. + for (const auto &it : llvm::reverse(strings)) + stringEmitter.emitVarInt(it.first.size() + 1); + // Emit the string data itself. + for (const auto &it : strings) + stringEmitter.emitNulTerminatedString(it.first.val()); + + emitter.emitSection(bytecode::Section::kString, std::move(stringEmitter)); +} + +size_t BytecodeWriter::getSharedStringNumber(StringRef str) { + auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); + return it.first->second; +} + +//===----------------------------------------------------------------------===// +// Entry Points +//===----------------------------------------------------------------------===// + +void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os, + StringRef producer) { + BytecodeWriter writer(op); + writer.write(op, os, producer); +} diff --git a/mlir/lib/Bytecode/Writer/CMakeLists.txt b/mlir/lib/Bytecode/Writer/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Writer/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_library(MLIRBytecodeWriter + BytecodeWriter.cpp + IRNumbering.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + ) diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -0,0 +1,193 @@ +//===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains various utilities that number IR structures in preparation +// for bytecode emission. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H +#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H + +#include "mlir/IR/OperationSupport.h" +#include "llvm/ADT/MapVector.h" + +namespace mlir { +class BytecodeWriterConfig; + +namespace bytecode { +namespace detail { +struct DialectNumbering; + +//===----------------------------------------------------------------------===// +// Attribute and Type Numbering +//===----------------------------------------------------------------------===// + +/// This class represents a numbering entry for an Attribute or Type. +struct AttrTypeNumbering { + AttrTypeNumbering(PointerUnion value) : value(value) {} + + /// The concrete value. + PointerUnion value; + + /// The number assigned to this value. + unsigned number = 0; + + /// The number of references to this value. + unsigned refCount = 1; + + /// The dialect of this value. + DialectNumbering *dialect = nullptr; +}; +struct AttributeNumbering : public AttrTypeNumbering { + AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {} + Attribute getValue() const { return value.get(); } +}; +struct TypeNumbering : public AttrTypeNumbering { + TypeNumbering(Type value) : AttrTypeNumbering(value) {} + Type getValue() const { return value.get(); } +}; + +//===----------------------------------------------------------------------===// +// OpName Numbering +//===----------------------------------------------------------------------===// + +/// This class represents the numbering entry of an operation name. +struct OpNameNumbering { + OpNameNumbering(DialectNumbering *dialect, OperationName name) + : dialect(dialect), name(name) {} + + /// The dialect of this value. + DialectNumbering *dialect; + + /// The concrete name. + OperationName name; + + /// The number assigned to this name. + unsigned number = 0; + + /// The number of references to this name. + unsigned refCount = 1; +}; + +//===----------------------------------------------------------------------===// +// Dialect Numbering +//===----------------------------------------------------------------------===// + +/// This class represents a numbering entry for an Dialect. +struct DialectNumbering { + DialectNumbering(StringRef name, unsigned number) + : name(name), number(number) {} + + /// The namespace of the dialect. + StringRef name; + + /// The number assigned to the dialect. + unsigned number; + + /// The loaded dialect, or nullptr if the dialect isn't loaded. + Dialect *dialect = nullptr; +}; + +//===----------------------------------------------------------------------===// +// IRNumberingState +//===----------------------------------------------------------------------===// + +/// This class manages numbering IR entities in preparation of bytecode +/// emission. +class IRNumberingState { +public: + IRNumberingState(Operation *op); + + /// Return the numbered dialects. + auto getDialects() { + return llvm::make_pointee_range(llvm::make_second_range(dialects)); + } + auto getAttributes() { return llvm::make_pointee_range(orderedAttrs); } + auto getOpNames() { return llvm::make_pointee_range(orderedOpNames); } + auto getTypes() { return llvm::make_pointee_range(orderedTypes); } + + /// Return the number for the given IR unit. + unsigned getNumber(Attribute attr) { + assert(attrs.count(attr) && "attribute not numbered"); + return attrs[attr]->number; + } + unsigned getNumber(Block *block) { + assert(blockIDs.count(block) && "block not numbered"); + return blockIDs[block]; + } + unsigned getNumber(OperationName opName) { + assert(opNames.count(opName) && "opName not numbered"); + return opNames[opName]->number; + } + unsigned getNumber(Type type) { + assert(types.count(type) && "type not numbered"); + return types[type]->number; + } + unsigned getNumber(Value value) { + assert(valueIDs.count(value) && "value not numbered"); + return valueIDs[value]; + } + + /// Return the block and value counts of the given region. + std::pair getBlockValueCount(Region *region) { + assert(regionBlockValueCounts.count(region) && "value not numbered"); + return regionBlockValueCounts[region]; + } + + /// Return the number of operations in the given block. + unsigned getOperationCount(Block *block) { + assert(blockOperationCounts.count(block) && "block not numbered"); + return blockOperationCounts[block]; + } + +private: + /// Number the given IR unit for bytecode emission. + void number(Attribute attr); + void number(Block &block); + DialectNumbering &numberDialect(Dialect *dialect); + DialectNumbering &numberDialect(StringRef dialect); + void number(Operation &op); + void number(OperationName opName); + void number(Region ®ion); + void number(Type type); + + /// Mapping from IR to the respective numbering entries. + DenseMap attrs; + DenseMap opNames; + DenseMap types; + DenseMap registeredDialects; + llvm::MapVector dialects; + std::vector orderedAttrs; + std::vector orderedOpNames; + std::vector orderedTypes; + + /// Allocators used for the various numbering entries. + llvm::SpecificBumpPtrAllocator attrAllocator; + llvm::SpecificBumpPtrAllocator dialectAllocator; + llvm::SpecificBumpPtrAllocator opNameAllocator; + llvm::SpecificBumpPtrAllocator typeAllocator; + + /// The value ID for each Block and Value. + DenseMap blockIDs; + DenseMap valueIDs; + + /// The number of operations in each block. + DenseMap blockOperationCounts; + + /// A map from region to the number of blocks and values within that region. + DenseMap> regionBlockValueCounts; + + /// The next value ID to assign when numbering. + unsigned nextValueID = 0; +}; +} // namespace detail +} // namespace bytecode +} // namespace mlir + +#endif diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -0,0 +1,251 @@ +//===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRNumbering.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" + +using namespace mlir; +using namespace mlir::bytecode::detail; + +//===----------------------------------------------------------------------===// +// IR Numbering +//===----------------------------------------------------------------------===// + +/// Group and sort the elements of the given range by their parent dialect. This +/// grouping is applied to sub-sections of the ranged defined by how many bytes +/// it takes to encode a varint index to that sub-section. +template +static void groupByDialectPerByte(T range) { + if (range.empty()) + return; + + // A functor used to sort by a given dialect, with a desired dialect to be + // ordered first (to better enable sharing of dialects across byte groups). + auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs, + const auto &rhs) { + if (lhs->dialect->number == dialectToOrderFirst) + return rhs->dialect->number != dialectToOrderFirst; + return lhs->dialect->number < rhs->dialect->number; + }; + + unsigned dialectToOrderFirst = 0; + size_t elementsInByteGroup = 0; + auto iterRange = range; + for (unsigned i = 1; i < 9; ++i) { + // Update the number of elements in the current byte grouping. Reminder + // that varint encodes 7-bits per byte, so that's how we compute the + // number of elements in each byte grouping. + elementsInByteGroup = (1 << (7 * i)) - elementsInByteGroup; + + // Slice out the sub-set of elements that are in the current byte grouping + // to be sorted. + auto byteSubRange = iterRange.take_front(elementsInByteGroup); + iterRange = iterRange.drop_front(byteSubRange.size()); + + // Sort the sub range for this byte. + llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) { + return sortByDialect(dialectToOrderFirst, lhs, rhs); + }); + + // Update the dialect to order first to be the dialect at the end of the + // current grouping. This seeks to allow larger dialect groupings across + // byte boundaries. + dialectToOrderFirst = byteSubRange.back()->dialect->number; + + // If the data range is now empty, we are done. + if (iterRange.empty()) + break; + } + + // Assign the entry numbers based on the sort order. + for (auto &entry : llvm::enumerate(range)) + entry.value()->number = entry.index(); +} + +IRNumberingState::IRNumberingState(Operation *op) { + // Number the root operation. + number(*op); + + // Push all of the regions of the root operation onto the worklist. + SmallVector, 8> numberContext; + for (Region ®ion : op->getRegions()) + numberContext.emplace_back(®ion, nextValueID); + + // Iteratively process each of the nested regions. + while (!numberContext.empty()) { + Region *region; + std::tie(region, nextValueID) = numberContext.pop_back_val(); + number(*region); + + // Traverse into nested regions. + for (Operation &op : region->getOps()) { + // Isolated regions don't share value numbers with their parent, so we can + // start numbering these regions at zero. + unsigned opFirstValueID = + op.hasTrait() ? 0 : nextValueID; + for (Region ®ion : op.getRegions()) + numberContext.emplace_back(®ion, opFirstValueID); + } + } + + // Number each of the dialects. For now this is just in the order they were + // found, given that the number of dialects on average is small enough to fit + // within a singly byte (128). If we ever have real world use cases that have + // a huge number of dialects, this could be made more intelligent. + for (auto &it : llvm::enumerate(dialects)) + it.value().second->number = it.index(); + + // Number each of the recorded components within each dialect. + + // First sort by ref count so that the most referenced elements are first. We + // try to bias more heavily used elements to the front. This allows for more + // frequently referenced things to be encoded using smaller varints. + auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) { + return lhs->refCount > rhs->refCount; + }; + llvm::stable_sort(orderedAttrs, sortByRefCountFn); + llvm::stable_sort(orderedOpNames, sortByRefCountFn); + llvm::stable_sort(orderedTypes, sortByRefCountFn); + + // After that, we apply a secondary ordering based on the parent dialect. This + // ordering is applied to sub-sections of the element list defined by how many + // bytes it takes to encode a varint index to that sub-section. This allows + // for more efficiently encoding components of the same dialect (e.g. we only + // have to encode the dialect reference once). + groupByDialectPerByte(llvm::makeMutableArrayRef(orderedAttrs)); + groupByDialectPerByte(llvm::makeMutableArrayRef(orderedOpNames)); + groupByDialectPerByte(llvm::makeMutableArrayRef(orderedTypes)); +} + +void IRNumberingState::number(Attribute attr) { + auto it = attrs.insert({attr, nullptr}); + if (!it.second) { + ++it.first->second->refCount; + return; + } + auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr); + it.first->second = numbering; + orderedAttrs.push_back(numbering); + + // Check for OpaqueAttr, which is a dialect-specific attribute that didn't + // have a registered dialect when it got created. We don't want to encode this + // as the builtin OpaqueAttr, we want to encode it as if the dialect was + // actually loaded. + if (OpaqueAttr opaqueAttr = attr.dyn_cast()) + numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace()); + else + numbering->dialect = &numberDialect(&attr.getDialect()); +} + +void IRNumberingState::number(Block &block) { + // Number the arguments of the block. + for (BlockArgument arg : block.getArguments()) { + valueIDs.try_emplace(arg, nextValueID++); + number(arg.getLoc()); + number(arg.getType()); + } + + // Number the operations in this block. + unsigned &numOps = blockOperationCounts[&block]; + for (Operation &op : block) { + number(op); + ++numOps; + } +} + +auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & { + DialectNumbering *&numbering = registeredDialects[dialect]; + if (!numbering) { + numbering = &numberDialect(dialect->getNamespace()); + numbering->dialect = dialect; + } + return *numbering; +} + +auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & { + DialectNumbering *&numbering = dialects[dialect]; + if (!numbering) { + numbering = new (dialectAllocator.Allocate()) + DialectNumbering(dialect, dialects.size() - 1); + } + return *numbering; +} + +void IRNumberingState::number(Region ®ion) { + if (region.empty()) + return; + size_t firstValueID = nextValueID; + + // Number the blocks within this region. + size_t blockCount = 0; + for (auto &it : llvm::enumerate(region)) { + blockIDs.try_emplace(&it.value(), it.index()); + number(it.value()); + ++blockCount; + } + + // Remember the number of blocks and values in this region. + regionBlockValueCounts.try_emplace(®ion, blockCount, + nextValueID - firstValueID); +} + +void IRNumberingState::number(Operation &op) { + // Number the components of an operation that won't be numbered elsewhere + // (e.g. we don't number operands, regions, or successors here). + number(op.getName()); + for (OpResult result : op.getResults()) { + valueIDs.try_emplace(result, nextValueID++); + number(result.getType()); + } + + // Only number the operation's dictionary if it isn't empty. + DictionaryAttr dictAttr = op.getAttrDictionary(); + if (!dictAttr.empty()) + number(dictAttr); + + number(op.getLoc()); +} + +void IRNumberingState::number(OperationName opName) { + OpNameNumbering *&numbering = opNames[opName]; + if (numbering) { + ++numbering->refCount; + return; + } + DialectNumbering *dialectNumber = nullptr; + if (Dialect *dialect = opName.getDialect()) + dialectNumber = &numberDialect(dialect); + else + dialectNumber = &numberDialect(opName.getDialectNamespace()); + + numbering = + new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName); + orderedOpNames.push_back(numbering); +} + +void IRNumberingState::number(Type type) { + auto it = types.insert({type, nullptr}); + if (!it.second) { + ++it.first->second->refCount; + return; + } + auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type); + it.first->second = numbering; + orderedTypes.push_back(numbering); + + // Check for OpaqueType, which is a dialect-specific type that didn't have a + // registered dialect when it got created. We don't want to encode this as the + // builtin OpaqueType, we want to encode it as if the dialect was actually + // loaded. + if (OpaqueType opaqueType = type.dyn_cast()) + numbering->dialect = &numberDialect(opaqueType.getDialectNamespace()); + else + numbering->dialect = &numberDialect(&type.getDialect()); +} diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(Analysis) add_subdirectory(AsmParser) +add_subdirectory(Bytecode) add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(IR) diff --git a/mlir/lib/Parser/CMakeLists.txt b/mlir/lib/Parser/CMakeLists.txt --- a/mlir/lib/Parser/CMakeLists.txt +++ b/mlir/lib/Parser/CMakeLists.txt @@ -6,5 +6,6 @@ LINK_LIBS PUBLIC MLIRAsmParser + MLIRBytecodeReader MLIRIR ) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -12,6 +12,7 @@ #include "mlir/Parser/Parser.h" #include "mlir/AsmParser/AsmParser.h" +#include "mlir/Bytecode/BytecodeReader.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; @@ -25,6 +26,8 @@ sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0); } + if (isBytecode(*sourceBuf)) + return readBytecodeFile(*sourceBuf, block, config); return parseAsmSourceFile(sourceMgr, block, config); } diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt --- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt @@ -5,6 +5,7 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-opt LINK_LIBS PUBLIC + MLIRBytecodeWriter MLIRPass MLIRParser MLIRSupport diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" @@ -47,7 +48,8 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, SourceMgr &sourceMgr, MLIRContext *context, - PassPipelineFn passManagerSetupFn) { + PassPipelineFn passManagerSetupFn, + bool emitBytecode) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -86,8 +88,12 @@ // Print the output. TimingScope outputTiming = timing.nest("Output"); - module->print(os); - os << '\n'; + if (emitBytecode) { + writeBytecodeToFile(module->getOperation(), os); + } else { + module->print(os); + os << '\n'; + } return success(); } @@ -97,8 +103,8 @@ processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, - llvm::ThreadPool *threadPool) { + bool emitBytecode, PassPipelineFn passManagerSetupFn, + DialectRegistry ®istry, llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -122,7 +128,7 @@ if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passManagerSetupFn); + &context, passManagerSetupFn, emitBytecode); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); @@ -131,7 +137,7 @@ // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passManagerSetupFn); + passManagerSetupFn, emitBytecode); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. @@ -144,7 +150,8 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext) { + bool preloadDialectsInContext, + bool emitBytecode) { // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying @@ -163,8 +170,8 @@ raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passManagerSetupFn, registry, - threadPool); + preloadDialectsInContext, emitBytecode, + passManagerSetupFn, registry, threadPool); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, splitInputFile, /*insertMarkerInOutput=*/true); @@ -176,7 +183,8 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext) { + bool preloadDialectsInContext, + bool emitBytecode) { auto passManagerSetupFn = [&](PassManager &pm) { auto errorHandler = [&](const Twine &msg) { emitError(UnknownLoc::get(pm.getContext())) << msg; @@ -186,7 +194,8 @@ }; return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, registry, splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext); + allowUnregisteredDialects, preloadDialectsInContext, + emitBytecode); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -224,6 +233,10 @@ "show-dialects", cl::desc("Print the list of registered dialects"), cl::init(false)); + static cl::opt emitBytecode( + "emit-bytecode", cl::desc("Emit bytecode when generating output"), + cl::init(false)); + InitLLVM y(argc, argv); // Register any command line options. @@ -268,7 +281,8 @@ if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext))) + allowUnregisteredDialects, preloadDialectsInContext, + emitBytecode))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful. diff --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/general.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt -allow-unregistered-dialect -emit-bytecode %s | mlir-opt -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: "bytecode.test1" +// CHECK-NEXT: "bytecode.empty"() : () -> () +// CHECK-NEXT: "bytecode.attributes"() {attra = 10 : i64, attrb = #bytecode.attr} : () -> () +// CHECK-NEXT: test.graph_region { +// CHECK-NEXT: "bytecode.operands"(%[[RESULTS:.*]]#0, %[[RESULTS]]#1, %[[RESULTS]]#2) : (i32, i64, i32) -> () +// CHECK-NEXT: %[[RESULTS]]:3 = "bytecode.results"() : () -> (i32, i64, i32) +// CHECK-NEXT: } +// CHECK-NEXT: "bytecode.branch"()[^[[BLOCK:.*]]] : () -> () +// CHECK-NEXT: ^[[BLOCK]](%[[ARG0:.*]]: i32, %[[ARG1:.*]]: !bytecode.int, %[[ARG2:.*]]: !pdl.operation): +// CHECK-NEXT: "bytecode.regions"() ({ +// CHECK-NEXT: "bytecode.operands"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (i32, !bytecode.int, !pdl.operation) -> () +// CHECK-NEXT: "bytecode.return"() : () -> () +// CHECK-NEXT: }) : () -> () +// CHECK-NEXT: "bytecode.return"() : () -> () +// CHECK-NEXT: }) : () -> () + +"bytecode.test1"() ({ + "bytecode.empty"() : () -> () + "bytecode.attributes"() {attra = 10, attrb = #bytecode.attr} : () -> () + test.graph_region { + "bytecode.operands"(%results#0, %results#1, %results#2) : (i32, i64, i32) -> () + %results:3 = "bytecode.results"() : () -> (i32, i64, i32) + } + "bytecode.branch"()[^secondBlock] : () -> () + +^secondBlock(%arg1: i32, %arg2: !bytecode.int, %arg3: !pdl.operation): + "bytecode.regions"() ({ + "bytecode.operands"(%arg1, %arg2, %arg3) : (i32, !bytecode.int, !pdl.operation) -> () + "bytecode.return"() : () -> () + }) : () -> () + "bytecode.return"() : () -> () +}) : () -> () diff --git a/mlir/test/Bytecode/invalid/invalid-attr_type_offset_section-large_offset.mlirbc b/mlir/test/Bytecode/invalid/invalid-attr_type_offset_section-large_offset.mlirbc new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@&1 | FileCheck %s --check-prefix=DIALECT_STR +// DIALECT_STR: invalid string index: 15 + +//===--------------------------------------------------------------------===// +// OpName +//===--------------------------------------------------------------------===// + +// RUN: not mlir-opt %S/invalid-dialect_section-opname_dialect.mlirbc 2>&1 | FileCheck %s --check-prefix=OPNAME_DIALECT +// OPNAME_DIALECT: invalid dialect index: 7 + +// RUN: not mlir-opt %S/invalid-dialect_section-opname_string.mlirbc 2>&1 | FileCheck %s --check-prefix=OPNAME_STR +// OPNAME_STR: invalid string index: 31 diff --git a/mlir/test/Bytecode/invalid/invalid-ir_section-attr.mlirbc b/mlir/test/Bytecode/invalid/invalid-ir_section-attr.mlirbc new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@&1 | FileCheck %s --check-prefix=OP_NAME +// OP_NAME: invalid operation name index: 14 + +//===--------------------------------------------------------------------===// +// Loc + +// RUN: not mlir-opt %S/invalid-ir_section-loc.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=OP_LOC +// OP_LOC: expected attribute of type: {{.*}}, but got: {attra = 10 : i64, attrb = #bytecode.attr} + +//===--------------------------------------------------------------------===// +// Attr + +// RUN: not mlir-opt %S/invalid-ir_section-attr.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=OP_ATTR +// OP_ATTR: expected attribute of type: {{.*}}, but got: loc(unknown) + +//===--------------------------------------------------------------------===// +// Operands + +// RUN: not mlir-opt %S/invalid-ir_section-operands.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=OP_OPERANDS +// OP_OPERANDS: invalid value index: 6 + +// RUN: not mlir-opt %S/invalid-ir_section-forwardref.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=FORWARD_REF +// FORWARD_REF: not all forward unresolved forward operand references + +//===--------------------------------------------------------------------===// +// Results + +// RUN: not mlir-opt %S/invalid-ir_section-results.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=OP_RESULTS +// OP_RESULTS: value index range was outside of the expected range for the parent region, got [3, 6), but the maximum index was 2 + +//===--------------------------------------------------------------------===// +// Successors + +// RUN: not mlir-opt %S/invalid-ir_section-successors.mlirbc -allow-unregistered-dialect 2>&1 | FileCheck %s --check-prefix=OP_SUCCESSORS +// OP_SUCCESSORS: invalid successor index: 3 diff --git a/mlir/test/Bytecode/invalid/invalid-string_section-count.mlirbc b/mlir/test/Bytecode/invalid/invalid-string_section-count.mlirbc new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@&1 | FileCheck %s --check-prefix=COUNT +// COUNT: attempting to parse a byte at the end of the bytecode + +//===--------------------------------------------------------------------===// +// Invalid String +//===--------------------------------------------------------------------===// + +// RUN: not mlir-opt %S/invalid-string_section-no_string.mlirbc 2>&1 | FileCheck %s --check-prefix=NO_STRING +// NO_STRING: attempting to parse a byte at the end of the bytecode + +// RUN: not mlir-opt %S/invalid-string_section-large_string.mlirbc 2>&1 | FileCheck %s --check-prefix=LARGE_STRING +// LARGE_STRING: string size exceeds the available data size + +//===--------------------------------------------------------------------===// +// Trailing data +//===--------------------------------------------------------------------===// + +// RUN: not mlir-opt %S/invalid-string_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA +// TRAILING_DATA: unexpected trailing data between the offsets for strings and their data diff --git a/mlir/test/Bytecode/invalid/invalid-structure-producer.mlirbc b/mlir/test/Bytecode/invalid/invalid-structure-producer.mlirbc new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/invalid/invalid-structure-producer.mlirbc @@ -0,0 +1 @@ +MLïRÿ \ No newline at end of file diff --git a/mlir/test/Bytecode/invalid/invalid-structure-section-duplicate.mlirbc b/mlir/test/Bytecode/invalid/invalid-structure-section-duplicate.mlirbc new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@&1 | FileCheck %s --check-prefix=VERSION +// VERSION: bytecode version 127 is newer than the current version 0 + +//===--------------------------------------------------------------------===// +// Producer +//===--------------------------------------------------------------------===// + +// RUN: not mlir-opt %S/invalid-structure-producer.mlirbc 2>&1 | FileCheck %s --check-prefix=PRODUCER +// PRODUCER: malformed null-terminated string, no null character found + +//===--------------------------------------------------------------------===// +// Section +//===--------------------------------------------------------------------===// + +//===--------------------------------------------------------------------===// +// Missing + +// RUN: not mlir-opt %S/invalid-structure-section-missing.mlirbc 2>&1 | FileCheck %s --check-prefix=SECTION_MISSING +// SECTION_MISSING: missing data for top-level section: String (0) + +//===--------------------------------------------------------------------===// +// ID + +// RUN: not mlir-opt %S/invalid-structure-section-id-unknown.mlirbc 2>&1 | FileCheck %s --check-prefix=SECTION_ID_UNKNOWN +// SECTION_ID_UNKNOWN: invalid section ID: 255 + +//===--------------------------------------------------------------------===// +// Length + +// RUN: not mlir-opt %S/invalid-structure-section-length.mlirbc 2>&1 | FileCheck %s --check-prefix=SECTION_LENGTH +// SECTION_LENGTH: attempting to parse a byte at the end of the bytecode + +//===--------------------------------------------------------------------===// +// Duplicate + +// RUN: not mlir-opt %S/invalid-structure-section-duplicate.mlirbc 2>&1 | FileCheck %s --check-prefix=SECTION_DUPLICATE +// SECTION_DUPLICATE: duplicate top-level section: String (0) diff --git a/mlir/test/Bytecode/invalid/invalid_attr_type_offset_section.mlir b/mlir/test/Bytecode/invalid/invalid_attr_type_offset_section.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/invalid/invalid_attr_type_offset_section.mlir @@ -0,0 +1,16 @@ +// This file contains various failure test cases related to the structure of +// the attribute/type offset section. + +//===--------------------------------------------------------------------===// +// Offset +//===--------------------------------------------------------------------===// + +// RUN: not mlir-opt %S/invalid-attr_type_offset_section-large_offset.mlirbc 2>&1 | FileCheck %s --check-prefix=LARGE_OFFSET +// LARGE_OFFSET: Attribute or Type entry offset points past the end of section + +//===--------------------------------------------------------------------===// +// Trailing Data +//===--------------------------------------------------------------------===// + +// RUN: not mlir-opt %S/invalid-attr_type_offset_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA +// TRAILING_DATA: unexpected trailing data in the Attribute/Type offset section diff --git a/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/invalid/invalid_attr_type_section.mlir @@ -0,0 +1,16 @@ +// This file contains various failure test cases related to the structure of +// the attribute/type offset section. + +//===--------------------------------------------------------------------===// +// Index +//===--------------------------------------------------------------------===// + +// RUN: not mlir-opt %S/invalid-attr_type_section-index.mlirbc 2>&1 | FileCheck %s --check-prefix=INDEX +// INDEX: invalid Attribute index: 3 + +//===--------------------------------------------------------------------===// +// Trailing Data +//===--------------------------------------------------------------------===// + +// RUN: not mlir-opt %S/invalid-attr_type_section-trailing_data.mlirbc 2>&1 | FileCheck %s --check-prefix=TRAILING_DATA +// TRAILING_DATA: trailing characters found after Attribute assembly format: trailing