diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
new file mode 100644
--- /dev/null
+++ b/mlir/docs/BytecodeFormat.md
@@ -0,0 +1,317 @@
+# MLIR Bytecode Format
+
+This documents describes the MLIR bytecode format and its encoding.
+
+[TOC]
+
+## Magic Number
+
+MLIR uses the following four-byte magic number to indicate bytecode files:
+
+'\[‘M’8, ‘L’8, ‘ï’8, ‘R’8\]'
+
+## Format Overview
+
+An MLIR Bytecode file is comprised of a byte stream, with a few simple
+structural concepts layered on top.
+
+### Primitives
+
+#### Fixed-Width Integers
+
+```
+ byte ::= `0x00`...`0xFF`
+```
+
+Fixed width integers are unsigned integers of a known byte size. The values are
+stored in little-endian byte order.
+
+TODO: Add larger fixed width integers as necessary.
+
+#### Variable-Width Integers
+
+Variable width integers, or `VarInt`s, provide a compact representation for
+integers. Each encoded VarInt consists of one to nine bytes, which together
+represent a single 64-bit value. The MLIR bytecode utilizes the "PrefixVarInt"
+encoding for VarInts. This encoding is a variant of the
+[LEB128 ("Little-Endian Base 128")](https://en.wikipedia.org/wiki/LEB128)
+encoding, where each byte of the encoding provides up to 7 bits for the value,
+with the remaining bit used to store a tag indicating the number of bytes used
+for the encoding. This means that small unsigned integers (less than 2^7) may be
+stored in one byte, unsigned integers up to 2^14 may be stored in two bytes,
+etc.
+
+The first byte of the encoding includes a length prefix in the low bits. This
+prefix is a bit sequence of '0's followed by a terminal '1', or the end of the
+byte. The number of '0' bits indicate the number of _additional_ bytes, not
+including the prefix byte, used to encode the value. All of the remaining bits
+in the first byte, along with all of the bits in the additional bytes, provide
+the value of the integer. Below are the various possible encodings of the prefix
+byte:
+
+```
+xxxxxxx1: 7 value bits, the encoding uses 1 byte
+xxxxxx10: 14 value bits, the encoding uses 2 bytes
+xxxxx100: 21 value bits, the encoding uses 3 bytes
+xxxx1000: 28 value bits, the encoding uses 4 bytes
+xxx10000: 35 value bits, the encoding uses 5 bytes
+xx100000: 42 value bits, the encoding uses 6 bytes
+x1000000: 49 value bits, the encoding uses 7 bytes
+10000000: 56 value bits, the encoding uses 8 bytes
+00000000: 64 value bits, the encoding uses 9 bytes
+```
+
+#### Strings
+
+Strings are blobs of characters with an associated length.
+
+### Sections
+
+```
+section {
+ id: byte
+ length: varint
+}
+```
+
+Sections are a mechanism for grouping data within the bytecode. The enable
+delayed processing, which is useful for out-of-order processing of data,
+lazy-loading, and more. Each section contains a Section ID and a length (which
+allowing for skipping over the section).
+
+TODO: Sections should also carry an optional alignment. Add this when necessary.
+
+## MLIR Encoding
+
+Given the generic structure of MLIR, the bytecode encoding is actually fairly
+simplistic. It effectively maps to the core components of MLIR.
+
+### Top Level Structure
+
+The top-level structure of the bytecode contains the 4-byte "magic number", a
+version number, and a list of sections. Each section is currently only expected
+to appear once within a bytecode file.
+
+```
+bytecode {
+ magic: "MLïR",
+ version: varint,
+ producer: string,
+ sections: section[]
+}
+```
+
+### String Section
+
+```
+strings {
+ numStrings: varint,
+ reverseStringLengths: varint[],
+ stringData: byte[]
+}
+```
+
+The string section contains a table of strings referenced within the bytecode,
+more easily enabling string sharing. This section is encoded first with the
+total number of strings, followed by the sizes of each of the individual strings
+in reverse order. The remaining encoding contains a single blob containing all
+of the strings concatenated together.
+
+### Dialect Section
+
+The dialect section of the bytecode contains all of the dialects referenced
+within the encoded IR, and some information about the components of those
+dialects that were also referenced.
+
+```
+dialect_section {
+ numDialects: varint,
+ dialectNames: varint[],
+ opNames: op_name_group[]
+}
+
+op_name_group {
+ dialect: varint,
+ numOpNames: varint,
+ opNames: varint[]
+}
+```
+
+Dialects are encoded as indexes to the name string within the string section.
+Operation names are encoded in groups by dialect, with each group containing the
+dialect, the number of operation names, and the array of indexes to each name
+within the string section.
+
+### Attribute/Type Sections
+
+Attributes and types are encoded using two [sections](#sections), one section
+(`attr_type_section`) containing the actual encoded representation, and another
+section (`attr_type_offset_section`) containing the offsets of each encoded
+attribute/type into the previous section. This structure allows for attributes
+and types to always be lazily loaded on demand.
+
+```
+attr_type_section {
+ attrs: attribute[],
+ types: type[]
+}
+attr_type_offset_section {
+ numAttrs: varint,
+ numTypes: varint,
+ offsets: attr_type_offset_group[]
+}
+
+attr_type_offset_group {
+ dialect: varint,
+ numElements: varint,
+ offsets: varint[]
+}
+
+enum attrTypeEncoding: byte {
+ kAsmForm = 0
+}
+
+attribute {
+ code: attrTypeEncoding,
+ encoding: ...
+}
+type {
+ code: attrTypeEncoding,
+ encoding: ...
+}
+```
+
+Each `offset` in the `attr_type_offset_section` above is the size of the
+encoding for the attribute or type. We avoid using the direct offset into the
+`attr_type_section`, as a smaller relative offsets provides more effective
+compression. Attributes and types are grouped by dialect, with each
+`attr_type_offset_group` in the offset section containing the corresponding
+parent dialect, number of elements, and offsets for each element within the
+group.
+
+#### Attribute/Type Encodings
+
+In the previous section, the forms of `attribute` and `type` both start with a
+`code` field. This field indicates how the attribute or type was encoded. In the
+abstract, an attribute/type is encoded in one of two possible ways: via its
+assembly format, or via a custom dialect defined encoding.
+
+##### Assembly Format Fallback
+
+In the case where a dialect does not define a method for encoding the attribute
+or type, the textual assembly format of that attribute or type is used as a
+fallback. For example, a type of `!bytecode.type` would be encoded as the null
+terminated string "!bytecode.type". This ensures that every attribute and type
+may be encoded, even if the owning dialect has not yet opted in to a more
+efficient serialization.
+
+TODO: We shouldn't redundantly encode the dialect name here, we should use a
+reference to the parent dialect instead.
+
+##### Dialect Defined Encoding
+
+TODO: This is not yet supported.
+
+### IR Section
+
+The IR section contains the encoded form of operations within the bytecode.
+
+#### Operation Encoding
+
+```
+op {
+ name: varint,
+ encodingMask: byte,
+ location: varint,
+
+ attrDict: varint?,
+
+ numResults: varint?,
+ resultTypes: varint[],
+
+ numOperands: varint?,
+ operands: varint[],
+
+ numSuccessors: varint?,
+ successors: varint[],
+
+ regionEncoding: varint?, // (numRegions << 1) | (isIsolatedFromAbove)
+ regions: region[]
+}
+```
+
+The encoding of an operation is important because this is generally the most
+commonly appearing structure in the bytecode. A single encoding is used for
+every type of operation. Given this prevelance, many of the fields of an
+operation are optional. The `encodingMask` field is a bitmask which indicates
+which of the components of the operation are present.
+
+##### Location
+
+The location is encoded as the index of the location within the attribute table.
+
+##### Attributes
+
+If the operation has attribues, the index of the operation attribute dictionary
+within the attribute table is encoded.
+
+##### Results
+
+If the operation has results, the number of results and the indexes of the
+result types within the type table are encoded.
+
+##### Operands
+
+If the operation has operands, the number of operands and the value index of
+each operand is encoded. This value index is the relative ordering of the
+definition of that value from the start of the first ancestor isolated region.
+
+##### Successors
+
+If the operation has successors, the number of successors and the indexes of the
+successor blocks within the parent region are encoded.
+
+##### Regions
+
+If the operation has regions, the number of regions and if the regions are
+isolated from above are encoded together in a single varint. Afterwards, each
+region is encoded inline.
+
+#### Region Encoding
+
+```
+region {
+ numBlocks: varint,
+
+ numValues: varint?,
+ blocks: block[]
+}
+```
+
+A region is encoded first with the number of blocks within. If the region is
+non-empty, the number of values defined directly within the region are encoded,
+followed by the blocks of the region.
+
+#### Block Encoding
+
+```
+block {
+ encoding: varint, // (numOps << 1) | (hasBlockArgs)
+ arguments: block_arguments?, // Optional based on encoding
+ ops : op[]
+}
+
+block_arguments {
+ numArgs: varint?,
+ args: block_argument[]
+}
+
+block_argument {
+ typeIndex: varint,
+ location: varint
+}
+```
+
+A block is encoded with an array of operations and block arguments. The first
+field is an encoding that combines the number of operations in the block, with a
+flag indicating if the block has arguments.
diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h
new file mode 100644
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeReader.h
@@ -0,0 +1,34 @@
+//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines interfaces to read MLIR bytecode files/streams.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEREADER_H
+#define MLIR_BYTECODE_BYTECODEREADER_H
+
+#include "mlir/IR/AsmState.h"
+#include "mlir/Support/LLVM.h"
+
+namespace llvm {
+class MemoryBufferRef;
+} // namespace llvm
+
+namespace mlir {
+/// Returns true if the given buffer starts with the magic bytes that signal
+/// MLIR bytecode.
+bool isBytecode(llvm::MemoryBufferRef buffer);
+
+/// Read the operations defined within the given memory buffer, containing MLIR
+/// bytecode, into the provided block.
+LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
+ const ParserConfig &config);
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEREADER_H
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
new file mode 100644
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -0,0 +1,36 @@
+//===- BytecodeWriter.h - MLIR Bytecode Writer ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines interfaces to write MLIR bytecode files/streams.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEWRITER_H
+#define MLIR_BYTECODE_BYTECODEWRITER_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+class Operation;
+
+//===----------------------------------------------------------------------===//
+// Entry Points
+//===----------------------------------------------------------------------===//
+
+/// Write the bytecode for the given operation to the provided output stream.
+/// For streams where it matters, the given stream should be in "binary" mode.
+/// `producer` is an optional string that can be used to identify the producer
+/// of the bytecode when reading. It has no functional effect on the bytecode
+/// serialization.
+void writeBytecodeToFile(Operation *op, raw_ostream &os,
+ StringRef producer = "MLIR" LLVM_VERSION_STRING);
+
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEWRITER_H
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -642,11 +642,11 @@
OperationState(Location location, OperationName name);
OperationState(Location location, OperationName name, ValueRange operands,
- TypeRange types, ArrayRef attributes,
+ TypeRange types, ArrayRef attributes = {},
BlockRange successors = {},
MutableArrayRef> regions = {});
OperationState(Location location, StringRef name, ValueRange operands,
- TypeRange types, ArrayRef attributes,
+ TypeRange types, ArrayRef attributes = {},
BlockRange successors = {},
MutableArrayRef> regions = {});
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -50,13 +50,15 @@
/// - preloadDialectsInContext will trigger the upfront loading of all
/// dialects from the global registry in the MLIRContext. This option is
/// deprecated and will be removed soon.
+/// - emitBytecode will generate bytecode output instead of text.
LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
std::unique_ptr buffer,
const PassPipelineCLParser &passPipeline,
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext = false);
+ bool preloadDialectsInContext = false,
+ bool emitBytecode = false);
/// Support a callback to setup the pass manager.
/// - passManagerSetupFn is the callback invoked to setup the pass manager to
@@ -67,7 +69,8 @@
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext = false);
+ bool preloadDialectsInContext = false,
+ bool emitBytecode = false);
/// Implementation for tools like `mlir-opt`.
/// - toolName is used for the header displayed by `--help`.
diff --git a/mlir/lib/Bytecode/CMakeLists.txt b/mlir/lib/Bytecode/CMakeLists.txt
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(Reader)
+add_subdirectory(Writer)
diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Encoding.h
@@ -0,0 +1,91 @@
+//===- Encoding.h - MLIR binary format encoding information -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines enum values describing the structure of MLIR bytecode
+// files.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_BYTECODE_ENCODING_H
+#define LIB_MLIR_BYTECODE_ENCODING_H
+
+#include
+
+namespace mlir {
+namespace bytecode {
+//===----------------------------------------------------------------------===//
+// General constants
+//===----------------------------------------------------------------------===//
+
+enum {
+ /// The current bytecode version.
+ kVersion = 0,
+};
+
+//===----------------------------------------------------------------------===//
+// Sections
+//===----------------------------------------------------------------------===//
+
+namespace Section {
+enum ID : uint8_t {
+ /// This section contains strings referenced within the bytecode.
+ kString = 0,
+
+ /// This section contains the dialects referenced within an IR module.
+ kDialect = 1,
+
+ /// This section contains the attributes and types referenced within an IR
+ /// module.
+ kAttrType = 2,
+
+ /// This section contains the offsets for the attribute and types within the
+ /// AttrType section.
+ kAttrTypeOffset = 3,
+
+ /// This section contains the list of operations serialized into the bytecode,
+ /// and their nested regions/operations.
+ kIR = 4,
+
+ /// The total number of section types.
+ kNumSections = 5,
+};
+} // namespace Section
+
+//===----------------------------------------------------------------------===//
+// AttrType Section
+//===----------------------------------------------------------------------===//
+
+enum class AttrTypeCode : uint8_t {
+ /// This code represents an attribute or type represented in the textual
+ /// assembly format.
+ kAsmForm,
+};
+
+//===----------------------------------------------------------------------===//
+// IR Section
+//===----------------------------------------------------------------------===//
+
+/// This enum represents a mask of all of the potential components of an
+/// operation. This mask is used when encoding an operation to indicate which
+/// components are present in the bytecode.
+namespace OpEncodingMask {
+enum : uint8_t {
+ // clang-format off
+ kHasAttrs = 0b00000001,
+ kHasResults = 0b00000010,
+ kHasOperands = 0b00000100,
+ kHasSuccessors = 0b00001000,
+ kHasInlineRegions = 0b00010000,
+ // clang-format on
+};
+} // namespace OpEncodingMask
+
+} // namespace bytecode
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -0,0 +1,1200 @@
+//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+// TODO: Support for big-endian architectures.
+// TODO: Properly preserve use lists of values.
+
+#include "mlir/Bytecode/BytecodeReader.h"
+#include "../Encoding.h"
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Verifier.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Support/MemoryBufferRef.h"
+#include "llvm/Support/SaveAndRestore.h"
+
+#define DEBUG_TYPE "mlir-bytecode-reader"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// EncodingReader
+//===----------------------------------------------------------------------===//
+
+namespace {
+class EncodingReader {
+public:
+ explicit EncodingReader(ArrayRef contents, Location fileLoc)
+ : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {}
+ explicit EncodingReader(StringRef contents, Location fileLoc)
+ : EncodingReader({reinterpret_cast(contents.data()),
+ contents.size()},
+ fileLoc) {}
+
+ /// Returns true if the entire section has been read.
+ bool empty() const { return dataIt == dataEnd; }
+
+ /// Returns the remaining size of the bytecode.
+ size_t size() const { return dataEnd - dataIt; }
+
+ /// Emit an error using the given arguments.
+ template
+ LogicalResult emitError(Args &&...args) const {
+ return ::emitError(fileLoc).append(std::forward(args)...);
+ }
+
+ /// Parse a single byte from the stream.
+ template
+ LogicalResult parseByte(T &value) {
+ if (empty())
+ return emitError("attempting to parse a byte at the end of the bytecode");
+ value = static_cast(*dataIt++);
+ return success();
+ }
+ /// Parse a range of bytes of 'length' into the given result.
+ LogicalResult parseBytes(size_t length, ArrayRef &result) {
+ if (length > size()) {
+ return emitError("attempting to parse ", length, " bytes when only ",
+ size(), " remain");
+ }
+ result = {dataIt, length};
+ dataIt += length;
+ return success();
+ }
+ /// Parse a range of bytes of 'length' into the given result, which can be
+ /// assumed to be large enough to hold `length`.
+ LogicalResult parseBytes(size_t length, uint8_t *result) {
+ if (length > size()) {
+ return emitError("attempting to parse ", length, " bytes when only ",
+ size(), " remain");
+ }
+ memcpy(result, dataIt, length);
+ dataIt += length;
+ return success();
+ }
+
+ /// Parse a variable length encoded integer from the byte stream. The first
+ /// encoded byte contains a prefix in the low bits indicating the encoded
+ /// length of the value. This length prefix is a bit sequence of '0's followed
+ /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
+ /// (not including the prefix byte). All remaining bits in the first byte,
+ /// along with all of the bits in additional bytes, provide the value of the
+ /// integer encoded in little-endian order.
+ LogicalResult parseVarInt(uint64_t &result) {
+ // Parse the first byte of the encoding, which contains the length prefix.
+ if (failed(parseByte(result)))
+ return failure();
+
+ // Handle the overwhelmingly common case where the value is stored in a
+ // single byte. In this case, the first bit is the `1` marker bit.
+ if (LLVM_LIKELY(result & 1)) {
+ result >>= 1;
+ return success();
+ }
+
+ // Handle the overwhelming uncommon case where the value required all 8
+ // bytes (i.e. a really really big number). In this case, the marker byte is
+ // all zeros: `00000000`.
+ if (LLVM_UNLIKELY(result == 0))
+ return parseBytes(sizeof(result), reinterpret_cast(&result));
+ return parseMultiByteVarInt(result);
+ }
+
+ /// Parse a variable length encoded integer whose low bit is used to encode an
+ /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`.
+ LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) {
+ if (failed(parseVarInt(result)))
+ return failure();
+ flag = result & 1;
+ result >>= 1;
+ return success();
+ }
+
+ /// Skip the first `length` bytes within the reader.
+ LogicalResult skipBytes(size_t length) {
+ if (length > size()) {
+ return emitError("attempting to skip ", length, " bytes when only ",
+ size(), " remain");
+ }
+ dataIt += length;
+ return success();
+ }
+
+ /// Parse a null-terminated string into `result` (without including the NUL
+ /// terminator).
+ LogicalResult parseNullTerminatedString(StringRef &result) {
+ const char *startIt = (const char *)dataIt;
+ const char *nulIt = (const char *)memchr(startIt, 0, size());
+ if (!nulIt)
+ return emitError("malformed null-terminated string, no NUL found");
+
+ result = StringRef(startIt, nulIt - startIt);
+ dataIt = (const uint8_t *)nulIt + 1;
+ return success();
+ }
+
+ /// Parse a section header, placing the kind of section in `sectionID` and the
+ /// contents of the section in `sectionData`.
+ LogicalResult parseSection(bytecode::Section::ID §ionID,
+ ArrayRef §ionData) {
+ size_t length;
+ if (failed(parseByte(sectionID)) || failed(parseVarInt(length)))
+ return failure();
+ if (sectionID >= bytecode::Section::kNumSections)
+ return emitError("invalid section ID: ", sectionID);
+
+ // Parse the actua section data now that we have its length.
+ return parseBytes(length, sectionData);
+ }
+
+private:
+ /// Parse a variable length encoded integer from the byte stream. This method
+ /// is a fallback when the number of bytes used to encode the value is greater
+ /// than 1, but less than the max (9). The provided `result` value can be
+ /// assumed to already contain the first byte of the value.
+ /// NOTE: This method is marked noinline to avoid pessimizing the common case
+ /// of single byte encoding.
+ LLVM_ATTRIBUTE_NOINLINE LogicalResult parseMultiByteVarInt(uint64_t &result) {
+ // Count the number of trailing zeros in the marker byte, this indicates the
+ // number of trailing bytes that are part of the value. We use `uint32_t`
+ // here because we only care about the first byte, and so that be actually
+ // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
+ // implementation).
+ uint32_t numBytes =
+ llvm::countTrailingZeros(result, llvm::ZB_Undefined);
+ assert(numBytes > 0 && numBytes <= 7 &&
+ "unexpected number of trailing zeros in varint encoding");
+
+ // Parse in the remaining bytes of the value.
+ if (failed(parseBytes(numBytes, reinterpret_cast(&result) + 1)))
+ return failure();
+
+ // Shift out the low-order bits that were used to mark how the value was
+ // encoded.
+ result >>= (numBytes + 1);
+ return success();
+ }
+
+ /// The current data iterator, and an iterator to the end of the buffer.
+ const uint8_t *dataIt, *dataEnd;
+
+ /// A location for the bytecode used to report errors.
+ Location fileLoc;
+};
+} // namespace
+
+/// Resolve an index into the given entry list. `entry` may either be a
+/// reference, in which case it is assigned to the corresponding value in
+/// `entries`, or a pointer, in which case it is assigned to the address of the
+/// element in `entries`.
+template
+static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries,
+ uint64_t index, T &entry,
+ StringRef entryStr) {
+ if (index >= entries.size())
+ return reader.emitError("invalid ", entryStr, " index: ", index);
+
+ // If the provided entry is a pointer, resolve to the address of the entry.
+ if constexpr (std::is_convertible_v, T>)
+ entry = entries[index];
+ else
+ entry = &entries[index];
+ return success();
+}
+
+/// Parse and resolve an index into the given entry list.
+template
+static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries,
+ T &entry, StringRef entryStr) {
+ uint64_t entryIdx;
+ if (failed(reader.parseVarInt(entryIdx)))
+ return failure();
+ return resolveEntry(reader, entries, entryIdx, entry, entryStr);
+}
+
+//===----------------------------------------------------------------------===//
+// BytecodeDialect
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This struct represents a dialect entry within the bytecode.
+struct BytecodeDialect {
+ /// Load the dialect into the provided context if it hasn't been loaded yet.
+ /// Returns failure if the dialect couldn't be loaded *and* the provided
+ /// context does not allow unregistered dialects. The provided reader is used
+ /// for error emission if necessary.
+ LogicalResult load(EncodingReader &reader, MLIRContext *ctx) {
+ if (dialect)
+ return success();
+ Dialect *loadedDialect = ctx->getOrLoadDialect(name);
+ if (!loadedDialect && !ctx->allowsUnregisteredDialects()) {
+ return reader.emitError(
+ "dialect '", name,
+ "' is unknown. If this is intended, please call "
+ "allowUnregisteredDialects() on the MLIRContext, or use "
+ "-allow-unregistered-dialect with the MLIR tool used.");
+ }
+ dialect = loadedDialect;
+ return success();
+ }
+
+ /// The loaded dialect entry. This field is None if we haven't attempted to
+ /// load, nullptr if we failed to load, otherwise the loaded dialect.
+ Optional dialect;
+
+ /// The name of the dialect.
+ StringRef name;
+};
+
+/// This struct represents an operation name entry within the bytecode.
+struct BytecodeOperationName {
+ BytecodeOperationName(BytecodeDialect *dialect, StringRef name)
+ : dialect(dialect), name(name) {}
+
+ /// The loaded operation name, or None if it hasn't been processed yet.
+ Optional opName;
+
+ /// The dialect that owns this operation name.
+ BytecodeDialect *dialect;
+
+ /// The name of the operation, without the dialect prefix.
+ StringRef name;
+};
+} // namespace
+
+/// Parse a single dialect group encoded in the byte stream.
+static LogicalResult parseDialectGrouping(
+ EncodingReader &reader, MutableArrayRef dialects,
+ function_ref entryCallback) {
+ // Parse the dialect and the number of entries in the group.
+ BytecodeDialect *dialect;
+ if (failed(parseEntry(reader, dialects, dialect, "dialect")))
+ return failure();
+ uint64_t numEntries;
+ if (failed(reader.parseVarInt(numEntries)))
+ return failure();
+
+ for (uint64_t i = 0; i < numEntries; ++i)
+ if (failed(entryCallback(dialect)))
+ return failure();
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute/Type Reader
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides support for reading attribute and type entries from the
+/// bytecode. Attribute and Type entries are read lazily on demand, so we use
+/// this reader to manage when to actually parse them from the bytecode.
+class AttrTypeReader {
+ /// This class represents a single attribute or type entry.
+ template
+ struct Entry {
+ /// The entry, or null if it hasn't been resolved yet.
+ T entry = {};
+ /// The parent dialect of this entry.
+ BytecodeDialect *dialect = nullptr;
+ /// The raw data of this entry in the bytecode.
+ ArrayRef data;
+ };
+ using AttrEntry = Entry;
+ using TypeEntry = Entry;
+
+public:
+ AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {}
+
+ /// Initialize the attribute and type information within the reader.
+ LogicalResult initialize(MutableArrayRef dialects,
+ ArrayRef sectionData,
+ ArrayRef offsetSectionData);
+
+ /// Resolve the attribute or type at the given index. Returns nullptr on
+ /// failure.
+ Attribute resolveAttribute(size_t index) {
+ return resolveEntry(attributes, index, "Attribute");
+ }
+ Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
+
+private:
+ /// Resolve the given entry at `index`.
+ template
+ T resolveEntry(SmallVectorImpl> &entries, size_t index,
+ StringRef entryType);
+
+ /// Parse the value defined within the given reader. `code` indicates how the
+ /// entry was encoded.
+ LogicalResult parseEntry(EncodingReader &reader, bytecode::AttrTypeCode code,
+ Attribute &result);
+ LogicalResult parseEntry(EncodingReader &reader, bytecode::AttrTypeCode code,
+ Type &result);
+
+ /// The set of attribute and type entries.
+ SmallVector attributes;
+ SmallVector types;
+
+ /// A location used for error emission.
+ Location fileLoc;
+};
+} // namespace
+
+LogicalResult
+AttrTypeReader::initialize(MutableArrayRef dialects,
+ ArrayRef sectionData,
+ ArrayRef offsetSectionData) {
+ EncodingReader offsetReader(offsetSectionData, fileLoc);
+
+ // Parse the number of attribute and type entries.
+ uint64_t numAttributes, numTypes;
+ if (failed(offsetReader.parseVarInt(numAttributes)) ||
+ failed(offsetReader.parseVarInt(numTypes)))
+ return failure();
+ attributes.resize(numAttributes);
+ types.resize(numTypes);
+
+ // A functor used to accumulate the offsets for the entries in the given
+ // range.
+ uint64_t currentOffset = 0;
+ auto parseEntries = [&](auto &&range) {
+ size_t currentIndex = 0, endIndex = range.size();
+
+ // Parse an individual entry.
+ auto parseEntryFn = [&](BytecodeDialect *dialect) {
+ uint64_t entrySize;
+ if (failed(offsetReader.parseVarInt(entrySize)))
+ return failure();
+
+ // Verify that the offset is actually valid.
+ if (currentOffset + entrySize > sectionData.size()) {
+ return offsetReader.emitError(
+ "Attribute or Type entry offset points past the end of section");
+ }
+
+ auto &entry = range[currentIndex++];
+ entry.data = sectionData.slice(currentOffset, entrySize);
+ entry.dialect = dialect;
+ currentOffset += entrySize;
+ return success();
+ };
+ while (currentIndex != endIndex)
+ if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn)))
+ return failure();
+ return success();
+ };
+
+ // Process each of the attributes, and then the types.
+ if (failed(parseEntries(attributes)) || failed(parseEntries(types)))
+ return failure();
+
+ // Ensure that we read everything from the section.
+ if (!offsetReader.empty()) {
+ return offsetReader.emitError(
+ "unexpected trailing data in the Attribute/Type offset section");
+ }
+ return success();
+}
+
+template
+T AttrTypeReader::resolveEntry(SmallVectorImpl> &entries, size_t index,
+ StringRef entryType) {
+ if (index >= entries.size()) {
+ emitError(fileLoc) << "invalid " << entryType << "index:" << index;
+ return {};
+ }
+
+ // If the entry has already been resolved, there is nothing left to do.
+ Entry &entry = entries[index];
+ if (entry.entry)
+ return entry.entry;
+
+ // Parse the entry. Each entry starts with a specific code that indicates how
+ // it is represented.
+ EncodingReader reader(entry.data, fileLoc);
+ bytecode::AttrTypeCode code;
+ if (failed(reader.parseByte(code)) ||
+ failed(parseEntry(reader, code, entry.entry)))
+ return T();
+ if (!reader.empty()) {
+ (void)reader.emitError("unexpected trailing bytes after " + entryType +
+ " entry");
+ return T();
+ }
+ return entry.entry;
+}
+
+LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader,
+ bytecode::AttrTypeCode code,
+ Attribute &result) {
+ // Handle the fallback case, where the attribute was encoded using its
+ // assembly format.
+ if (code == bytecode::AttrTypeCode::kAsmForm) {
+ StringRef attrStr;
+ if (failed(reader.parseNullTerminatedString(attrStr)))
+ return failure();
+
+ size_t numRead = 0;
+ if (!(result = parseAttribute(attrStr, fileLoc->getContext(), numRead)))
+ return failure();
+ if (numRead != attrStr.size()) {
+ return reader.emitError(
+ "trailing characters found after Attribute assembly format: ",
+ attrStr.drop_front(numRead));
+ }
+ return success();
+ }
+
+ return reader.emitError("unexpected Attribute encoding: ",
+ static_cast(code));
+}
+
+LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader,
+ bytecode::AttrTypeCode code,
+ Type &result) {
+ // Handle the fallback case, where the type was encoded using its
+ // assembly format.
+ if (code == bytecode::AttrTypeCode::kAsmForm) {
+ StringRef typeStr;
+ if (failed(reader.parseNullTerminatedString(typeStr)))
+ return failure();
+
+ size_t numRead = 0;
+ if (!(result = parseType(typeStr, fileLoc->getContext(), numRead)))
+ return failure();
+ if (numRead != typeStr.size()) {
+ return reader.emitError(
+ "trailing characters found after Type assembly format: " +
+ typeStr.drop_front(numRead));
+ }
+ return success();
+ }
+
+ return reader.emitError("unexpected Type encoding: ",
+ static_cast(code));
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode Reader
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class is used to read a bytecode buffer and translate it into MLIR.
+class BytecodeReader {
+public:
+ BytecodeReader(Location fileLoc, const ParserConfig &config)
+ : config(config), fileLoc(fileLoc), attrTypeReader(fileLoc),
+ // Use the builtin unrealized conversion cast operation to represent
+ // forward references to values that aren't yet defined.
+ forwardRefOpState(UnknownLoc::get(config.getContext()),
+ "builtin.unrealized_conversion_cast", ValueRange(),
+ NoneType::get(config.getContext())) {}
+
+ /// Read the bytecode defined within `buffer` into the given block.
+ LogicalResult read(llvm::MemoryBufferRef buffer, Block *block);
+
+private:
+ /// Return the context for this config.
+ MLIRContext *getContext() const { return config.getContext(); }
+
+ /// Parse the bytecode version.
+ LogicalResult parseVersion(EncodingReader &reader);
+
+ //===--------------------------------------------------------------------===//
+ // Dialect Section
+
+ LogicalResult parseDialectSection(ArrayRef sectionData);
+
+ /// Parse an operation name reference using the given reader.
+ FailureOr parseOpName(EncodingReader &reader);
+
+ //===--------------------------------------------------------------------===//
+ // Attribute/Type Section
+
+ /// Parse an attribute or type using the given reader. Returns nullptr in the
+ /// case of failure.
+ Attribute parseAttribute(EncodingReader &reader);
+ Type parseType(EncodingReader &reader);
+
+ template
+ T parseAttribute(EncodingReader &reader) {
+ if (Attribute attr = parseAttribute(reader)) {
+ if (auto derivedAttr = attr.dyn_cast())
+ return derivedAttr;
+ (void)reader.emitError("expected attribute of type: ",
+ llvm::getTypeName(), ", but got: ", attr);
+ }
+ return T();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // IR Section
+
+ /// This struct represents the current read state of a range of regions. This
+ /// struct is used to enable iterative parsing of regions.
+ struct RegionReadState {
+ RegionReadState(Operation *op, bool isIsolatedFromAbove)
+ : RegionReadState(op->getRegions(), isIsolatedFromAbove) {}
+ RegionReadState(MutableArrayRef regions, bool isIsolatedFromAbove)
+ : curRegion(regions.begin()), endRegion(regions.end()),
+ isIsolatedFromAbove(isIsolatedFromAbove) {}
+
+ /// The current regions being read.
+ MutableArrayRef::iterator curRegion, endRegion;
+
+ /// The number of values defined immediately within this region.
+ unsigned numValues = 0;
+
+ /// The current blocks of the region being read.
+ SmallVector curBlocks;
+ Region::iterator curBlock = {};
+
+ /// The number of operations remaining to be read from the current block
+ /// being read.
+ unsigned numOpsRemaining = 0;
+
+ /// A flag indicating if the regions being read are isolated from above.
+ bool isIsolatedFromAbove = false;
+ };
+
+ LogicalResult parseIRSection(ArrayRef sectionData, Block *block);
+ LogicalResult parseRegions(EncodingReader &reader,
+ std::vector ®ionStack,
+ RegionReadState &readState);
+ FailureOr parseOpWithoutRegions(EncodingReader &reader,
+ RegionReadState &readState,
+ bool &isIsolatedFromAbove);
+
+ LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState);
+ LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState);
+ LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
+
+ //===--------------------------------------------------------------------===//
+ // String Section
+
+ LogicalResult parseStringSection(ArrayRef sectionData);
+
+ /// Parse a shared string from the string section. The shared string is
+ /// encoded using an index to a corresponding string in the string section.
+ LogicalResult parseSharedString(EncodingReader &reader, StringRef &result) {
+ return parseEntry(reader, strings, result, "string");
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Value Processing
+
+ /// Parse an operand reference using the given reader. Returns nullptr in the
+ /// case of failure.
+ Value parseOperand(EncodingReader &reader);
+
+ /// Sequentially define the given value range.
+ LogicalResult defineValues(EncodingReader &reader, ValueRange values);
+
+ /// Create a value to use for a forward reference.
+ Value createForwardRef();
+
+ //===--------------------------------------------------------------------===//
+ // Fields
+
+ /// This class represents a single value scope, in which a value scope is
+ /// delimited by isolated from above regions.
+ struct ValueScope {
+ /// Push a new region state onto this scope, reserving enough values for
+ /// those defined within the current region of the provided state.
+ void push(RegionReadState &readState) {
+ nextValueIDs.push_back(values.size());
+ values.resize(values.size() + readState.numValues);
+ }
+
+ /// Pop the values defined for the current region within the provided region
+ /// state.
+ void pop(RegionReadState &readState) {
+ values.resize(values.size() - readState.numValues);
+ nextValueIDs.pop_back();
+ }
+
+ /// The set of values defined in this scope.
+ std::vector values;
+
+ /// The ID for the next defined value for each region current being
+ /// processed in this scope.
+ SmallVector nextValueIDs;
+ };
+
+ /// The configuration of the parser.
+ const ParserConfig &config;
+
+ /// A location to use when emitting errors.
+ Location fileLoc;
+
+ /// The reader used to process attribute and types within the bytecode.
+ AttrTypeReader attrTypeReader;
+
+ /// The version of the bytecode being read.
+ uint64_t version = 0;
+
+ /// The producer of the bytecode being read.
+ StringRef producer;
+
+ /// The table of IR units referenced within the bytecode file.
+ SmallVector dialects;
+ SmallVector opNames;
+
+ /// The table of strings referenced within the bytecode file.
+ SmallVector strings;
+
+ /// The current set of available IR value scopes.
+ std::vector valueScopes;
+ /// A block containing the set of operations defined to create forward
+ /// references.
+ Block forwardRefOps;
+ /// A block containing previously created, and no longer used, forward
+ /// reference operations.
+ Block openForwardRefOps;
+ /// An operation state used when instantiating forward references.
+ OperationState forwardRefOpState;
+};
+} // namespace
+
+LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
+ EncodingReader reader(buffer.getBuffer(), fileLoc);
+
+ // Skip over the bytecode header, this should have already been checked.
+ if (failed(reader.skipBytes(StringRef("ML\xefR").size())))
+ return failure();
+ // Parse the bytecode version and producer.
+ if (failed(parseVersion(reader)) ||
+ failed(reader.parseNullTerminatedString(producer)))
+ return failure();
+
+ // Add a diagnostic handler that attaches a note that includes the original
+ // producer of the bytecode.
+ ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) {
+ diag.attachNote() << "in bytecode file produced by: " << producer;
+ return failure();
+ });
+
+ // Parse the raw data for each of the top-level sections of the bytecode.
+ Optional> sectionDatas[bytecode::Section::kNumSections];
+ while (!reader.empty()) {
+ // Read the next section from the bytecode.
+ bytecode::Section::ID sectionID;
+ ArrayRef sectionData;
+ if (failed(reader.parseSection(sectionID, sectionData)))
+ return failure();
+
+ // Check for duplicate sections, we only expect one instance of each.
+ if (sectionDatas[sectionID])
+ return reader.emitError("duplicate top-level section ID: ", sectionID);
+ sectionDatas[sectionID] = sectionData;
+ }
+ // Check that all of the sections were found.
+ for (int i = 0; i < bytecode::Section::kNumSections; ++i)
+ if (!sectionDatas[i])
+ return reader.emitError("missing data for top-level section: ", i);
+
+ // Process the string section first.
+ if (failed(parseStringSection(*sectionDatas[bytecode::Section::kString])))
+ return failure();
+
+ // Process the dialect section.
+ if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect])))
+ return failure();
+
+ // Process the attribute and type section.
+ if (failed(attrTypeReader.initialize(
+ dialects, *sectionDatas[bytecode::Section::kAttrType],
+ *sectionDatas[bytecode::Section::kAttrTypeOffset])))
+ return failure();
+
+ // Finally, process the IR section.
+ return parseIRSection(*sectionDatas[bytecode::Section::kIR], block);
+}
+
+LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) {
+ if (failed(reader.parseVarInt(version)))
+ return failure();
+
+ // Validate the bytecode version.
+ if (version < bytecode::kVersion) {
+ return reader.emitError(
+ "bytecode version ", version, " is older than the current version of ",
+ bytecode::kVersion, ", and upgrade is not supported.");
+ }
+ if (version > bytecode::kVersion) {
+ return reader.emitError("bytecode version ", version,
+ " is newer than the current version ",
+ bytecode::kVersion, ".");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Dialect Section
+
+LogicalResult
+BytecodeReader::parseDialectSection(ArrayRef sectionData) {
+ EncodingReader sectionReader(sectionData, fileLoc);
+
+ // Parse the number of dialects in the section.
+ uint64_t numDialects;
+ if (failed(sectionReader.parseVarInt(numDialects)))
+ return failure();
+ dialects.resize(numDialects);
+
+ // Parse each of the dialects.
+ for (uint64_t i = 0; i < numDialects; ++i)
+ if (failed(parseSharedString(sectionReader, dialects[i].name)))
+ return failure();
+
+ // Parse the operation names, which are grouped by dialect.
+ auto parseOpName = [&](BytecodeDialect *dialect) {
+ StringRef opName;
+ if (failed(parseSharedString(sectionReader, opName)))
+ return failure();
+ opNames.emplace_back(dialect, opName);
+ return success();
+ };
+ while (!sectionReader.empty())
+ if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
+ return failure();
+ return success();
+}
+
+FailureOr BytecodeReader::parseOpName(EncodingReader &reader) {
+ BytecodeOperationName *opName = nullptr;
+ if (failed(parseEntry(reader, opNames, opName, "operation name")))
+ return failure();
+
+ // Check to see if this operation name has already been resolved. If we
+ // haven't, load the dialect and build the operation name.
+ if (!opName->opName) {
+ if (failed(opName->dialect->load(reader, getContext())))
+ return failure();
+ opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
+ getContext());
+ }
+ return *opName->opName;
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute/Type Section
+
+Attribute BytecodeReader::parseAttribute(EncodingReader &reader) {
+ uint64_t attrIdx;
+ if (failed(reader.parseVarInt(attrIdx)))
+ return Attribute();
+ return attrTypeReader.resolveAttribute(attrIdx);
+}
+
+Type BytecodeReader::parseType(EncodingReader &reader) {
+ uint64_t typeIdx;
+ if (failed(reader.parseVarInt(typeIdx)))
+ return Type();
+ return attrTypeReader.resolveType(typeIdx);
+}
+
+//===----------------------------------------------------------------------===//
+// IR Section
+
+LogicalResult BytecodeReader::parseIRSection(ArrayRef sectionData,
+ Block *block) {
+ EncodingReader reader(sectionData, fileLoc);
+
+ // A stack of operation regions currently being read from the bytecode.
+ std::vector regionStack;
+
+ // Parse the top-level block using a temporary module operation.
+ OwningOpRef moduleOp = ModuleOp::create(fileLoc);
+ regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true);
+ regionStack.back().curBlocks.push_back(moduleOp->getBody());
+ regionStack.back().curBlock = regionStack.back().curRegion->begin();
+ if (failed(parseBlock(reader, regionStack.back())))
+ return failure();
+ valueScopes.emplace_back(ValueScope());
+ valueScopes.back().push(regionStack.back());
+
+ // Iteratively parse regions until everything has been resolved.
+ while (!regionStack.empty())
+ if (failed(parseRegions(reader, regionStack, regionStack.back())))
+ return failure();
+ if (!forwardRefOps.empty()) {
+ return reader.emitError(
+ "not all forward unresolved forward operand references");
+ }
+
+ // Verify that the parsed operations are valid.
+ if (failed(verify(*moduleOp)))
+ return failure();
+
+ // Splice the parsed operations over to the provided top-level block.
+ auto &parsedOps = moduleOp->getBody()->getOperations();
+ auto &destOps = block->getOperations();
+ destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()),
+ parsedOps, parsedOps.begin(), parsedOps.end());
+ return success();
+}
+
+LogicalResult
+BytecodeReader::parseRegions(EncodingReader &reader,
+ std::vector ®ionStack,
+ RegionReadState &readState) {
+ // Read the regions of this operation.
+ for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
+ // If the current block hasn't been setup yet, parse the header for this
+ // region.
+ if (readState.curBlock == Region::iterator()) {
+ if (failed(parseRegion(reader, readState)))
+ return failure();
+
+ // If the region is empty, there is nothing to more to do.
+ if (readState.curRegion->empty())
+ continue;
+ }
+
+ // Parse the blocks within the region.
+ do {
+ while (readState.numOpsRemaining--) {
+ // Read in the next operation. We don't read its regions directly, we
+ // handle those afterwards as necessary.
+ bool isIsolatedFromAbove = false;
+ FailureOr op =
+ parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
+ if (failed(op))
+ return failure();
+
+ // If the op has regions, add it to the stack for processing.
+ if ((*op)->getNumRegions()) {
+ regionStack.emplace_back(*op, isIsolatedFromAbove);
+
+ // If the op is isolated from above, push a new value scope.
+ if (isIsolatedFromAbove)
+ valueScopes.emplace_back(ValueScope());
+ return success();
+ }
+ }
+
+ // Move to the next block of the region.
+ if (++readState.curBlock == readState.curRegion->end())
+ break;
+ if (failed(parseBlock(reader, readState)))
+ return failure();
+ } while (true);
+
+ // Reset the current block and any values reserved for this region.
+ readState.curBlock = {};
+ valueScopes.back().pop(readState);
+ }
+
+ // When the regions have been fully parsed, pop them off of the read stack. If
+ // the regions were isolated from above, we also pop the last value scope.
+ regionStack.pop_back();
+ if (readState.isIsolatedFromAbove)
+ valueScopes.pop_back();
+ return success();
+}
+
+FailureOr
+BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
+ RegionReadState &readState,
+ bool &isIsolatedFromAbove) {
+ // Parse the name of the operation.
+ FailureOr opName = parseOpName(reader);
+ if (failed(opName))
+ return failure();
+
+ // Parse the operation mask, which indicates which components of the operation
+ // are present.
+ uint8_t opMask;
+ if (failed(reader.parseByte(opMask)))
+ return failure();
+
+ /// Parse the location.
+ LocationAttr opLoc = parseAttribute(reader);
+ if (!opLoc)
+ return failure();
+
+ // With the location and name resolved, we can start building the operation
+ // state.
+ OperationState opState(opLoc, *opName);
+
+ // Parse the attributes of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
+ DictionaryAttr dictAttr = parseAttribute(reader);
+ if (!dictAttr)
+ return failure();
+ opState.attributes = dictAttr;
+ }
+
+ /// Parse the results of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasResults) {
+ uint64_t numResults;
+ if (failed(reader.parseVarInt(numResults)))
+ return failure();
+ opState.types.resize(numResults);
+ for (int i = 0, e = numResults; i < e; ++i)
+ if (!(opState.types[i] = parseType(reader)))
+ return failure();
+ }
+
+ /// Parse the operands of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasOperands) {
+ uint64_t numOperands;
+ if (failed(reader.parseVarInt(numOperands)))
+ return failure();
+ opState.operands.resize(numOperands);
+ for (int i = 0, e = numOperands; i < e; ++i)
+ if (!(opState.operands[i] = parseOperand(reader)))
+ return failure();
+ }
+
+ /// Parse the successors of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasSuccessors) {
+ uint64_t numSuccs;
+ if (failed(reader.parseVarInt(numSuccs)))
+ return failure();
+ opState.successors.resize(numSuccs);
+ for (int i = 0, e = numSuccs; i < e; ++i) {
+ if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i],
+ "successor")))
+ return failure();
+ }
+ }
+
+ /// Parse the regions of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
+ uint64_t numRegions;
+ if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove)))
+ return failure();
+
+ opState.regions.reserve(numRegions);
+ for (int i = 0, e = numRegions; i < e; ++i)
+ opState.regions.push_back(std::make_unique());
+ }
+
+ // Create the operation at the back of the current block.
+ Operation *op = Operation::create(opState);
+ readState.curBlock->push_back(op);
+
+ // If the operation had results, update the value references.
+ if (op->getNumResults() && failed(defineValues(reader, op->getResults())))
+ return failure();
+
+ return op;
+}
+
+LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
+ RegionReadState &readState) {
+ // Parse the number of blocks in the region.
+ uint64_t numBlocks;
+ if (failed(reader.parseVarInt(numBlocks)))
+ return failure();
+
+ // If the region is empty, there is nothing else to do.
+ if (numBlocks == 0)
+ return success();
+
+ // Parse the number of values defined in this region.
+ uint64_t numValues;
+ if (failed(reader.parseVarInt(numValues)))
+ return failure();
+ readState.numValues = numValues;
+
+ // Create the blocks within this region. We do this before processing so that
+ // we can rely on the blocks existing when creating operations.
+ readState.curBlocks.clear();
+ readState.curBlocks.reserve(numBlocks);
+ for (uint64_t i = 0; i < numBlocks; ++i) {
+ readState.curBlocks.push_back(new Block());
+ readState.curRegion->push_back(readState.curBlocks.back());
+ }
+
+ // Prepare the current value scope for this region.
+ valueScopes.back().push(readState);
+
+ // Parse the entry block of the region.
+ readState.curBlock = readState.curRegion->begin();
+ return parseBlock(reader, readState);
+}
+
+LogicalResult BytecodeReader::parseBlock(EncodingReader &reader,
+ RegionReadState &readState) {
+ // Parse the number of operations in the block.
+ uint64_t numOps;
+ if (failed(reader.parseVarInt(numOps)))
+ return failure();
+ // Extract out the low-bit of the operation count, which indicates if the
+ // block has arguments.
+ bool hasArgs = numOps & 1;
+ readState.numOpsRemaining = numOps >> 1;
+
+ // Parse the arguments of the block.
+ if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
+ return failure();
+
+ // We don't parse the operations of the block here, that's done elsewhere.
+ return success();
+}
+
+LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
+ Block *block) {
+ // Parse the value ID for the first argument, and the number of arguments.
+ uint64_t numArgs;
+ if (failed(reader.parseVarInt(numArgs)))
+ return failure();
+
+ SmallVector argTypes;
+ SmallVector argLocs;
+ argTypes.reserve(numArgs);
+ argLocs.reserve(numArgs);
+
+ while (numArgs--) {
+ Type argType = parseType(reader);
+ if (!argType)
+ return failure();
+ LocationAttr argLoc = parseAttribute(reader);
+ if (!argLoc)
+ return failure();
+
+ argTypes.push_back(argType);
+ argLocs.push_back(argLoc);
+ }
+ block->addArguments(argTypes, argLocs);
+ return defineValues(reader, block->getArguments());
+}
+
+//===----------------------------------------------------------------------===//
+// String Section
+
+LogicalResult
+BytecodeReader::parseStringSection(ArrayRef sectionData) {
+ EncodingReader stringReader(sectionData, fileLoc);
+
+ // Parse the number of strings in the section.
+ uint64_t numStrings;
+ if (failed(stringReader.parseVarInt(numStrings)))
+ return failure();
+ strings.resize(numStrings);
+
+ // Parse each of the strings. The sizes of the strings are encoded in reverse
+ // order, so that's the order we populate the table.
+ size_t stringDataEndOffset = sectionData.size();
+ size_t totalStringDataSize = 0;
+ for (StringRef &string : llvm::reverse(strings)) {
+ uint64_t stringSize;
+ if (failed(stringReader.parseVarInt(stringSize)))
+ return failure();
+ if (stringDataEndOffset < stringSize) {
+ return stringReader.emitError(
+ "string size exceeds the available data size");
+ }
+
+ // Extract the string from the data, dropping the null character.
+ size_t stringOffset = stringDataEndOffset - stringSize;
+ string = StringRef(
+ reinterpret_cast(sectionData.data() + stringOffset),
+ stringSize - 1);
+ stringDataEndOffset = stringOffset;
+
+ // Update the total string data size.
+ totalStringDataSize += stringSize;
+ }
+
+ // Check that the only remaining data was for the strings
+ if (stringReader.size() != totalStringDataSize) {
+ return stringReader.emitError("unexpected trailing data between the "
+ "offsets for strings and their data");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Value Processing
+
+Value BytecodeReader::parseOperand(EncodingReader &reader) {
+ std::vector &values = valueScopes.back().values;
+ Value *value = nullptr;
+ if (failed(parseEntry(reader, values, value, "value")))
+ return Value();
+
+ // Create a new forward reference if necessary.
+ if (!*value)
+ *value = createForwardRef();
+ return *value;
+}
+
+LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
+ ValueRange newValues) {
+ ValueScope &valueScope = valueScopes.back();
+ std::vector &values = valueScope.values;
+
+ unsigned &valueID = valueScope.nextValueIDs.back();
+ unsigned valueIDEnd = valueID + newValues.size();
+ if (valueIDEnd > values.size()) {
+ return reader.emitError(
+ "value index range was outside of the expected range for "
+ "the parent region, got [",
+ valueID, ", ", valueIDEnd, "), but the maximum index was ",
+ values.size() - 1);
+ }
+
+ // Assign the values and update any forward references.
+ for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
+ Value newValue = newValues[i];
+
+ // Check to see if a definition for this value already exists.
+ if (Value oldValue = std::exchange(values[valueID], newValue)) {
+ Operation *forwardRefOp = oldValue.getDefiningOp();
+ if (!forwardRefOp || forwardRefOp->getBlock() != &forwardRefOps) {
+ return reader.emitError("value index ", valueID,
+ " was already defined");
+ }
+
+ oldValue.replaceAllUsesWith(newValue);
+ forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
+ }
+ }
+ return success();
+}
+
+Value BytecodeReader::createForwardRef() {
+ // Check for an avaliable existing operation to use. Otherwise, create a new
+ // fake operation to use for the reference.
+ if (!openForwardRefOps.empty()) {
+ Operation *op = &openForwardRefOps.back();
+ op->moveBefore(&forwardRefOps, forwardRefOps.end());
+ } else {
+ forwardRefOps.push_back(Operation::create(forwardRefOpState));
+ }
+ return forwardRefOps.back().getResult(0);
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Points
+//===----------------------------------------------------------------------===//
+
+bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
+ return buffer.getBuffer().startswith("ML\xefR");
+}
+
+LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
+ const ParserConfig &config) {
+ Location sourceFileLoc =
+ FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
+ /*line=*/0, /*column=*/0);
+ if (!isBytecode(buffer)) {
+ return emitError(sourceFileLoc,
+ "input buffer is not an MLIR bytecode file");
+ }
+
+ BytecodeReader reader(sourceFileLoc, config);
+ return reader.read(buffer, block);
+}
diff --git a/mlir/lib/Bytecode/Reader/CMakeLists.txt b/mlir/lib/Bytecode/Reader/CMakeLists.txt
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Reader/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(MLIRBytecodeReader
+ BytecodeReader.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode
+
+ LINK_LIBS PUBLIC
+ MLIRAsmParser
+ MLIRIR
+ MLIRSupport
+ )
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -0,0 +1,519 @@
+//===- BytecodeWriter.cpp - MLIR Bytecode Writer --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "../Encoding.h"
+#include "IRNumbering.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/CachedHashString.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Support/Debug.h"
+#include
+
+#define DEBUG_TYPE "mlir-bytecode-writer"
+
+using namespace mlir;
+using namespace mlir::bytecode::detail;
+
+//===----------------------------------------------------------------------===//
+// EncodingEmitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class functions as the underlying encoding emitter for the bytecode
+/// writer. This class is a bit different compared to other types of encoders;
+/// it does not use a single buffer, but instead may contain several buffers
+/// (some owned by the writer, and some not) that get concatted during the final
+/// emission.
+class EncodingEmitter {
+public:
+ EncodingEmitter() = default;
+ EncodingEmitter(const EncodingEmitter &) = delete;
+ EncodingEmitter &operator=(const EncodingEmitter &) = delete;
+
+ /// Write the current contents to the provided stream.
+ void writeTo(raw_ostream &os) const;
+
+ /// Return the current size of the encoded buffer.
+ size_t size() const { return prevResultSize + currentResult.size(); }
+
+ //===--------------------------------------------------------------------===//
+ // Emission
+ //===--------------------------------------------------------------------===//
+
+ /// Return a raw pointer into the result buffer at the specified offset.
+ uint8_t *getRawPointer(uint64_t offset) {
+ assert(offset < size() && offset >= prevResultSize &&
+ "cannot get pointer to previously emitted data");
+ return currentResult.data() + (offset - prevResultSize);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Integer Emission
+
+ /// Emit a single byte.
+ template
+ void emitByte(T byte) {
+ currentResult.push_back(static_cast(byte));
+ }
+
+ /// Emit a range of bytes.
+ void emitBytes(ArrayRef bytes) {
+ llvm::append_range(currentResult, bytes);
+ }
+
+ /// Emit a variable length integer. The first encoded byte contains a prefix
+ /// in the low bits indicating the encoded length of the value. This length
+ /// prefix is a bit sequence of '0's followed by a '1'. The number of '0' bits
+ /// indicate the number of _additional_ bytes (not including the prefix byte).
+ /// All remaining bits in the first byte, along with all of the bits in
+ /// additional bytes, provide the value of the integer encoded in
+ /// little-endian order.
+ void emitVarInt(uint64_t value) {
+ // In the most common case, the value can be represented in a single byte.
+ // Given how hot this case is, explicitly handle that here.
+ if ((value >> 7) == 0)
+ return emitByte((value << 1) | 0x1);
+ emitMultiByteVarInt(value);
+ }
+
+ /// Emit a variable length integer whose low bit is used to encode the
+ /// provided flag, i.e. encoded as: (value << 1) | (flag ? 1 : 0).
+ void emitVarIntWithFlag(uint64_t value, bool flag) {
+ emitVarInt((value << 1) | (flag ? 1 : 0));
+ }
+
+ //===--------------------------------------------------------------------===//
+ // String Emission
+
+ /// Emit the given string as a nul terminated string.
+ void emitNulTerminatedString(StringRef str) {
+ emitString(str);
+ emitByte(0);
+ }
+
+ /// Emit the given string without a nul terminator.
+ void emitString(StringRef str) {
+ emitBytes({reinterpret_cast(str.data()), str.size()});
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Section Emission
+
+ /// Emit a nested section of the given code, whose contents are encoded in the
+ /// provided emitter.
+ void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) {
+ // Emit the section code and length.
+ emitByte(code);
+ emitVarInt(emitter.size());
+
+ // Push our current buffer and then merge the provided section body into
+ // ours.
+ appendResult(std::move(currentResult));
+ for (std::vector &result : emitter.prevResultStorage)
+ appendResult(std::move(result));
+ appendResult(std::move(emitter.currentResult));
+ }
+
+private:
+ /// Emit the given value using a variable width encoding. This method is a
+ /// fallback when the number of bytes needed to encode the value is greater
+ /// than 1. We mark it noinline here so that the single byte hot path isn't
+ /// pessimized.
+ LLVM_ATTRIBUTE_NOINLINE void emitMultiByteVarInt(uint64_t value);
+
+ /// Append a new result buffer to the current contents.
+ void appendResult(std::vector &&result) {
+ prevResultSize += result.size();
+ prevResultStorage.emplace_back(std::move(result));
+ prevResultList.emplace_back(prevResultStorage.back());
+ }
+
+ /// The result of the emitter currently being built. We refrain from building
+ /// a single buffer to simplify emitting sections, large data, and more. The
+ /// result is thus represented using multiple distinct buffers, some of which
+ /// we own (via prevResultStorage), and some of which are just pointers into
+ /// externally owned buffers.
+ std::vector currentResult;
+ std::vector> prevResultList;
+ std::vector> prevResultStorage;
+
+ /// An up-to-date total size of all of the buffers within `prevResultList`.
+ /// This enables O(1) size checks of the current encoding.
+ size_t prevResultSize = 0;
+};
+
+/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need
+/// to go through an intermediate buffer when interacting with code that wants a
+/// raw_ostream.
+class raw_emitter_ostream : public raw_ostream {
+public:
+ explicit raw_emitter_ostream(EncodingEmitter &emitter) : emitter(emitter) {
+ SetUnbuffered();
+ }
+
+private:
+ void write_impl(const char *ptr, size_t size) override {
+ emitter.emitBytes({reinterpret_cast(ptr), size});
+ }
+ uint64_t current_pos() const override { return emitter.size(); }
+
+ /// The section being emitted to.
+ EncodingEmitter &emitter;
+};
+} // namespace
+
+void EncodingEmitter::writeTo(raw_ostream &os) const {
+ for (auto &prevResult : prevResultList)
+ os.write((const char *)prevResult.data(), prevResult.size());
+ os.write((const char *)currentResult.data(), currentResult.size());
+}
+
+void EncodingEmitter::emitMultiByteVarInt(uint64_t value) {
+ // Compute the number of bytes needed to encode the value. Each byte can hold
+ // up to 7-bits of data. We only check up to the number of bits we can encode
+ // in the first byte (8).
+ uint64_t it = value >> 7;
+ for (size_t numBytes = 2; numBytes < 9; ++numBytes) {
+ if (LLVM_LIKELY(it >>= 7) == 0) {
+ uint64_t encodedValue = (value << 1) | 0x1;
+ encodedValue <<= (numBytes - 1);
+ emitBytes({reinterpret_cast(&encodedValue), numBytes});
+ return;
+ }
+ }
+
+ // If the value is too large to encode in a single byte, emit a special all
+ // zero marker byte and splat the value directly.
+ emitByte(0);
+ emitBytes({reinterpret_cast(&value), sizeof(value)});
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode Writer
+//===----------------------------------------------------------------------===//
+
+namespace {
+class BytecodeWriter {
+public:
+ BytecodeWriter(Operation *op) : numberingState(op) {}
+
+ /// Write the bytecode for the given root operation.
+ void write(Operation *rootOp, raw_ostream &os, StringRef producer);
+
+private:
+ //===--------------------------------------------------------------------===//
+ // Dialects
+
+ void writeDialectSection(EncodingEmitter &emitter);
+
+ //===--------------------------------------------------------------------===//
+ // Attributes and Types
+
+ void writeAttrTypeSection(EncodingEmitter &emitter);
+
+ //===--------------------------------------------------------------------===//
+ // Operations
+
+ void writeBlock(EncodingEmitter &emitter, Block *block);
+ void writeOp(EncodingEmitter &emitter, Operation *op);
+ void writeRegion(EncodingEmitter &emitter, Region *region);
+ void writeIRSection(EncodingEmitter &emitter, Operation *op);
+
+ //===--------------------------------------------------------------------===//
+ // Strings
+
+ void writeStringSection(EncodingEmitter &emitter);
+
+ /// Get the number for the given shared string, that is contained within the
+ /// string section.
+ size_t getSharedStringNumber(StringRef str);
+
+ //===--------------------------------------------------------------------===//
+ // Fields
+
+ /// The IR numbering state generated for the root operation.
+ IRNumberingState numberingState;
+
+ /// A set of strings referenced within the bytecode. The value of the map is
+ /// unused.
+ llvm::MapVector strings;
+};
+} // namespace
+
+void BytecodeWriter::write(Operation *rootOp, raw_ostream &os,
+ StringRef producer) {
+ EncodingEmitter emitter;
+
+ // Emit the bytecode file header. This is how we identify the output as a
+ // bytecode file.
+ emitter.emitString("ML\xefR");
+
+ // Emit the bytecode version.
+ emitter.emitVarInt(bytecode::kVersion);
+
+ // Emit the producer.
+ emitter.emitNulTerminatedString(producer);
+
+ // Emit the dialect section.
+ writeDialectSection(emitter);
+
+ // Emit the attributes and types section.
+ writeAttrTypeSection(emitter);
+
+ // Emit the IR section.
+ writeIRSection(emitter, rootOp);
+
+ // Emit the string section.
+ writeStringSection(emitter);
+
+ // Write the generated bytecode to the provided output stream.
+ emitter.writeTo(os);
+}
+
+//===----------------------------------------------------------------------===//
+// Dialects
+
+/// Write the given entries in contiguous groups with the same parent dialect.
+/// Each dialect sub-group is encoded with the parent dialect and number of
+/// elements, followed by the encoding for the entries. The given callback is
+/// invoked to encode each individual entry.
+template
+static void writeDialectGrouping(EncodingEmitter &emitter, EntriesT &&entries,
+ EntryCallbackT &&callback) {
+ for (auto it = entries.begin(), e = entries.end(); it != e;) {
+ auto groupStart = it++;
+
+ // Find the end of the group that shares the same parent dialect.
+ DialectNumbering *currentDialect = groupStart->dialect;
+ it = std::find_if(it, e, [&](const auto &entry) {
+ return entry.dialect != currentDialect;
+ });
+
+ // Emit the dialect and number of elements.
+ emitter.emitVarInt(currentDialect->number);
+ emitter.emitVarInt(std::distance(groupStart, it));
+
+ // Emit the entries within the group.
+ for (auto &entry : llvm::make_range(groupStart, it))
+ callback(entry);
+ }
+}
+
+void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
+ EncodingEmitter dialectEmitter;
+
+ // Emit the referenced dialects.
+ auto dialects = numberingState.getDialects();
+ dialectEmitter.emitVarInt(llvm::size(dialects));
+ for (DialectNumbering &dialect : dialects)
+ dialectEmitter.emitVarInt(getSharedStringNumber(dialect.name));
+
+ // Emit the referenced operation names grouped by dialect.
+ auto emitOpName = [&](OpNameNumbering &name) {
+ dialectEmitter.emitVarInt(getSharedStringNumber(name.name.stripDialect()));
+ };
+ writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);
+
+ emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter));
+}
+
+//===----------------------------------------------------------------------===//
+// Attributes and Types
+
+void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
+ EncodingEmitter attrTypeEmitter;
+ EncodingEmitter offsetEmitter;
+ offsetEmitter.emitVarInt(llvm::size(numberingState.getAttributes()));
+ offsetEmitter.emitVarInt(llvm::size(numberingState.getTypes()));
+
+ // A functor used to emit an attribute or type entry.
+ uint64_t prevOffset = 0;
+ auto emitAttrOrType = [&](auto &entry) {
+ // Emit the entry using the textual format.
+ // TODO: Allow dialects to provide more optimal implementations of attribute
+ // and type encodings.
+ attrTypeEmitter.emitByte(bytecode::AttrTypeCode::kAsmForm);
+ raw_emitter_ostream(attrTypeEmitter) << entry.getValue();
+ attrTypeEmitter.emitByte(0);
+
+ // Record the offset of this entry.
+ uint64_t curOffset = attrTypeEmitter.size();
+ offsetEmitter.emitVarInt(curOffset - prevOffset);
+ prevOffset = curOffset;
+ };
+
+ // Emit the attribute and type entries for each dialect.
+ writeDialectGrouping(offsetEmitter, numberingState.getAttributes(),
+ emitAttrOrType);
+ writeDialectGrouping(offsetEmitter, numberingState.getTypes(),
+ emitAttrOrType);
+
+ // Emit the sections to the stream.
+ emitter.emitSection(bytecode::Section::kAttrTypeOffset,
+ std::move(offsetEmitter));
+ emitter.emitSection(bytecode::Section::kAttrType, std::move(attrTypeEmitter));
+}
+
+//===----------------------------------------------------------------------===//
+// Operations
+
+void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) {
+ ArrayRef args = block->getArguments();
+ bool hasArgs = !args.empty();
+
+ // Emit the number of operations in this block, and if it has arguments. We
+ // use the low bit of the operation count to indicate if the block has
+ // arguments.
+ unsigned numOps = numberingState.getOperationCount(block);
+ emitter.emitVarIntWithFlag(numOps, hasArgs);
+
+ // Emit the arguments of the block.
+ if (hasArgs) {
+ emitter.emitVarInt(args.size());
+ for (BlockArgument arg : args) {
+ emitter.emitVarInt(numberingState.getNumber(arg.getType()));
+ emitter.emitVarInt(numberingState.getNumber(arg.getLoc()));
+ }
+ }
+
+ // Emit the operations within the block.
+ for (Operation &op : *block)
+ writeOp(emitter, &op);
+}
+
+void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
+ emitter.emitVarInt(numberingState.getNumber(op->getName()));
+
+ // Emit a mask for the operation components. We need to fill this in later
+ // (when we actually know what needs to be emitted), so emit a placeholder for
+ // now.
+ uint64_t maskOffset = emitter.size();
+ uint8_t opEncodingMask = 0;
+ emitter.emitByte(0);
+
+ // Emit the location for this operation.
+ emitter.emitVarInt(numberingState.getNumber(op->getLoc()));
+
+ // Emit the attributes of this operation.
+ DictionaryAttr attrs = op->getAttrDictionary();
+ if (!attrs.empty()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs;
+ emitter.emitVarInt(numberingState.getNumber(op->getAttrDictionary()));
+ }
+
+ // Emit the result types of the operation.
+ if (unsigned numResults = op->getNumResults()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasResults;
+ emitter.emitVarInt(numResults);
+ for (Type type : op->getResultTypes())
+ emitter.emitVarInt(numberingState.getNumber(type));
+ }
+
+ // Emit the operands of the operation.
+ if (unsigned numOperands = op->getNumOperands()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasOperands;
+ emitter.emitVarInt(numOperands);
+ for (Value operand : op->getOperands())
+ emitter.emitVarInt(numberingState.getNumber(operand));
+ }
+
+ // Emit the successors of the operation.
+ if (unsigned numSuccessors = op->getNumSuccessors()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasSuccessors;
+ emitter.emitVarInt(numSuccessors);
+ for (Block *successor : op->getSuccessors())
+ emitter.emitVarInt(numberingState.getNumber(successor));
+ }
+
+ // Check for regions.
+ unsigned numRegions = op->getNumRegions();
+ if (numRegions)
+ opEncodingMask |= bytecode::OpEncodingMask::kHasInlineRegions;
+
+ // Update the mask for the operation.
+ *emitter.getRawPointer(maskOffset) = opEncodingMask;
+
+ // With the mask emitted, we can now emit the regions of the operation. We do
+ // this after mask emission to avoid offset complications that may arise by
+ // emitting the regions first (e.g. if the regions are huge, backpatching the
+ // op encoding mask is more annoying).
+ if (numRegions) {
+ bool isIsolatedFromAbove = op->hasTrait();
+ emitter.emitVarIntWithFlag(numRegions, isIsolatedFromAbove);
+
+ for (Region ®ion : op->getRegions())
+ writeRegion(emitter, ®ion);
+ }
+}
+
+void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) {
+ // If the region is empty, we only need to emit the number of blocks (which is
+ // zero).
+ if (region->empty())
+ return emitter.emitVarInt(/*numBlocks*/ 0);
+
+ // Emit the number of blocks and values within the region.
+ unsigned numBlocks, numValues;
+ std::tie(numBlocks, numValues) = numberingState.getBlockValueCount(region);
+ emitter.emitVarInt(numBlocks);
+ emitter.emitVarInt(numValues);
+
+ // Emit the blocks within the region.
+ for (Block &block : *region)
+ writeBlock(emitter, &block);
+}
+
+void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) {
+ EncodingEmitter irEmitter;
+
+ // Write the IR section the same way as a block with no arguments. Note that
+ // the low-bit of the operation count for a block is used to indicate if the
+ // block has arguments, which in this case is always false.
+ irEmitter.emitVarIntWithFlag(/*numOps*/ 1, /*hasArgs*/ false);
+
+ // Emit the operations.
+ writeOp(irEmitter, op);
+
+ emitter.emitSection(bytecode::Section::kIR, std::move(irEmitter));
+}
+
+//===----------------------------------------------------------------------===//
+// Strings
+
+void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) {
+ EncodingEmitter stringEmitter;
+ stringEmitter.emitVarInt(strings.size());
+
+ // Emit the sizes in reverse order, so that we don't need to backpatch an
+ // offset to the string data or have a separate section.
+ for (const auto &it : llvm::reverse(strings))
+ stringEmitter.emitVarInt(it.first.size() + 1);
+ // Emit the string data itself.
+ for (const auto &it : strings)
+ stringEmitter.emitNulTerminatedString(it.first.val());
+
+ emitter.emitSection(bytecode::Section::kString, std::move(stringEmitter));
+}
+
+size_t BytecodeWriter::getSharedStringNumber(StringRef str) {
+ auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()});
+ return it.first->second;
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Points
+//===----------------------------------------------------------------------===//
+
+void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os,
+ StringRef producer) {
+ BytecodeWriter writer(op);
+ writer.write(op, os, producer);
+}
diff --git a/mlir/lib/Bytecode/Writer/CMakeLists.txt b/mlir/lib/Bytecode/Writer/CMakeLists.txt
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(MLIRBytecodeWriter
+ BytecodeWriter.cpp
+ IRNumbering.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSupport
+ )
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -0,0 +1,193 @@
+//===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains various utilities that number IR structures in preparation
+// for bytecode emission.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
+#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
+
+#include "mlir/IR/OperationSupport.h"
+#include "llvm/ADT/MapVector.h"
+
+namespace mlir {
+class BytecodeWriterConfig;
+
+namespace bytecode {
+namespace detail {
+struct DialectNumbering;
+
+//===----------------------------------------------------------------------===//
+// Attribute and Type Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents a numbering entry for an Attribute or Type.
+struct AttrTypeNumbering {
+ AttrTypeNumbering(PointerUnion value) : value(value) {}
+
+ /// The concrete value.
+ PointerUnion value;
+
+ /// The number assigned to this value.
+ unsigned number = 0;
+
+ /// The number of references to this value.
+ unsigned refCount = 1;
+
+ /// The dialect of this value.
+ DialectNumbering *dialect = nullptr;
+};
+struct AttributeNumbering : public AttrTypeNumbering {
+ AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {}
+ Attribute getValue() const { return value.get(); }
+};
+struct TypeNumbering : public AttrTypeNumbering {
+ TypeNumbering(Type value) : AttrTypeNumbering(value) {}
+ Type getValue() const { return value.get(); }
+};
+
+//===----------------------------------------------------------------------===//
+// OpName Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents the numbering entry of an operation name.
+struct OpNameNumbering {
+ OpNameNumbering(DialectNumbering *dialect, OperationName name)
+ : dialect(dialect), name(name) {}
+
+ /// The dialect of this value.
+ DialectNumbering *dialect;
+
+ /// The concrete name.
+ OperationName name;
+
+ /// The number assigned to this name.
+ unsigned number = 0;
+
+ /// The number of references to this name.
+ unsigned refCount = 1;
+};
+
+//===----------------------------------------------------------------------===//
+// Dialect Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents a numbering entry for an Dialect.
+struct DialectNumbering {
+ DialectNumbering(StringRef name, unsigned number)
+ : name(name), number(number) {}
+
+ /// The namespace of the dialect.
+ StringRef name;
+
+ /// The number assigned to the dialect.
+ unsigned number;
+
+ /// The loaded dialect, or nullptr if the dialect isn't loaded.
+ Dialect *dialect = nullptr;
+};
+
+//===----------------------------------------------------------------------===//
+// IRNumberingState
+//===----------------------------------------------------------------------===//
+
+/// This class manages numbering IR entities in preparation of bytecode
+/// emission.
+class IRNumberingState {
+public:
+ IRNumberingState(Operation *op);
+
+ /// Return the numbered dialects.
+ auto getDialects() {
+ return llvm::make_pointee_range(llvm::make_second_range(dialects));
+ }
+ auto getAttributes() { return llvm::make_pointee_range(orderedAttrs); }
+ auto getOpNames() { return llvm::make_pointee_range(orderedOpNames); }
+ auto getTypes() { return llvm::make_pointee_range(orderedTypes); }
+
+ /// Return the number for the given IR unit.
+ unsigned getNumber(Attribute attr) {
+ assert(attrs.count(attr) && "attribute not numbered");
+ return attrs[attr]->number;
+ }
+ unsigned getNumber(Block *block) {
+ assert(blockIDs.count(block) && "block not numbered");
+ return blockIDs[block];
+ }
+ unsigned getNumber(OperationName opName) {
+ assert(opNames.count(opName) && "opName not numbered");
+ return opNames[opName]->number;
+ }
+ unsigned getNumber(Type type) {
+ assert(types.count(type) && "type not numbered");
+ return types[type]->number;
+ }
+ unsigned getNumber(Value value) {
+ assert(valueIDs.count(value) && "value not numbered");
+ return valueIDs[value];
+ }
+
+ /// Return the block and value counts of the given region.
+ std::pair getBlockValueCount(Region *region) {
+ assert(regionBlockValueCounts.count(region) && "value not numbered");
+ return regionBlockValueCounts[region];
+ }
+
+ /// Return the number of operations in the given block.
+ unsigned getOperationCount(Block *block) {
+ assert(blockOperationCounts.count(block) && "block not numbered");
+ return blockOperationCounts[block];
+ }
+
+private:
+ /// Number the given IR unit for bytecode emission.
+ void number(Attribute attr);
+ void number(Block &block);
+ DialectNumbering &numberDialect(Dialect *dialect);
+ DialectNumbering &numberDialect(StringRef dialect);
+ void number(Operation &op);
+ void number(OperationName opName);
+ void number(Region ®ion);
+ void number(Type type);
+
+ /// Mapping from IR to the respective numbering entries.
+ DenseMap attrs;
+ DenseMap opNames;
+ DenseMap types;
+ DenseMap registeredDialects;
+ llvm::MapVector dialects;
+ std::vector orderedAttrs;
+ std::vector orderedOpNames;
+ std::vector orderedTypes;
+
+ /// Allocators used for the various numbering entries.
+ llvm::SpecificBumpPtrAllocator attrAllocator;
+ llvm::SpecificBumpPtrAllocator dialectAllocator;
+ llvm::SpecificBumpPtrAllocator opNameAllocator;
+ llvm::SpecificBumpPtrAllocator typeAllocator;
+
+ /// The value ID for each Block and Value.
+ DenseMap blockIDs;
+ DenseMap valueIDs;
+
+ /// The number of operations in each block.
+ DenseMap blockOperationCounts;
+
+ /// A map from region to the number of blocks and values within that region.
+ DenseMap> regionBlockValueCounts;
+
+ /// The next value ID to assign when numbering.
+ unsigned nextValueID = 0;
+};
+} // namespace detail
+} // namespace bytecode
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -0,0 +1,246 @@
+//===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRNumbering.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+using namespace mlir;
+using namespace mlir::bytecode::detail;
+
+//===----------------------------------------------------------------------===//
+// IR Numbering
+//===----------------------------------------------------------------------===//
+
+IRNumberingState::IRNumberingState(Operation *op) {
+ // Number the root operation.
+ number(*op);
+
+ // Push all of the regions of the root operation onto the worklist.
+ SmallVector, 8> numberContext;
+ for (Region ®ion : op->getRegions())
+ numberContext.emplace_back(®ion, nextValueID);
+
+ // Iteratively process each of the nested regions.
+ while (!numberContext.empty()) {
+ Region *region;
+ std::tie(region, nextValueID) = numberContext.pop_back_val();
+ number(*region);
+
+ // Traverse into nested regions.
+ for (Operation &op : region->getOps()) {
+ // Isolated regions don't share value numbers with their parent, so we can
+ // start numbering these regions at zero.
+ unsigned opFirstValueID =
+ op.hasTrait() ? 0 : nextValueID;
+ for (Region ®ion : op.getRegions())
+ numberContext.emplace_back(®ion, opFirstValueID);
+ }
+ }
+
+ // Number each of the dialects. For now this is just in the order they were
+ // found, given that the number of dialects on average is small enough to fit
+ // within a singly byte (128). If we ever have real world use cases that have
+ // a huge number of dialects, this could be made more intelligent.
+ for (auto &it : llvm::enumerate(dialects))
+ it.value().second->number = it.index();
+
+ // Number each of the recorded components within each dialect.
+
+ // First sort by ref count so that the most referenced elements are first. We
+ // try to bias more heavily used elements to the front. This allows for more
+ // frequently referenced things to be encoded using smaller varints.
+ auto sortByRefCountFn = [](const auto &lhs, const auto &rhs) {
+ return lhs->refCount > rhs->refCount;
+ };
+ llvm::stable_sort(orderedAttrs, sortByRefCountFn);
+ llvm::stable_sort(orderedOpNames, sortByRefCountFn);
+ llvm::stable_sort(orderedTypes, sortByRefCountFn);
+
+ // After that, we apply a secondary ordering based on the parent dialect. This
+ // ordering is applied to sub-sections of the element list define by how many
+ // bytes it takes to encode a varint index to that sub-section. This allows
+ // for more efficiently encoding components of the same dialect (e.g. we only
+ // have to encode the dialect reference once).
+ auto sortByDialect = [](unsigned dialectToOrderFirst, const auto &lhs,
+ const auto &rhs) {
+ if (lhs->dialect->number == dialectToOrderFirst)
+ return rhs->dialect->number != dialectToOrderFirst;
+ return lhs->dialect->number < rhs->dialect->number;
+ };
+
+ // The numbering of the parent dialect of the last byte grouping. These allow
+ // for building larger dialect groupings across byte boundaries.
+ auto groupByDialectPerByte = [&](auto range) {
+ if (range.empty())
+ return;
+
+ unsigned dialectToOrderFirst = 0;
+ size_t elementsInByteGroup = 0;
+ auto iterRange = range;
+ for (unsigned i = 1; i < 9; ++i) {
+ // Update the number of elements in the current byte grouping. Reminder
+ // that varint encodes 7-bits per byte, so that's how we compute the
+ // number of elements in each byte grouping.
+ elementsInByteGroup = (1 << (7 * i)) - elementsInByteGroup;
+
+ // Slice out the sub-set of elements that are in the current byte grouping
+ // to be sorted.
+ auto byteSubRange = iterRange.take_front(elementsInByteGroup);
+ iterRange = iterRange.drop_front(byteSubRange.size());
+
+ // Sort the sub range for this byte.
+ llvm::stable_sort(byteSubRange, [&](const auto &lhs, const auto &rhs) {
+ return sortByDialect(dialectToOrderFirst, lhs, rhs);
+ });
+
+ // Update the dialect to order first to be the dialect at the end of the
+ // current grouping. This seeks to allow larger dialect groupings across
+ // byte boundaries.
+ dialectToOrderFirst = byteSubRange.back()->dialect->number;
+
+ // If the data range is now empty, we are done.
+ if (iterRange.empty())
+ break;
+ }
+
+ // Assign the entry numbers based on the sort order.
+ for (auto &entry : llvm::enumerate(range))
+ entry.value()->number = entry.index();
+ };
+ groupByDialectPerByte(llvm::makeMutableArrayRef(orderedAttrs));
+ groupByDialectPerByte(llvm::makeMutableArrayRef(orderedOpNames));
+ groupByDialectPerByte(llvm::makeMutableArrayRef(orderedTypes));
+}
+
+void IRNumberingState::number(Attribute attr) {
+ auto it = attrs.insert({attr, nullptr});
+ if (!it.second) {
+ ++it.first->second->refCount;
+ return;
+ }
+ auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
+ it.first->second = numbering;
+ orderedAttrs.push_back(numbering);
+
+ // Check for OpaqueAttr, which is a dialect-specific attribute that didn't
+ // have a registered dialect when it got created. We don't want to encode this
+ // as the builtin OpaqueAttr, we want to encode it as if the dialect was
+ // actually loaded.
+ if (OpaqueAttr opaqueAttr = attr.dyn_cast())
+ numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
+ else
+ numbering->dialect = &numberDialect(&attr.getDialect());
+}
+
+void IRNumberingState::number(Block &block) {
+ // Number the arguments of the block.
+ for (BlockArgument arg : block.getArguments()) {
+ valueIDs.try_emplace(arg, nextValueID++);
+ number(arg.getLoc());
+ number(arg.getType());
+ }
+
+ // Number the operations in this block.
+ unsigned &numOps = blockOperationCounts[&block];
+ for (Operation &op : block) {
+ number(op);
+ ++numOps;
+ }
+}
+
+auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
+ DialectNumbering *&numbering = registeredDialects[dialect];
+ if (!numbering) {
+ numbering = &numberDialect(dialect->getNamespace());
+ numbering->dialect = dialect;
+ }
+ return *numbering;
+}
+
+auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
+ DialectNumbering *&numbering = dialects[dialect];
+ if (!numbering) {
+ numbering = new (dialectAllocator.Allocate())
+ DialectNumbering(dialect, dialects.size() - 1);
+ }
+ return *numbering;
+}
+
+void IRNumberingState::number(Region ®ion) {
+ if (region.empty())
+ return;
+ size_t firstValueID = nextValueID;
+
+ // Number the blocks within this region.
+ size_t blockCount = 0;
+ for (auto &it : llvm::enumerate(region)) {
+ blockIDs.try_emplace(&it.value(), it.index());
+ number(it.value());
+ ++blockCount;
+ }
+
+ // Remember the number of blocks and values in this region.
+ regionBlockValueCounts.try_emplace(®ion, blockCount,
+ nextValueID - firstValueID);
+}
+
+void IRNumberingState::number(Operation &op) {
+ // Number the components of an operation that won't be numbered elsewhere
+ // (e.g. we don't number operands, regions, or successors here).
+ number(op.getName());
+ for (OpResult result : op.getResults()) {
+ valueIDs.try_emplace(result, nextValueID++);
+ number(result.getType());
+ }
+
+ // Only number the operation's dictionary if it isn't empty.
+ DictionaryAttr dictAttr = op.getAttrDictionary();
+ if (!dictAttr.empty())
+ number(dictAttr);
+
+ number(op.getLoc());
+}
+
+void IRNumberingState::number(OperationName opName) {
+ OpNameNumbering *&numbering = opNames[opName];
+ if (numbering) {
+ ++numbering->refCount;
+ return;
+ }
+ DialectNumbering *dialectNumber = nullptr;
+ if (Dialect *dialect = opName.getDialect())
+ dialectNumber = &numberDialect(dialect);
+ else
+ dialectNumber = &numberDialect(opName.getDialectNamespace());
+
+ numbering =
+ new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
+ orderedOpNames.push_back(numbering);
+}
+
+void IRNumberingState::number(Type type) {
+ auto it = types.insert({type, nullptr});
+ if (!it.second) {
+ ++it.first->second->refCount;
+ return;
+ }
+ auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
+ it.first->second = numbering;
+ orderedTypes.push_back(numbering);
+
+ // Check for OpaqueType, which is a dialect-specific type that didn't have a
+ // registered dialect when it got created. We don't want to encode this as the
+ // builtin OpaqueType, we want to encode it as if the dialect was actually
+ // loaded.
+ if (OpaqueType opaqueType = type.dyn_cast())
+ numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
+ else
+ numbering->dialect = &numberDialect(&type.getDialect());
+}
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -3,6 +3,7 @@
add_subdirectory(Analysis)
add_subdirectory(AsmParser)
+add_subdirectory(Bytecode)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(IR)
diff --git a/mlir/lib/Parser/CMakeLists.txt b/mlir/lib/Parser/CMakeLists.txt
--- a/mlir/lib/Parser/CMakeLists.txt
+++ b/mlir/lib/Parser/CMakeLists.txt
@@ -6,5 +6,6 @@
LINK_LIBS PUBLIC
MLIRAsmParser
+ MLIRBytecodeReader
MLIRIR
)
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -12,6 +12,7 @@
#include "mlir/Parser/Parser.h"
#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/Bytecode/BytecodeReader.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
@@ -25,6 +26,8 @@
sourceBuf->getBufferIdentifier(),
/*line=*/0, /*column=*/0);
}
+ if (isBytecode(*sourceBuf))
+ return readBytecodeFile(*sourceBuf, block, config);
return parseAsmSourceFile(sourceMgr, block, config);
}
diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
--- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt
+++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
@@ -5,6 +5,7 @@
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-opt
LINK_LIBS PUBLIC
+ MLIRBytecodeWriter
MLIRPass
MLIRParser
MLIRSupport
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
@@ -47,7 +48,8 @@
static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
bool verifyPasses, SourceMgr &sourceMgr,
MLIRContext *context,
- PassPipelineFn passManagerSetupFn) {
+ PassPipelineFn passManagerSetupFn,
+ bool emitBytecode) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
@@ -86,8 +88,12 @@
// Print the output.
TimingScope outputTiming = timing.nest("Output");
- module->print(os);
- os << '\n';
+ if (emitBytecode) {
+ writeBytecodeToFile(module->getOperation(), os);
+ } else {
+ module->print(os);
+ os << '\n';
+ }
return success();
}
@@ -97,8 +103,8 @@
processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
- PassPipelineFn passManagerSetupFn, DialectRegistry ®istry,
- llvm::ThreadPool *threadPool) {
+ bool emitBytecode, PassPipelineFn passManagerSetupFn,
+ DialectRegistry ®istry, llvm::ThreadPool *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -122,7 +128,7 @@
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
- &context, passManagerSetupFn);
+ &context, passManagerSetupFn, emitBytecode);
}
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
@@ -131,7 +137,7 @@
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
- passManagerSetupFn);
+ passManagerSetupFn, emitBytecode);
// Verify the diagnostic handler to make sure that each of the diagnostics
// matched.
@@ -144,7 +150,8 @@
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext) {
+ bool preloadDialectsInContext,
+ bool emitBytecode) {
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
// We use an explicit threadpool to avoid creating and joining/destroying
@@ -163,8 +170,8 @@
raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
- preloadDialectsInContext, passManagerSetupFn, registry,
- threadPool);
+ preloadDialectsInContext, emitBytecode,
+ passManagerSetupFn, registry, threadPool);
};
return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
splitInputFile, /*insertMarkerInOutput=*/true);
@@ -176,7 +183,8 @@
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext) {
+ bool preloadDialectsInContext,
+ bool emitBytecode) {
auto passManagerSetupFn = [&](PassManager &pm) {
auto errorHandler = [&](const Twine &msg) {
emitError(UnknownLoc::get(pm.getContext())) << msg;
@@ -186,7 +194,8 @@
};
return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
registry, splitInputFile, verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, preloadDialectsInContext);
+ allowUnregisteredDialects, preloadDialectsInContext,
+ emitBytecode);
}
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
@@ -224,6 +233,10 @@
"show-dialects", cl::desc("Print the list of registered dialects"),
cl::init(false));
+ static cl::opt emitBytecode(
+ "emit-bytecode", cl::desc("Emit bytecode when generating output"),
+ cl::init(false));
+
InitLLVM y(argc, argv);
// Register any command line options.
@@ -268,7 +281,8 @@
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
splitInputFile, verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, preloadDialectsInContext)))
+ allowUnregisteredDialects, preloadDialectsInContext,
+ emitBytecode)))
return failure();
// Keep the output file if the invocation of MlirOptMain was successful.
diff --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bytecode/general.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -allow-unregistered-dialect -emit-bytecode %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: "bytecode.test1"
+// CHECK-NEXT: "bytecode.empty"() : () -> ()
+// CHECK-NEXT: "bytecode.attributes"() {attra = 10 : i64, attrb = #bytecode.attr} : () -> ()
+// CHECK-NEXT: %[[RESULTS:.*]]:3 = "bytecode.results"() : () -> (i32, i64, i32)
+// CHECK-NEXT: "bytecode.operands"(%[[RESULTS]]#0, %[[RESULTS]]#1, %[[RESULTS]]#2) : (i32, i64, i32) -> ()
+// CHECK-NEXT: "bytecode.branch"()[^[[BLOCK:.*]]] : () -> ()
+// CHECK-NEXT: ^[[BLOCK]](%[[ARG0:.*]]: i32, %[[ARG1:.*]]: !bytecode.int, %[[ARG2:.*]]: !pdl.operation):
+// CHECK-NEXT: "bytecode.regions"() ({
+// CHECK-NEXT: "bytecode.operands"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (i32, !bytecode.int, !pdl.operation) -> ()
+// CHECK-NEXT: "bytecode.return"() : () -> ()
+// CHECK-NEXT: }) : () -> ()
+// CHECK-NEXT: "bytecode.return"() : () -> ()
+// CHECK-NEXT: }) : () -> ()
+
+"bytecode.test1"() ({
+ "bytecode.empty"() : () -> ()
+ "bytecode.attributes"() {attra = 10, attrb = #bytecode.attr} : () -> ()
+ %results:3 = "bytecode.results"() : () -> (i32, i64, i32)
+ "bytecode.operands"(%results#0, %results#1, %results#2) : (i32, i64, i32) -> ()
+ "bytecode.branch"()[^secondBlock] : () -> ()
+
+^secondBlock(%arg1: i32, %arg2: !bytecode.int, %arg3: !pdl.operation):
+ "bytecode.regions"() ({
+ "bytecode.operands"(%arg1, %arg2, %arg3) : (i32, !bytecode.int, !pdl.operation) -> ()
+ "bytecode.return"() : () -> ()
+ }) : () -> ()
+ "bytecode.return"() : () -> ()
+}) : () -> ()