diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
new file mode 100644
--- /dev/null
+++ b/mlir/docs/BytecodeFormat.md
@@ -0,0 +1,288 @@
+# MLIR Bytecode Format
+
+This documents describes the MLIR bytecode format and its encoding.
+
+[TOC]
+
+## Magic Number
+
+MLIR uses the following four-byte magic number to indicate bytecode files:
+
+'\[‘M’8, ‘L’8, ‘ï’8, ‘R’8\]'
+
+## Format Overview
+
+An MLIR Bytecode file is comprised of a byte stream, with a few simple
+structural concepts layered on top.
+
+### Primitives
+
+#### Fixed-Width Integers
+
+```
+ byte ::= `0x00`...`0xFF`
+```
+
+Fixed width integers are unsigned integers of a known byte size. The values are
+stored in little-endian byte order.
+
+TODO: Add larger fixed width integers as necessary.
+
+#### Variable-Width Integers
+
+Variable width integers, or `VarInt`s, provide a compact representation for
+integers. Each encoded VarInt consists of one to nine bytes, which together
+represent a single 64-bit value. The MLIR bytecode utilizes the "PrefixVarInt"
+encoding for VarInts. This encoding is a variant of the
+[LEB128 ("Little-Endian Base 128")](https://en.wikipedia.org/wiki/LEB128)
+encoding, where each byte of the encoding provides up to 7 bits for the value,
+with the remaining bit used to store a tag indicating the number of bytes used
+for the encoding. This means that small unsigned integers (less than 2^7) may be
+stored in one byte, unsigned integers up to 2^14 may be stored in two bytes,
+etc.
+
+The first byte of the encoding includes a length prefix in the low bits. This
+prefix is a bit sequence of '0's followed by a terminal '1', or the end of the
+byte. The number of '0' bits indicate the number of _additional_ bytes, not
+including the prefix byte, used to encode the value. All of the remaining bits
+in the first byte, along with all of the bits in the additional bytes, provide
+the value of the integer. Below are the various possible encodings of the prefix
+byte:
+
+```
+xxxxxxx1: 7 value bits, the encoding uses 1 byte
+xxxxxx10: 14 value bits, the encoding uses 2 bytes
+xxxxx100: 21 value bits, the encoding uses 3 bytes
+xxxx1000: 28 value bits, the encoding uses 4 bytes
+xxx10000: 35 value bits, the encoding uses 5 bytes
+xx100000: 42 value bits, the encoding uses 6 bytes
+x1000000: 49 value bits, the encoding uses 7 bytes
+10000000: 56 value bits, the encoding uses 8 bytes
+00000000: 64 value bits, the encoding uses 9 bytes
+```
+
+#### NUL Terminated Strings
+
+NUL Terminated Strings are terminated with the ASCII NUL character (whose byte
+value is zero). These are not used in cases where a string may contain an
+embedded NUL character. In cases that may hold an embedded NUL character, the
+string is encoded using a length and byte array.
+
+### Sections
+
+```
+section {
+ id: byte
+ length: varint
+}
+```
+
+Sections are a mechanism for grouping data within the bytecode. The enable
+delayed processing, which is useful for out-of-order processing of data,
+lazy-loading, and more. Each section contains a Section ID and a length (which
+allowing for skipping over the section).
+
+TODO: Sections should also carry an optional alignment. Add this when necessary.
+
+## MLIR Encoding
+
+Given the generic structure of MLIR, the bytecode encoding is actually fairly
+simplistic. It effectively maps to the core components of MLIR.
+
+### Top Level Structure
+
+The top-level structure of the bytecode contains the 4-byte "magic number", a
+version number, and a list of sections. Each section is currently only expected
+to appear once within a bytecode file.
+
+```
+bytecode {
+ magic: "MLïR",
+ version: varint,
+ sections: section[]
+}
+```
+
+### Dialect Section
+
+The dialect section of the bytecode contains all of the dialects referenced
+within the encoded IR, and some information about the components of those
+dialects that were also referenced.
+
+```
+dialect_section {
+ dialects: dialect[]
+}
+
+dialect {
+ name: nul_terminated_string,
+ numAttrs: varint,
+ numTypes: varint,
+ numOpNames: varint,
+ opNames: nul_terminated_string[]
+}
+```
+
+### Attribute/Type Sections
+
+Attributes and types are encoded using two [sections](#sections), one section
+(`attr_type_section`) containing the actual encoded representation, and another
+section (`attr_type_offset_section`) containing the offsets of each encoded
+attribute/type into the previous section. As such, the
+`attr_type_offset_section` must be appear before the `attr_type_section` to
+ensure the latter can be properly loaded. This structure allows for attributes
+and types to always be lazily loaded on demand.
+
+```
+attr_type_section {
+ attrs: attribute[],
+ types: type[]
+}
+attr_type_offset_section {
+ offset: varint[]
+}
+
+attribute {
+ code: byte, // kAsmForm
+ encoding: ...
+}
+type {
+ code: byte, // kAsmForm
+ encoding: ...
+}
+```
+
+Each `offset` in the `attr_type_offset_section` above is the size of the
+encoding for the attribute or type. We avoid using the direct offset into the
+`attr_type_section`, as a smaller relative offsets provides more effective
+compression. Attributes and types are grouped by dialect, with each dialect
+grouping in the same order of the dialects within the
+[dialect section](#dialect-section), this allows to associate an attribute back
+to a dialect without including a dialect reference in each type/attr entry.
+
+#### Attribute/Type Encodings
+
+In the previous section, the forms of `attribute` and `type` both start with a
+`code` field. This field indicates how the attribute or type was encoded. In the
+abstract, an attribute/type is encoded in one of two possible ways: via its
+assembly format, or via a custom dialect defined encoding.
+
+##### Assembly Format Fallback
+
+In the case where a dialect does not define a method for encoding the attribute
+or type, the textual assembly format of that attribute or type is used as a
+fallback. For example, a type of `!bytecode.type` would be encoded as the null
+terminated string "!bytecode.type". This ensures that every attribute and type
+may be encoded, even if the owning dialect has not yet opted in to a more
+efficient serialization.
+
+TODO: We shouldn't redundantly encode the dialect name here, we should use a
+reference to the parent dialect instead.
+
+##### Dialect Defined Encoding
+
+TODO: This is not yet supported.
+
+### IR Section
+
+The IR section contains the encoded form of operations within the bytecode.
+
+#### Operation Encoding
+
+```
+op {
+ name: varint,
+ encodingMask: byte,
+ location: varint,
+
+ attrDict: varint?,
+
+ numResults: varint?,
+ resultTypes: varint[],
+
+ numOperands: varint?,
+ operands: varint[],
+
+ numSuccessors: varint?,
+ successors: varint[],
+
+ regionEncoding: varint?, // (numRegions << 1) | (isIsolatedFromAbove)
+ regions: region[]
+}
+```
+
+The encoding of an operation is important because this is generally the most
+commonly appearing structure in the bytecode. A single encoding is used for
+every type of operation. Given this prevelance, many of the fields of an
+operation are optional. The `encodingMask` field is a bitmask which indicates
+which of the components of the operation are present.
+
+##### Location
+
+The location is encoded as the index of the location within the attribute table.
+
+##### Attributes
+
+If the operation has attribues, the index of the operation attribute dictionary
+within the attribute table is encoded.
+
+##### Results
+
+If the operation has results, the number of results and the indexes of the
+result types within the type table are encoded.
+
+##### Operands
+
+If the operation has operands, the number of operands and the value index of
+each operand is encoded. This value index is the relative ordering of the
+definition of that value from the start of the first ancestor isolated region.
+
+##### Successors
+
+If the operation has successors, the number of successors and the indexes of the
+successor blocks within the parent region are encoded.
+
+##### Regions
+
+If the operation has regions, the number of regions and if the regions are
+isolated from above are encoded together in a single varint. Afterwards, each
+region is encoded inline.
+
+#### Region Encoding
+
+```
+region {
+ numBlocks: varint,
+
+ numValues: varint?,
+ blocks: block[]
+}
+```
+
+A region is encoded first with the number of blocks within. If the region is
+non-empty, the number of values defined directly within the region are encoded,
+followed by the blocks of the region.
+
+#### Block Encoding
+
+```
+block {
+ encoding: varint, // (numOps << 1) | (hasBlockArgs)
+ arguments: block_arguments?, // Optional based on encoding
+ ops : op[]
+}
+
+block_arguments {
+ numArgs: varint?,
+ args: block_argument[]
+}
+
+block_argument {
+ typeIndex: varint,
+ location: varint
+}
+```
+
+A block is encoded with an array of operations and block arguments. The first
+field is an encoding that combines the number of operations in the block, with a
+flag indicating if the block has arguments.
diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h
new file mode 100644
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeReader.h
@@ -0,0 +1,34 @@
+//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines interfaces to read MLIR bytecode files/streams.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEREADER_H
+#define MLIR_BYTECODE_BYTECODEREADER_H
+
+#include "mlir/IR/AsmState.h"
+#include "mlir/Support/LLVM.h"
+
+namespace llvm {
+class MemoryBufferRef;
+} // namespace llvm
+
+namespace mlir {
+/// Returns true if the given buffer starts with the magic bytes that signal
+/// MLIR bytecode.
+bool isBytecode(llvm::MemoryBufferRef buffer);
+
+/// Read the operations defined within the given memory buffer, containing MLIR
+/// bytecode, into the provided block.
+LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
+ const ParserConfig &config);
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEREADER_H
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
new file mode 100644
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -0,0 +1,31 @@
+//===- BytecodeWriter.h - MLIR Bytecode Writer ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines interfaces to write MLIR bytecode files/streams.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEWRITER_H
+#define MLIR_BYTECODE_BYTECODEWRITER_H
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Operation;
+
+//===----------------------------------------------------------------------===//
+// Entry Points
+//===----------------------------------------------------------------------===//
+
+/// Write the bytecode for the given operation to the provided output stream.
+/// For streams where it matters, the given stream should be in "binary" mode.
+void writeBytecodeToFile(Operation *op, raw_ostream &os);
+
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEWRITER_H
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -642,11 +642,11 @@
OperationState(Location location, OperationName name);
OperationState(Location location, OperationName name, ValueRange operands,
- TypeRange types, ArrayRef attributes,
+ TypeRange types, ArrayRef attributes = {},
BlockRange successors = {},
MutableArrayRef> regions = {});
OperationState(Location location, StringRef name, ValueRange operands,
- TypeRange types, ArrayRef attributes,
+ TypeRange types, ArrayRef attributes = {},
BlockRange successors = {},
MutableArrayRef> regions = {});
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -50,13 +50,15 @@
/// - preloadDialectsInContext will trigger the upfront loading of all
/// dialects from the global registry in the MLIRContext. This option is
/// deprecated and will be removed soon.
+/// - emitBytecode will generate bytecode output instead of text.
LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
std::unique_ptr buffer,
const PassPipelineCLParser &passPipeline,
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext = false);
+ bool preloadDialectsInContext = false,
+ bool emitBytecode = false);
/// Support a callback to setup the pass manager.
/// - passManagerSetupFn is the callback invoked to setup the pass manager to
@@ -67,7 +69,8 @@
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext = false);
+ bool preloadDialectsInContext = false,
+ bool emitBytecode = false);
/// Implementation for tools like `mlir-opt`.
/// - toolName is used for the header displayed by `--help`.
diff --git a/mlir/lib/Bytecode/CMakeLists.txt b/mlir/lib/Bytecode/CMakeLists.txt
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(Reader)
+add_subdirectory(Writer)
diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Encoding.h
@@ -0,0 +1,86 @@
+//===- Encoding.h - MLIR binary format encoding information -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines enum values describing the structure of MLIR bytecode
+// files.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_BYTECODE_ENCODING_H
+#define LIB_MLIR_BYTECODE_ENCODING_H
+
+#include
+
+namespace mlir {
+namespace bytecode {
+//===----------------------------------------------------------------------===//
+// General constants
+//===----------------------------------------------------------------------===//
+
+enum {
+ /// The current bytecode version.
+ kVersion = 0,
+};
+
+//===----------------------------------------------------------------------===//
+// Sections
+//===----------------------------------------------------------------------===//
+
+namespace Section {
+enum ID : uint8_t {
+ /// This section contains the dialects referenced within an IR module.
+ kDialect = 0,
+
+ /// This section contains the attributes and types referenced within an IR
+ /// module.
+ kAttrType = 1,
+
+ /// This section contains the offsets for the attribute and types within the
+ /// AttrType section.
+ kAttrTypeOffset = 2,
+
+ /// This section contains the list of operations serialized into the bytecode,
+ /// and their nested regions/operations.
+ kIR = 3,
+
+ /// The total number of section types.
+ kNumSections = 4,
+};
+} // namespace Section
+
+//===----------------------------------------------------------------------===//
+// AttrType Section
+//===----------------------------------------------------------------------===//
+
+enum class AttrTypeCode : uint8_t {
+ /// This code represents an attribute or type represented in the textual
+ /// assembly format.
+ kAsmForm,
+};
+
+//===----------------------------------------------------------------------===//
+// IR Section
+//===----------------------------------------------------------------------===//
+
+/// This enum represents a mask of all of the potential components of an
+/// operation. This mask is used when encoding an operation to indicate which
+/// components are present in the bytecode.
+namespace OpEncodingMask {
+enum : uint8_t {
+ kHasAttrs = 1 << 0,
+ kHasResults = 1 << 1,
+ kHasOperands = 1 << 2,
+ kHasSuccessors = 1 << 3,
+ kHasInlineRegions = 1 << 4,
+};
+} // namespace OpEncodingMask
+
+} // namespace bytecode
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -0,0 +1,1077 @@
+//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// TODO: Support for big-endian architectures.
+// TODO: Properly preserve use lists of values.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeReader.h"
+#include "../Encoding.h"
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/Verifier.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Support/MemoryBufferRef.h"
+#include "llvm/Support/SaveAndRestore.h"
+
+#define DEBUG_TYPE "mlir-bytecode-reader"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// EncodingReader
+//===----------------------------------------------------------------------===//
+
+namespace {
+class EncodingReader {
+public:
+ explicit EncodingReader(ArrayRef contents, Location fileLoc)
+ : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {}
+ explicit EncodingReader(StringRef contents, Location fileLoc)
+ : EncodingReader({reinterpret_cast(contents.data()),
+ contents.size()},
+ fileLoc) {}
+
+ /// Returns true if the entire section has been read.
+ bool empty() const { return dataIt == dataEnd; }
+
+ /// Returns the remaining size of the bytecode.
+ size_t size() const { return dataEnd - dataIt; }
+
+ /// Emit an error using the given arguments.
+ template
+ LogicalResult emitError(Args &&...args) const {
+ return ::emitError(fileLoc).append(std::forward(args)...);
+ }
+
+ /// Parse a single byte from the stream.
+ template
+ ParseResult parseByte(T &value) {
+ if (empty())
+ return emitError("attempting to parse a byte at the end of the bytecode");
+ value = static_cast(*dataIt++);
+ return success();
+ }
+ /// Parse a range of bytes of 'length' into the given result.
+ ParseResult parseBytes(size_t length, ArrayRef &result) {
+ if (length > size()) {
+ return emitError("attempting to parse ", length, " bytes when only ",
+ size(), " remain");
+ }
+ result = {dataIt, length};
+ dataIt += length;
+ return success();
+ }
+ /// Parse a range of bytes of 'length' into the given result, which can be
+ /// assumed to be large enough to hold `length`.
+ ParseResult parseBytes(size_t length, uint8_t *result) {
+ if (length > size()) {
+ return emitError("attempting to parse ", length, " bytes when only ",
+ size(), " remain");
+ }
+ memcpy(result, dataIt, length);
+ dataIt += length;
+ return success();
+ }
+
+ /// Parse a variable length encoded integer from the byte stream. The first
+ /// encoded byte contains a prefix in the low bits indicating the encoded
+ /// length of the value. This length prefix is a bit sequence of '0's followed
+ /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
+ /// (not including the prefix byte). All remaining bits in the first byte,
+ /// along with all of the bits in additional bytes, provide the value of the
+ /// integer encoded in little-endian order.
+ ParseResult parseVarInt(uint64_t &result) {
+ // Parse the first byte of the encoding, which contains the length prefix.
+ if (parseByte(result))
+ return failure();
+
+ // Handle the overwhelmingly common case where the value is stored in a
+ // single byte. In this case, the first bit is the `1` marker bit.
+ if (LLVM_LIKELY(result & 1)) {
+ result >>= 1;
+ return success();
+ }
+
+ // Handle the overwhelming uncommon case where the value required all 8
+ // bytes (i.e. a really really big number). In this case, the marker byte is
+ // all zeros: `00000000`.
+ if (LLVM_UNLIKELY(result == 0))
+ return parseBytes(sizeof(result), reinterpret_cast(&result));
+ return parseMultiByteVarInt(result);
+ }
+
+ /// Skip the first `length` bytes within the reader.
+ ParseResult skipBytes(size_t length) {
+ if (length > size()) {
+ return emitError("attempting to skip ", length, " bytes when only ",
+ size(), " remain");
+ }
+ dataIt += length;
+ return success();
+ }
+
+ /// Parse a null-terminated string into `result` (without including the NUL
+ /// terminator).
+ ParseResult parseNullTerminatedString(StringRef &result) {
+ const char *startIt = (const char *)dataIt;
+ const char *nulIt = (const char *)memchr(startIt, 0, size());
+ if (!nulIt)
+ return emitError("malformed null-terminated string, no NUL found");
+
+ result = StringRef(startIt, nulIt - startIt);
+ dataIt = (const uint8_t *)nulIt + 1;
+ return success();
+ }
+
+ /// Parse a section header, placing the kind of section in `sectionID` and the
+ /// contents of the section in `sectionData`.
+ ParseResult parseSection(bytecode::Section::ID §ionID,
+ ArrayRef §ionData) {
+ size_t length;
+ if (parseByte(sectionID) || parseVarInt(length))
+ return failure();
+ if (sectionID >= bytecode::Section::kNumSections)
+ return emitError("invalid section ID: ", sectionID);
+
+ // Parse the actua section data now that we have its length.
+ return parseBytes(length, sectionData);
+ }
+
+private:
+ /// Parse a variable length encoded integer from the byte stream. This method
+ /// is a fallback when the number of bytes used to encode the value is greater
+ /// than 1, but less than the max (9). The provided `result` value can be
+ /// assumed to already contain the first byte of the value.
+ /// NOTE: This method is marked noinline to avoid pessimizing the common case
+ /// of single byte encoding.
+ LLVM_ATTRIBUTE_NOINLINE ParseResult parseMultiByteVarInt(uint64_t &result) {
+ // Count the number of trailing zeros in the marker byte, this indicates the
+ // number of trailing bytes that are part of the value. We use `uint32_t`
+ // here because we only care about the first byte, and so that be actually
+ // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
+ // implementation).
+ uint32_t numBytes =
+ llvm::countTrailingZeros(result, llvm::ZB_Undefined);
+ assert(numBytes > 0 && numBytes <= 7 &&
+ "unexpected number of trailing zeros in varint encoding");
+
+ // Parse in the remaining bytes of the value.
+ if (parseBytes(numBytes, reinterpret_cast(&result) + 1))
+ return failure();
+
+ // Shift out the low-order bits that were used to mark how the value was
+ // encoded.
+ result >>= (numBytes + 1);
+ return success();
+ }
+
+ /// The current data iterator, and an iterator to the end of the buffer.
+ const uint8_t *dataIt, *dataEnd;
+
+ /// A location for the bytecode used to report errors.
+ Location fileLoc;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BytecodeDialect
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This struct represents a dialect entry within the bytecode.
+struct BytecodeDialect {
+ BytecodeDialect(Dialect *dialect, StringRef name, unsigned numAttrs,
+ unsigned numTypes)
+ : dialect(dialect), name(name), numAttrs(numAttrs), numTypes(numTypes) {}
+
+ /// The loaded dialect entry, if available, otherwise nullptr.
+ Dialect *dialect;
+
+ /// The name of the dialect.
+ StringRef name;
+
+ /// The number of attributes owned by this dialect in the bytecode.
+ unsigned numAttrs;
+
+ /// The number of types owned by this dialect in the bytecode.
+ unsigned numTypes;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Attribute/Type Reader
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides support for reading attribute and type entries from the
+/// bytecode. Attribute and Type entries are read lazily on demand, so we use
+/// this reader to manage when to actually parse them from the bytecode.
+class AttrTypeReader {
+ /// This class represents a single attribute or type entry.
+ template
+ struct Entry {
+ /// The entry, or null if it hasn't been resolved yet.
+ T entry = {};
+ /// The parent dialect of this entry.
+ BytecodeDialect *dialect = nullptr;
+ /// The raw data of this entry in the bytecode.
+ ArrayRef data;
+ };
+ using AttrEntry = Entry;
+ using TypeEntry = Entry;
+
+public:
+ AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {}
+
+ /// Initialize the attribute and type information within the reader.
+ LogicalResult initialize(ArrayRef dialects,
+ ArrayRef sectionData,
+ ArrayRef offsetSectionData);
+
+ /// Resolve the attribute or type at the given index. Returns nullptr on
+ /// failure.
+ Attribute resolveAttribute(unsigned index) {
+ return resolveEntry(attributes, index, "Attribute");
+ }
+ Type resolveType(unsigned index) {
+ return resolveEntry(types, index, "Type");
+ }
+
+private:
+ /// Initialize the offsets for the attribute and type entries.
+ LogicalResult initializeOffsets(ArrayRef sectionData,
+ ArrayRef offsetSectionData);
+
+ /// Resolve the given entry at `index`.
+ template
+ T resolveEntry(SmallVectorImpl> &entries, unsigned index,
+ StringRef entryType);
+
+ /// Parse the value defined within the given reader. `code` indicates how the
+ /// entry was encoded.
+ LogicalResult parseEntry(EncodingReader &reader, bytecode::AttrTypeCode code,
+ Attribute &result);
+ LogicalResult parseEntry(EncodingReader &reader, bytecode::AttrTypeCode code,
+ Type &result);
+
+ /// The set of attribute and type entries.
+ SmallVector attributes;
+ SmallVector types;
+
+ /// A location used for error emission.
+ Location fileLoc;
+};
+} // namespace
+
+LogicalResult AttrTypeReader::initialize(ArrayRef dialects,
+ ArrayRef sectionData,
+ ArrayRef offsetSectionData) {
+ // Initialize the entries using the dialect information.
+ unsigned numAttrs = 0, numTypes = 0;
+ for (const BytecodeDialect &dialect : dialects) {
+ numAttrs += dialect.numAttrs;
+ numTypes += dialect.numTypes;
+ }
+ attributes.resize(numAttrs);
+ types.resize(numTypes);
+
+ // With the entries initialized, we can process the offsets.
+ return initializeOffsets(sectionData, offsetSectionData);
+}
+
+LogicalResult
+AttrTypeReader::initializeOffsets(ArrayRef sectionData,
+ ArrayRef offsetSectionData) {
+ EncodingReader offsetReader(offsetSectionData, fileLoc);
+
+ // A functor used to accumulate the offsets for the entries in the given
+ // range.
+ uint64_t currentOffset = 0;
+ auto accumulateOffsets = [&](auto &&range) {
+ for (auto &entry : range) {
+ uint64_t entrySize;
+ if (offsetReader.parseVarInt(entrySize))
+ return failure();
+ if (currentOffset + entrySize > sectionData.size())
+ return offsetReader.emitError(
+ "Attribute or Type entry offset points past the end of section");
+ entry.data = sectionData.slice(currentOffset, entrySize);
+ currentOffset += entrySize;
+ }
+ return success();
+ };
+
+ // Process each of the attributes, and then the types.
+ if (failed(accumulateOffsets(attributes)) || failed(accumulateOffsets(types)))
+ return failure();
+
+ // Ensure that we read everything from the section.
+ if (!offsetReader.empty()) {
+ return offsetReader.emitError(
+ "unexpected trailing data in the Attribute/Type offset section");
+ }
+ return success();
+}
+
+template
+T AttrTypeReader::resolveEntry(SmallVectorImpl> &entries,
+ unsigned index, StringRef entryType) {
+ if (index >= entries.size()) {
+ emitError(fileLoc) << "invalid " << entryType << "index:" << index;
+ return {};
+ }
+
+ // If the entry has already been resolved, there is nothing left to do.
+ Entry &entry = entries[index];
+ if (entry.entry)
+ return entry.entry;
+
+ // Parse the entry. Each entry starts with a specific code that indicates how
+ // it is represented.
+ EncodingReader reader(entry.data, fileLoc);
+ bytecode::AttrTypeCode code;
+ if (reader.parseByte(code) || failed(parseEntry(reader, code, entry.entry)))
+ return T();
+ if (!reader.empty()) {
+ (void)reader.emitError("unexpected trailing bytes after " + entryType +
+ " entry");
+ return T();
+ }
+ return entry.entry;
+}
+
+LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader,
+ bytecode::AttrTypeCode code,
+ Attribute &result) {
+ // Handle the fallback case, where the attribute was encoded using its
+ // assembly format.
+ if (code == bytecode::AttrTypeCode::kAsmForm) {
+ StringRef attrStr;
+ if (failed(reader.parseNullTerminatedString(attrStr)))
+ return failure();
+
+ size_t numRead = 0;
+ if (!(result = parseAttribute(attrStr, fileLoc->getContext(), numRead)))
+ return failure();
+ if (numRead != attrStr.size()) {
+ return reader.emitError(
+ "trailing characters found after Attribute assembly format: ",
+ attrStr.drop_front(numRead));
+ }
+ return success();
+ }
+
+ return reader.emitError("unexpected Attribute encoding: ",
+ static_cast(code));
+}
+
+LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader,
+ bytecode::AttrTypeCode code,
+ Type &result) {
+ // Handle the fallback case, where the type was encoded using its
+ // assembly format.
+ if (code == bytecode::AttrTypeCode::kAsmForm) {
+ StringRef typeStr;
+ if (failed(reader.parseNullTerminatedString(typeStr)))
+ return failure();
+
+ size_t numRead = 0;
+ if (!(result = parseType(typeStr, fileLoc->getContext(), numRead)))
+ return failure();
+ if (numRead != typeStr.size()) {
+ return reader.emitError(
+ "trailing characters found after Type assembly format: " +
+ typeStr.drop_front(numRead));
+ }
+ return success();
+ }
+
+ return reader.emitError("unexpected Type encoding: ",
+ static_cast(code));
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode Reader
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class is used to read a bytecode buffer and translate it into MLIR.
+class BytecodeReader {
+public:
+ BytecodeReader(Location fileLoc, const ParserConfig &config)
+ : config(config), fileLoc(fileLoc), attrTypeReader(fileLoc),
+ forwardRefOpState(UnknownLoc::get(config.getContext()),
+ "builtin.unrealized_conversion_cast", ValueRange(),
+ NoneType::get(config.getContext())) {}
+
+ /// Read the bytecode defined within `buffer` into the given block.
+ LogicalResult read(llvm::MemoryBufferRef buffer, Block *block);
+
+private:
+ /// Return the context for this config.
+ MLIRContext *getContext() const { return config.getContext(); }
+
+ //===--------------------------------------------------------------------===//
+ // Dialect Section
+
+ LogicalResult parseDialectSection(ArrayRef sectionData);
+
+ /// Parse an operation name reference using the given reader.
+ FailureOr parseOpName(EncodingReader &reader);
+
+ //===--------------------------------------------------------------------===//
+ // Attribute/Type Section
+
+ /// Parse an attribute or type using the given reader. Returns nullptr in the
+ /// case of failure.
+ Attribute parseAttribute(EncodingReader &reader);
+ Type parseType(EncodingReader &reader);
+
+ template
+ T parseAttribute(EncodingReader &reader) {
+ if (Attribute attr = parseAttribute(reader)) {
+ if (auto derivedAttr = attr.dyn_cast())
+ return derivedAttr;
+ (void)reader.emitError("expected attribute of type: ",
+ llvm::getTypeName(), ", but got: ", attr);
+ }
+ return T();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // IR Section
+
+ /// This struct represents the current read state of a range of regions. This
+ /// struct is used to enable iterative parsing of regions.
+ struct RegionReadState {
+ RegionReadState(Operation *op, bool isIsolatedFromAbove)
+ : RegionReadState(op->getRegions(), isIsolatedFromAbove) {}
+ RegionReadState(MutableArrayRef regions, bool isIsolatedFromAbove)
+ : curRegion(regions.begin()), endRegion(regions.end()),
+ isIsolatedFromAbove(isIsolatedFromAbove) {}
+
+ /// The current regions being read.
+ MutableArrayRef::iterator curRegion, endRegion;
+
+ /// The number of values defined immediately within this region.
+ unsigned numValues = 0;
+
+ /// The current blocks of the region being read.
+ SmallVector curBlocks;
+ Region::iterator curBlock = {};
+
+ /// The number of operations remaining to be read from the current block
+ /// being read.
+ unsigned numOpsRemaining = 0;
+
+ /// A flag indicating if the regions being read are isolated from above.
+ bool isIsolatedFromAbove = false;
+ };
+
+ LogicalResult parseIRSection(ArrayRef sectionData, Block *block);
+ LogicalResult parseRegions(EncodingReader &reader,
+ std::vector ®ionStack,
+ RegionReadState &readState);
+ FailureOr parseOpWithoutRegions(EncodingReader &reader,
+ RegionReadState &readState,
+ bool &isIsolatedFromAbove);
+
+ LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState);
+ LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState);
+ LogicalResult parseBlockArguments(EncodingReader &reader, Block *block);
+
+ //===--------------------------------------------------------------------===//
+ // Value Processing
+
+ /// Parse an operand reference using the given reader. Returns nullptr in the
+ /// case of failure.
+ Value parseOperand(EncodingReader &reader);
+
+ /// Sequentially define the given value range.
+ LogicalResult defineValues(EncodingReader &reader, ValueRange values);
+
+ /// Create a value to use for a forward reference.
+ Value createForwardRef();
+
+ //===--------------------------------------------------------------------===//
+ // Fields
+
+ /// This class represents a single value scope, in which a value scope is
+ /// delimited by isolated from above regions.
+ struct ValueScope {
+ /// Push a new region state onto this scope, reserving enough values for
+ /// those defined within the current region of the provided state.
+ void push(RegionReadState &readState) {
+ nextValueIDs.push_back(values.size());
+ values.resize(values.size() + readState.numValues);
+ }
+
+ /// Pop the values defined for the current region within the provided region
+ /// state.
+ void pop(RegionReadState &readState) {
+ values.resize(values.size() - readState.numValues);
+ nextValueIDs.pop_back();
+ }
+
+ /// The set of values defined in this scope.
+ std::vector values;
+
+ /// The ID for the next defined value for each region current being
+ /// processed in this scope.
+ SmallVector nextValueIDs;
+ };
+
+ /// The configuration of the parser.
+ const ParserConfig &config;
+
+ /// A location to use when emitting errors.
+ Location fileLoc;
+
+ /// The reader used to process attribute and types within the bytecode.
+ AttrTypeReader attrTypeReader;
+
+ /// The version of the bytecode being read.
+ uint64_t version = 0;
+
+ /// The table of IR units referenced within the bitcode file.
+ SmallVector dialects;
+ SmallVector opNames;
+
+ /// The current set of available IR value scopes.
+ std::vector valueScopes;
+ /// A block containing the set of operations defined to create forward
+ /// references.
+ Block forwardRefOps;
+ /// A block containing previously created, and no longer used, forward
+ /// reference operations.
+ Block openForwardRefOps;
+ /// An operation state used when instantiating forward references.
+ OperationState forwardRefOpState;
+};
+} // namespace
+
+LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
+ EncodingReader reader(buffer.getBuffer(), fileLoc);
+
+ // Skip over the bytecode header, this should have already been checked.
+ if (reader.skipBytes(StringRef("ML\xefR").size()))
+ return failure();
+
+ // Parse the bytecode version.
+ if (reader.parseVarInt(version))
+ return failure();
+
+ // Validate the bytecode version.
+ if (version < bytecode::kVersion) {
+ return reader.emitError(
+ "bytecode version ", version, " is older than the current version of ",
+ bytecode::kVersion, ", and upgrade is not supported.");
+ }
+ if (version > bytecode::kVersion) {
+ return reader.emitError("bytecode version ", version,
+ " is newer than the current version ",
+ bytecode::kVersion, ".");
+ }
+
+ // The raw data for the AttrTypeOffset section.
+ Optional> attrTypeOffsetSection;
+
+ BitVector seenSections(bytecode::Section::kNumSections);
+ while (!reader.empty()) {
+ // Read the next section from the bytecode.
+ bytecode::Section::ID sectionID;
+ ArrayRef sectionData;
+ if (reader.parseSection(sectionID, sectionData))
+ return failure();
+
+ // Check for duplicate sections, we only expect one instance of each.
+ if (seenSections.test(sectionID))
+ return reader.emitError("duplicate top-level section ID: ", sectionID);
+ seenSections.set(sectionID);
+
+ // Process the section.
+ switch (sectionID) {
+ case bytecode::Section::kDialect:
+ if (failed(parseDialectSection(sectionData)))
+ return failure();
+ break;
+ case bytecode::Section::kAttrType:
+ if (!attrTypeOffsetSection) {
+ return reader.emitError(
+ "expected the AttrTypeOffset section before the AttrType section");
+ }
+ if (dialects.empty()) {
+ return reader.emitError("expected the Dialect section before the"
+ "AttrTypeOffset section");
+ }
+
+ // With everything ready, initialize the attribute/type reader.
+ if (failed(attrTypeReader.initialize(dialects, sectionData,
+ *attrTypeOffsetSection)))
+ return failure();
+ break;
+ case bytecode::Section::kAttrTypeOffset:
+ // We won't parse this section until we process the main AttrType section.
+ // For now, just record the raw data.
+ attrTypeOffsetSection = sectionData;
+ break;
+ case bytecode::Section::kIR:
+ if (failed(parseIRSection(sectionData, block)))
+ return failure();
+ break;
+ default:
+ return reader.emitError("unexpected top-level section: ", sectionID);
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Dialect Section
+
+LogicalResult
+BytecodeReader::parseDialectSection(ArrayRef sectionData) {
+ MLIRContext *ctx = getContext();
+
+ EncodingReader sectionReader(sectionData, fileLoc);
+ while (!sectionReader.empty()) {
+ // Read the name of the next dialect.
+ StringRef dialectName;
+ if (sectionReader.parseNullTerminatedString(dialectName))
+ return failure();
+
+ // Parse the attribute and type counts.
+ uint64_t attrCount, typeCount;
+ if (sectionReader.parseVarInt(attrCount) ||
+ sectionReader.parseVarInt(typeCount))
+ return failure();
+
+ // Try to load the dialect.
+ Dialect *dialect = ctx->getOrLoadDialect(dialectName);
+ if (!dialect && !ctx->allowsUnregisteredDialects()) {
+ return sectionReader.emitError(
+ "dialect '", dialectName,
+ "' is unknown. If this is intended, please call "
+ "allowUnregisteredDialects() on the MLIRContext, or use "
+ "-allow-unregistered-dialect with the MLIR tool used.");
+ }
+ dialects.emplace_back(dialect, dialectName, attrCount, typeCount);
+
+ // Parse the operation names of the dialect.
+ uint64_t numOpNames;
+ if (sectionReader.parseVarInt(numOpNames))
+ return failure();
+ SmallString<32> opNameStorage({dialectName, "."});
+ while (numOpNames--) {
+ StringRef opName;
+ if (sectionReader.parseNullTerminatedString(opName))
+ return failure();
+
+ opNameStorage.resize(dialectName.size() + 1);
+ opNameStorage.append(opName);
+ opNames.push_back(OperationName(opNameStorage, ctx));
+ }
+ }
+ return success();
+}
+
+FailureOr BytecodeReader::parseOpName(EncodingReader &reader) {
+ uint64_t opNameIdx;
+ if (reader.parseVarInt(opNameIdx))
+ return failure();
+
+ if (opNameIdx >= opNames.size())
+ return reader.emitError("invalid operation name index: ", opNameIdx);
+ return opNames[opNameIdx];
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute/Type Section
+
+Attribute BytecodeReader::parseAttribute(EncodingReader &reader) {
+ uint64_t attrIdx;
+ if (reader.parseVarInt(attrIdx))
+ return Attribute();
+ return attrTypeReader.resolveAttribute(attrIdx);
+}
+
+Type BytecodeReader::parseType(EncodingReader &reader) {
+ uint64_t typeIdx;
+ if (reader.parseVarInt(typeIdx))
+ return Type();
+ return attrTypeReader.resolveType(typeIdx);
+}
+
+//===----------------------------------------------------------------------===//
+// IR Section
+
+LogicalResult BytecodeReader::parseIRSection(ArrayRef sectionData,
+ Block *block) {
+ EncodingReader reader(sectionData, fileLoc);
+
+ // A stack of operation regions currently being read from the bytecode.
+ std::vector regionStack;
+
+ // Parse the top-level block using a temporary module operation.
+ OwningOpRef moduleOp = ModuleOp::create(fileLoc);
+ regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true);
+ regionStack.back().curBlocks.push_back(moduleOp->getBody());
+ regionStack.back().curBlock = regionStack.back().curRegion->begin();
+ if (failed(parseBlock(reader, regionStack.back())))
+ return failure();
+ valueScopes.emplace_back(ValueScope());
+ valueScopes.back().push(regionStack.back());
+
+ // Iteratively parse regions until everything has been resolved.
+ while (!regionStack.empty())
+ if (failed(parseRegions(reader, regionStack, regionStack.back())))
+ return failure();
+ if (!forwardRefOps.empty()) {
+ return reader.emitError(
+ "not all forward unresolved forward operand references");
+ }
+
+ // Verify that the parsed operations are valid.
+ if (failed(verify(*moduleOp)))
+ return failure();
+
+ // Splice the parsed operations over to the provided top-level block.
+ auto &parsedOps = moduleOp->getBody()->getOperations();
+ auto &destOps = block->getOperations();
+ destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()),
+ parsedOps, parsedOps.begin(), parsedOps.end());
+ return success();
+}
+
+LogicalResult
+BytecodeReader::parseRegions(EncodingReader &reader,
+ std::vector ®ionStack,
+ RegionReadState &readState) {
+ // Read the regions of this operation.
+ for (; readState.curRegion != readState.endRegion; ++readState.curRegion) {
+ // If the current block hasn't been setup yet, parse the header for this
+ // region.
+ if (readState.curBlock == Region::iterator()) {
+ if (failed(parseRegion(reader, readState)))
+ return failure();
+
+ // If the region is empty, there is nothing to more to do.
+ if (readState.curRegion->empty())
+ continue;
+ }
+
+ // Parse the blocks within the region.
+ do {
+ while (readState.numOpsRemaining--) {
+ // Read in the next operation. We don't read its regions directly, we
+ // handle those afterwards as necessary.
+ bool isIsolatedFromAbove = false;
+ FailureOr op =
+ parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
+ if (failed(op))
+ return failure();
+
+ // If the op has regions, add it to the stack for processing.
+ if ((*op)->getNumRegions()) {
+ regionStack.emplace_back(*op, isIsolatedFromAbove);
+
+ // If the op is isolated from above, push a new value scope.
+ if (isIsolatedFromAbove)
+ valueScopes.emplace_back(ValueScope());
+ return success();
+ }
+ }
+
+ // Move to the next block of the region.
+ if (++readState.curBlock == readState.curRegion->end())
+ break;
+ if (failed(parseBlock(reader, readState)))
+ return failure();
+ } while (true);
+
+ // Reset the current block and any values reserved for this region.
+ readState.curBlock = {};
+ valueScopes.back().pop(readState);
+ }
+
+ // When the regions have been fully parsed, pop them off of the read stack. If
+ // the regions were isolated from above, we also pop the last value scope.
+ regionStack.pop_back();
+ if (readState.isIsolatedFromAbove)
+ valueScopes.pop_back();
+ return success();
+}
+
+FailureOr
+BytecodeReader::parseOpWithoutRegions(EncodingReader &reader,
+ RegionReadState &readState,
+ bool &isIsolatedFromAbove) {
+ // Parse the name of the operation.
+ FailureOr opName = parseOpName(reader);
+ if (failed(opName))
+ return failure();
+
+ // Parse the operation mask, which indicates which components of the operation
+ // are present.
+ uint8_t opMask;
+ if (reader.parseByte(opMask))
+ return failure();
+
+ /// Parse the location.
+ LocationAttr opLoc = parseAttribute(reader);
+ if (!opLoc)
+ return failure();
+
+ // With the location and name resolved, we can start building the operation
+ // state.
+ OperationState opState(opLoc, *opName);
+
+ // Parse the attributes of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
+ DictionaryAttr dictAttr = parseAttribute(reader);
+ if (!dictAttr)
+ return failure();
+ opState.attributes = dictAttr;
+ }
+
+ /// Parse the results of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasResults) {
+ uint64_t numResults;
+ if (reader.parseVarInt(numResults))
+ return failure();
+ opState.types.resize(numResults);
+ for (int i = 0, e = numResults; i < e; ++i)
+ if (!(opState.types[i] = parseType(reader)))
+ return failure();
+ }
+
+ /// Parse the operands of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasOperands) {
+ uint64_t numOperands;
+ if (reader.parseVarInt(numOperands))
+ return failure();
+ opState.operands.resize(numOperands);
+ for (int i = 0, e = numOperands; i < e; ++i)
+ if (!(opState.operands[i] = parseOperand(reader)))
+ return failure();
+ }
+
+ /// Parse the successors of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasSuccessors) {
+ uint64_t numSuccs;
+ if (reader.parseVarInt(numSuccs))
+ return failure();
+ opState.successors.reserve(numSuccs);
+ for (int i = 0, e = numSuccs; i < e; ++i) {
+ uint64_t succID;
+ if (reader.parseVarInt(succID))
+ return failure();
+ if (succID >= readState.curBlocks.size())
+ return reader.emitError("invalid successor index: ", succID);
+ opState.successors.push_back(readState.curBlocks[succID]);
+ }
+ }
+
+ /// Parse the regions of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
+ uint64_t numRegions;
+ if (reader.parseVarInt(numRegions))
+ return failure();
+
+ // Extract if the regions are isolated from the numRegions varint, which is
+ // encoded in the low bit.
+ isIsolatedFromAbove = numRegions & 1;
+ numRegions >>= 1;
+
+ opState.regions.reserve(numRegions);
+ for (int i = 0, e = numRegions; i < e; ++i)
+ opState.regions.push_back(std::make_unique());
+ }
+
+ // Create the operation at the back of the current block.
+ Operation *op = Operation::create(opState);
+ readState.curBlock->push_back(op);
+
+ // If the operation had results, update the value references.
+ if (op->getNumResults() && failed(defineValues(reader, op->getResults())))
+ return failure();
+
+ return op;
+}
+
+LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
+ RegionReadState &readState) {
+ // Parse the number of blocks in the region.
+ uint64_t numBlocks;
+ if (reader.parseVarInt(numBlocks))
+ return failure();
+
+ // If the region is empty, there is nothing else to do.
+ if (numBlocks == 0)
+ return success();
+
+ // Parse the number of values defined in this region.
+ uint64_t numValues;
+ if (reader.parseVarInt(numValues))
+ return failure();
+ readState.numValues = numValues;
+
+ // Create the blocks within this region. We do this before processing so that
+ // we can rely on the blocks existing when creating operations.
+ readState.curBlocks.clear();
+ readState.curBlocks.reserve(numBlocks);
+ for (uint64_t i = 0; i < numBlocks; ++i) {
+ readState.curBlocks.push_back(new Block());
+ readState.curRegion->push_back(readState.curBlocks.back());
+ }
+
+ // Prepare the current value scope for this region.
+ valueScopes.back().push(readState);
+
+ // Parse the entry block of the region.
+ readState.curBlock = readState.curRegion->begin();
+ return parseBlock(reader, readState);
+}
+
+LogicalResult BytecodeReader::parseBlock(EncodingReader &reader,
+ RegionReadState &readState) {
+ // Parse the number of operations in the block.
+ uint64_t numOps;
+ if (reader.parseVarInt(numOps))
+ return failure();
+ // Extract out the low-bit of the operation count, which indicates if the
+ // block has arguments.
+ bool hasArgs = numOps & 1;
+ readState.numOpsRemaining = numOps >> 1;
+
+ // Parse the arguments of the block.
+ if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock)))
+ return failure();
+
+ // We don't parse the operations of the block here, that's done elsewhere.
+ return success();
+}
+
+LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
+ Block *block) {
+ // Parse the value ID for the first argument, and the number of arguments.
+ uint64_t numArgs;
+ if (reader.parseVarInt(numArgs))
+ return failure();
+
+ SmallVector argTypes;
+ SmallVector argLocs;
+ argTypes.reserve(numArgs);
+ argLocs.reserve(numArgs);
+
+ while (numArgs--) {
+ Type argType = parseType(reader);
+ if (!argType)
+ return failure();
+ LocationAttr argLoc = parseAttribute(reader);
+ if (!argLoc)
+ return failure();
+
+ argTypes.push_back(argType);
+ argLocs.push_back(argLoc);
+ }
+ block->addArguments(argTypes, argLocs);
+ return defineValues(reader, block->getArguments());
+}
+
+//===----------------------------------------------------------------------===//
+// Value Processing
+
+Value BytecodeReader::parseOperand(EncodingReader &reader) {
+ uint64_t valueIdx;
+ if (failed(reader.parseVarInt(valueIdx)))
+ return nullptr;
+ std::vector &values = valueScopes.back().values;
+ if (valueIdx >= values.size())
+ return (void)reader.emitError("invalid value index: ", valueIdx), Value();
+
+ // Resolve it, or create a new forward reference if necessary.
+ Value &value = values[valueIdx];
+ if (!value)
+ value = createForwardRef();
+ return value;
+}
+
+LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
+ ValueRange newValues) {
+ ValueScope &valueScope = valueScopes.back();
+ std::vector &values = valueScope.values;
+
+ unsigned &valueID = valueScope.nextValueIDs.back();
+ unsigned valueIDEnd = valueID + newValues.size();
+ if (valueIDEnd > values.size()) {
+ return reader.emitError(
+ "value index range was outside of the expected range for "
+ "the parent region, got [",
+ valueID, ", ", valueIDEnd, "), but the maximum index was ",
+ values.size() - 1);
+ }
+
+ // Assign the values and update any forward references.
+ for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) {
+ Value newValue = newValues[i];
+
+ // Check to see if a definition for this value already exists.
+ if (Value oldValue = std::exchange(values[valueID], newValue)) {
+ Operation *forwardRefOp = oldValue.getDefiningOp();
+ if (!forwardRefOp || forwardRefOp->getBlock() != &forwardRefOps) {
+ return reader.emitError("value index ", valueID,
+ " was already defined");
+ }
+
+ oldValue.replaceAllUsesWith(newValue);
+ forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
+ }
+ }
+ return success();
+}
+
+Value BytecodeReader::createForwardRef() {
+ // Check for an avaliable existing operation to use. Otherwise, create a new
+ // fake operation to use for the reference.
+ if (!openForwardRefOps.empty()) {
+ Operation *op = &openForwardRefOps.back();
+ op->moveBefore(&forwardRefOps, forwardRefOps.end());
+ } else {
+ forwardRefOps.push_back(Operation::create(forwardRefOpState));
+ }
+ return forwardRefOps.back().getResult(0);
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Points
+//===----------------------------------------------------------------------===//
+
+bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
+ return buffer.getBuffer().startswith("ML\xefR");
+}
+
+LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
+ const ParserConfig &config) {
+ Location sourceFileLoc =
+ FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
+ /*line=*/0, /*column=*/0);
+ if (!isBytecode(buffer)) {
+ return emitError(sourceFileLoc,
+ "input buffer is not an MLIR bytecode file");
+ }
+
+ BytecodeReader reader(sourceFileLoc, config);
+ return reader.read(buffer, block);
+}
diff --git a/mlir/lib/Bytecode/Reader/CMakeLists.txt b/mlir/lib/Bytecode/Reader/CMakeLists.txt
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Reader/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(MLIRBytecodeReader
+ BytecodeReader.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode
+
+ LINK_LIBS PUBLIC
+ MLIRAsmParser
+ MLIRIR
+ MLIRSupport
+ )
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -0,0 +1,446 @@
+//===- BytecodeWriter.cpp - MLIR Bytecode Writer --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "../Encoding.h"
+#include "IRNumbering.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Support/Debug.h"
+#include
+
+#define DEBUG_TYPE "mlir-bytecode-writer"
+
+using namespace mlir;
+using namespace mlir::bytecode::detail;
+
+//===----------------------------------------------------------------------===//
+// EncodingEmitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class functions as the underlying encoding emitter for the bytecode
+/// writer. This class is a bit different compared to other types of encoders;
+/// it does not use a single buffer, but instead may contain several buffers
+/// (some owned by the writer, and some not) that get concatted during the final
+/// emission.
+class EncodingEmitter {
+public:
+ EncodingEmitter() = default;
+ EncodingEmitter(const EncodingEmitter &) = delete;
+ EncodingEmitter &operator=(const EncodingEmitter &) = delete;
+
+ /// Write the current contents to the provided stream.
+ void writeTo(raw_ostream &os) const;
+
+ /// Return the current size of the encoded buffer.
+ size_t size() const { return prevResultSize + currentResult.size(); }
+
+ //===--------------------------------------------------------------------===//
+ // Emission
+ //===--------------------------------------------------------------------===//
+
+ /// Return a raw pointer into the result buffer at the specified offset.
+ uint8_t *getRawPointer(uint64_t offset) {
+ assert(offset < size() && offset >= prevResultSize &&
+ "cannot get pointer to previously emitted data");
+ return currentResult.data() + (offset - prevResultSize);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Integer Emission
+
+ /// Emit a single byte.
+ template
+ void emitByte(T byte) {
+ currentResult.push_back(static_cast(byte));
+ }
+
+ /// Emit a range of bytes.
+ void emitBytes(ArrayRef bytes) {
+ llvm::append_range(currentResult, bytes);
+ }
+
+ /// Emit a variable length integer. The first encoded byte contains a prefix
+ /// in the low bits indicating the encoded length of the value. This length
+ /// prefix is a bit sequence of '0's followed by a '1'. The number of '0' bits
+ /// indicate the number of _additional_ bytes (not including the prefix byte).
+ /// All remaining bits in the first byte, along with all of the bits in
+ /// additional bytes, provide the value of the integer encoded in
+ /// little-endian order.
+ void emitVarInt(uint64_t value) {
+ // In the most common case, the value can be represented in a single byte.
+ // Given how hot this case is, explicitly handle that here.
+ if ((value >> 7) == 0)
+ return emitByte((value << 1) | 0x1);
+ emitMultiByteVarInt(value);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // String Emission
+
+ /// Emit the given string as a nul terminated string.
+ void emitNulTerminatedString(StringRef str) {
+ emitString(str);
+ emitByte(0);
+ }
+
+ /// Emit the given string without a nul terminator.
+ void emitString(StringRef str) {
+ emitBytes({reinterpret_cast(str.data()), str.size()});
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Section Emission
+
+ /// Emit a nested section of the given code, whose contents are encoded in the
+ /// provided emitter.
+ void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) {
+ // Emit the section code and length.
+ emitByte(code);
+ emitVarInt(emitter.size());
+
+ // Push our current buffer and then merge the provided section body into
+ // ours.
+ appendResult(std::move(currentResult));
+ for (std::vector &result : emitter.prevResultStorage)
+ appendResult(std::move(result));
+ appendResult(std::move(emitter.currentResult));
+ }
+
+private:
+ /// Emit the given value using a variable width encoding. This method is a
+ /// fallback when the number of bytes needed to encode the value is greater
+ /// than 1. We mark it noinline here so that the single byte hot path isn't
+ /// pessimized.
+ LLVM_ATTRIBUTE_NOINLINE void emitMultiByteVarInt(uint64_t value);
+
+ /// Append a new result buffer to the current contents.
+ void appendResult(std::vector &&result) {
+ prevResultSize += result.size();
+ prevResultStorage.emplace_back(std::move(result));
+ prevResultList.emplace_back(prevResultStorage.back());
+ }
+
+ /// The result of the emitter currently being built. We refrain from building
+ /// a single buffer to simplify emitting sections, large data, and more. The
+ /// result is thus represented using multiple distinct buffers, some of which
+ /// we own (via prevResultStorage), and some of which are just pointers into
+ /// externally owned buffers.
+ std::vector currentResult;
+ std::vector> prevResultList;
+ std::vector> prevResultStorage;
+
+ /// An up-to-date total size of all of the buffers within `prevResultList`.
+ /// This enables O(1) size checks of the current encoding.
+ size_t prevResultSize = 0;
+};
+
+/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need
+/// to go through an intermediate buffer when interacting with code that wants a
+/// raw_ostream.
+class raw_emitter_ostream : public raw_ostream {
+public:
+ explicit raw_emitter_ostream(EncodingEmitter &emitter) : emitter(emitter) {
+ SetUnbuffered();
+ }
+
+private:
+ void write_impl(const char *ptr, size_t size) override {
+ emitter.emitBytes({reinterpret_cast(ptr), size});
+ }
+ uint64_t current_pos() const override { return emitter.size(); }
+
+ /// The section being emitted to.
+ EncodingEmitter &emitter;
+};
+} // namespace
+
+void EncodingEmitter::writeTo(raw_ostream &os) const {
+ for (auto &prevResult : prevResultList)
+ os.write((const char *)prevResult.data(), prevResult.size());
+ os.write((const char *)currentResult.data(), currentResult.size());
+}
+
+void EncodingEmitter::emitMultiByteVarInt(uint64_t value) {
+ // Compute the number of bytes needed to encode the value. Each byte can hold
+ // up to 7-bits of data. We only check up to the number of bits we can encode
+ // in the first byte (8).
+ uint64_t it = value >> 7;
+ for (size_t numBytes = 2; numBytes < 9; ++numBytes) {
+ if (LLVM_LIKELY(it >>= 7) == 0) {
+ uint64_t encodedValue = (value << 1) | 0x1;
+ encodedValue <<= (numBytes - 1);
+ emitBytes({reinterpret_cast(&encodedValue), numBytes});
+ return;
+ }
+ }
+
+ // If the value is too large to encode in a single byte, emit a special all
+ // zero marker byte and splat the value directly.
+ emitByte(0);
+ emitBytes({reinterpret_cast(&value), sizeof(value)});
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode Writer
+//===----------------------------------------------------------------------===//
+
+namespace {
+class BytecodeWriter {
+public:
+ BytecodeWriter(Operation *op) : numberingState(op) {}
+
+ /// Write the bytecode for the given root operation.
+ void write(Operation *rootOp, raw_ostream &os);
+
+private:
+ //===--------------------------------------------------------------------===//
+ // Dialects
+
+ void writeDialectSection(EncodingEmitter &emitter);
+
+ //===--------------------------------------------------------------------===//
+ // Attributes and Types
+
+ void writeAttrTypeSection(EncodingEmitter &emitter);
+
+ //===--------------------------------------------------------------------===//
+ // Operations
+
+ void writeBlock(EncodingEmitter &emitter, Block *block);
+ void writeOp(EncodingEmitter &emitter, Operation *op);
+ void writeRegion(EncodingEmitter &emitter, Region *region);
+ void writeIRSection(EncodingEmitter &emitter, Operation *op);
+
+ //===--------------------------------------------------------------------===//
+ // Fields
+
+ /// The IR numbering state generated for the root operation.
+ IRNumberingState numberingState;
+};
+} // namespace
+
+void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
+ EncodingEmitter emitter;
+
+ // Emit the bytecode file header. This is how we identify the output as a
+ // bytecode file.
+ emitter.emitString("ML\xefR");
+
+ // Emit the bytecode version.
+ emitter.emitVarInt(bytecode::kVersion);
+
+ // Emit the dialect section.
+ writeDialectSection(emitter);
+
+ // Emit the attributes and types section.
+ writeAttrTypeSection(emitter);
+
+ // Emit the IR section.
+ writeIRSection(emitter, rootOp);
+
+ // Write the generated bytecode to the provided output stream.
+ emitter.writeTo(os);
+}
+
+//===----------------------------------------------------------------------===//
+// Dialects
+
+void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
+ EncodingEmitter dialectEmitter;
+
+ // Emit the referenced dialects.
+ for (DialectNumbering &dialect : numberingState.getDialects()) {
+ // Emit the dialect name.
+ dialectEmitter.emitNulTerminatedString(dialect.name);
+
+ // Emit the number of attributes and types emitted for this dialect.
+ dialectEmitter.emitVarInt(dialect.attributes.size());
+ dialectEmitter.emitVarInt(dialect.types.size());
+
+ // Emit the referenced operation names of this dialect.
+ dialectEmitter.emitVarInt(dialect.opNames.size());
+ for (OpNameNumbering *opName : dialect.opNames)
+ dialectEmitter.emitNulTerminatedString(opName->name.stripDialect());
+ }
+
+ emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter));
+}
+
+//===----------------------------------------------------------------------===//
+// Attributes and Types
+
+void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
+ EncodingEmitter attrTypeEmitter;
+ EncodingEmitter offsetEmitter;
+
+ // A functor used to emit an attribute or type entry.
+ uint64_t prevOffset = 0;
+ auto emitAttrOrType = [&](auto value) {
+ // Emit the entry using the textual format.
+ // TODO: Allow dialects to provide more optimal implementations of attribute
+ // and type encodings.
+ attrTypeEmitter.emitByte(bytecode::AttrTypeCode::kAsmForm);
+ raw_emitter_ostream(attrTypeEmitter) << value;
+ attrTypeEmitter.emitByte(0);
+
+ // Record the offset of this entry.
+ uint64_t curOffset = attrTypeEmitter.size();
+ offsetEmitter.emitVarInt(curOffset - prevOffset);
+ prevOffset = curOffset;
+ };
+
+ // Emit the attribute and type entries for each dialect.
+ for (DialectNumbering &dialect : numberingState.getDialects())
+ for (AttributeNumbering *attr : dialect.attributes)
+ emitAttrOrType(attr->getValue());
+ for (DialectNumbering &dialect : numberingState.getDialects())
+ for (TypeNumbering *type : dialect.types)
+ emitAttrOrType(type->getValue());
+
+ // Emit the sections to the stream.
+ emitter.emitSection(bytecode::Section::kAttrTypeOffset,
+ std::move(offsetEmitter));
+ emitter.emitSection(bytecode::Section::kAttrType, std::move(attrTypeEmitter));
+}
+
+//===----------------------------------------------------------------------===//
+// Operations
+
+void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) {
+ ArrayRef args = block->getArguments();
+ bool hasArgs = !args.empty();
+
+ // Emit the number of operations in this block, and if it has arguments. We
+ // use the low bit of the operation count to indicate if the block has
+ // arguments.
+ unsigned numOps = numberingState.getOperationCount(block);
+ emitter.emitVarInt((numOps << 1) | (hasArgs ? 1 : 0));
+
+ // Emit the arguments of the block.
+ if (hasArgs) {
+ emitter.emitVarInt(args.size());
+ for (BlockArgument arg : args) {
+ emitter.emitVarInt(numberingState.getNumber(arg.getType()));
+ emitter.emitVarInt(numberingState.getNumber(arg.getLoc()));
+ }
+ }
+
+ // Emit the operations within the block.
+ for (Operation &op : *block)
+ writeOp(emitter, &op);
+}
+
+void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
+ emitter.emitVarInt(numberingState.getNumber(op->getName()));
+
+ // Emit a mask for the operation components. We need to fill this in later
+ // (when we actually know what needs to be emitted), so emit a placeholder for
+ // now.
+ uint64_t maskOffset = emitter.size();
+ uint8_t opEncodingMask = 0;
+ emitter.emitByte(0);
+
+ // Emit the location for this operation.
+ emitter.emitVarInt(numberingState.getNumber(op->getLoc()));
+
+ // Emit the attributes of this operation.
+ DictionaryAttr attrs = op->getAttrDictionary();
+ if (!attrs.empty()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs;
+ emitter.emitVarInt(numberingState.getNumber(op->getAttrDictionary()));
+ }
+
+ // Emit the result types of the operation.
+ if (unsigned numResults = op->getNumResults()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasResults;
+ emitter.emitVarInt(numResults);
+ for (Type type : op->getResultTypes())
+ emitter.emitVarInt(numberingState.getNumber(type));
+ }
+
+ // Emit the operands of the operation.
+ if (unsigned numOperands = op->getNumOperands()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasOperands;
+ emitter.emitVarInt(numOperands);
+ for (Value operand : op->getOperands())
+ emitter.emitVarInt(numberingState.getNumber(operand));
+ }
+
+ // Emit the successors of the operation.
+ if (unsigned numSuccessors = op->getNumSuccessors()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasSuccessors;
+ emitter.emitVarInt(numSuccessors);
+ for (Block *successor : op->getSuccessors())
+ emitter.emitVarInt(numberingState.getNumber(successor));
+ }
+
+ // Check for regions.
+ unsigned numRegions = op->getNumRegions();
+ if (numRegions)
+ opEncodingMask |= bytecode::OpEncodingMask::kHasInlineRegions;
+
+ // Update the mask for the operation.
+ *emitter.getRawPointer(maskOffset) = opEncodingMask;
+
+ // With the mask emitted, we can now emit the regions of the operation. We do
+ // this after mask emission to avoid offset complications that may arise by
+ // emitting the regions first (e.g. if the regions are huge, backpatching the
+ // op encoding mask is more annoying).
+ if (numRegions) {
+ bool isIsolatedFromAbove = op->hasTrait();
+ emitter.emitVarInt((numRegions << 1) | (isIsolatedFromAbove ? 1 : 0));
+
+ for (Region ®ion : op->getRegions())
+ writeRegion(emitter, ®ion);
+ }
+}
+
+void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region) {
+ // If the region is empty, we only need to emit the number of blocks (which is
+ // zero).
+ if (region->empty())
+ return emitter.emitVarInt(/*numBlocks*/ 0);
+
+ // Emit the number of blocks and values within the region.
+ unsigned numBlocks, numValues;
+ std::tie(numBlocks, numValues) = numberingState.getBlockValueCount(region);
+ emitter.emitVarInt(numBlocks);
+ emitter.emitVarInt(numValues);
+
+ // Emit the blocks within the region.
+ for (Block &block : *region)
+ writeBlock(emitter, &block);
+}
+
+void BytecodeWriter::writeIRSection(EncodingEmitter &emitter, Operation *op) {
+ EncodingEmitter irEmitter;
+
+ // Write the IR section the same way as a block with no arguments. Note that
+ // the low-bit of the operation count for a block is used to indicate if the
+ // block has arguments, which in this case is always false.
+ uint64_t numOps = 1;
+ irEmitter.emitVarInt(numOps << 1);
+
+ // Emit the operations.
+ writeOp(irEmitter, op);
+
+ emitter.emitSection(bytecode::Section::kIR, std::move(irEmitter));
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Points
+//===----------------------------------------------------------------------===//
+
+void mlir::writeBytecodeToFile(Operation *op, raw_ostream &os) {
+ BytecodeWriter writer(op);
+ writer.write(op, os);
+}
diff --git a/mlir/lib/Bytecode/Writer/CMakeLists.txt b/mlir/lib/Bytecode/Writer/CMakeLists.txt
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(MLIRBytecodeWriter
+ BytecodeWriter.cpp
+ IRNumbering.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSupport
+ )
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -0,0 +1,182 @@
+//===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains various utilities that number IR structures in preparation
+// for bytecode emission.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
+#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
+
+#include "mlir/IR/OperationSupport.h"
+#include "llvm/ADT/MapVector.h"
+
+namespace mlir {
+class BytecodeWriterConfig;
+
+namespace bytecode {
+namespace detail {
+struct DialectNumbering;
+
+//===----------------------------------------------------------------------===//
+// Attribute and Type Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents a numbering entry for an Attribute or Type.
+struct AttrTypeNumbering {
+ AttrTypeNumbering(PointerUnion value) : value(value) {}
+
+ /// The concrete value.
+ PointerUnion value;
+
+ /// The number assigned to this value.
+ unsigned number = 0;
+
+ /// The dialect of this value.
+ DialectNumbering *dialect = nullptr;
+};
+struct AttributeNumbering : public AttrTypeNumbering {
+ AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {}
+ Attribute getValue() const { return value.get(); }
+};
+struct TypeNumbering : public AttrTypeNumbering {
+ TypeNumbering(Type value) : AttrTypeNumbering(value) {}
+ Type getValue() const { return value.get(); }
+};
+
+//===----------------------------------------------------------------------===//
+// OpName Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents the numbering entry of an operation name.
+struct OpNameNumbering {
+ OpNameNumbering(OperationName name) : name(name) {}
+
+ /// The concrete name.
+ OperationName name;
+
+ /// The number assigned to this name.
+ unsigned number = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// Dialect Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents a numbering entry for an Dialect.
+struct DialectNumbering {
+ DialectNumbering(StringRef name, unsigned number)
+ : name(name), number(number) {}
+
+ /// The namespace of the dialect.
+ StringRef name;
+
+ /// The number assigned to the dialect.
+ unsigned number;
+
+ /// The loaded dialect, or nullptr if the dialect isn't loaded.
+ Dialect *dialect = nullptr;
+
+ /// Numbered sub-components of the dialect to be emitted.
+ std::vector opNames;
+ std::vector attributes;
+ std::vector types;
+};
+
+//===----------------------------------------------------------------------===//
+// IRNumberingState
+//===----------------------------------------------------------------------===//
+
+/// This class manages numbering IR entities in preparation of bytecode
+/// emission.
+class IRNumberingState {
+public:
+ IRNumberingState(Operation *op);
+
+ /// Return the numbered dialects.
+ auto getDialects() {
+ return llvm::make_pointee_range(llvm::make_second_range(dialects));
+ }
+
+ /// Return the number for the given IR unit.
+ unsigned getNumber(Attribute attr) {
+ assert(attrs.count(attr) && "attribute not numbered");
+ return attrs[attr]->number;
+ }
+ unsigned getNumber(Block *block) {
+ assert(blockIDs.count(block) && "block not numbered");
+ return blockIDs[block];
+ }
+ unsigned getNumber(OperationName opName) {
+ assert(opNames.count(opName) && "opName not numbered");
+ return opNames[opName]->number;
+ }
+ unsigned getNumber(Type type) {
+ assert(types.count(type) && "type not numbered");
+ return types[type]->number;
+ }
+ unsigned getNumber(Value value) {
+ assert(valueIDs.count(value) && "value not numbered");
+ return valueIDs[value];
+ }
+
+ /// Return the block and value counts of the given region.
+ std::pair getBlockValueCount(Region *region) {
+ assert(regionBlockValueCounts.count(region) && "value not numbered");
+ return regionBlockValueCounts[region];
+ }
+
+ /// Return the number of operations in the given block.
+ unsigned getOperationCount(Block *block) {
+ assert(blockOperationCounts.count(block) && "block not numbered");
+ return blockOperationCounts[block];
+ }
+
+private:
+ /// Number the given IR unit for bytecode emission.
+ void number(Attribute attr);
+ void number(Block &block);
+ DialectNumbering &numberDialect(Dialect *dialect);
+ DialectNumbering &numberDialect(StringRef dialect);
+ void number(Operation &op);
+ void number(OperationName opName);
+ void number(Region ®ion);
+ void number(Type type);
+
+ /// Mapping from IR to the respective numbering entries.
+ llvm::MapVector attrs;
+ llvm::MapVector opNames;
+ llvm::MapVector types;
+ llvm::MapVector registeredDialects;
+ llvm::MapVector dialects;
+
+ /// Allocators used for the various numbering entries.
+ llvm::SpecificBumpPtrAllocator attrAllocator;
+ llvm::SpecificBumpPtrAllocator dialectAllocator;
+ llvm::SpecificBumpPtrAllocator opNameAllocator;
+ llvm::SpecificBumpPtrAllocator typeAllocator;
+
+ /// The value ID for each Block and Value.
+ DenseMap blockIDs;
+ DenseMap valueIDs;
+
+ /// The number of operations in each block.
+ DenseMap blockOperationCounts;
+
+ /// A map from region to the number of blocks and values within that region.
+ DenseMap> regionBlockValueCounts;
+
+ /// The next value ID to assign when numbering.
+ unsigned nextValueID = 0;
+};
+} // namespace detail
+} // namespace bytecode
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -0,0 +1,172 @@
+//===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRNumbering.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+
+using namespace mlir;
+using namespace mlir::bytecode::detail;
+
+//===----------------------------------------------------------------------===//
+// IR Numbering
+//===----------------------------------------------------------------------===//
+
+IRNumberingState::IRNumberingState(Operation *op) {
+ // Number the root operation.
+ number(*op);
+
+ // Push all of the regions of the root operation onto the worklist.
+ SmallVector, 8> numberContext;
+ for (Region ®ion : op->getRegions())
+ numberContext.emplace_back(®ion, nextValueID);
+
+ // Iteratively process each of the nested regions.
+ while (!numberContext.empty()) {
+ Region *region;
+ std::tie(region, nextValueID) = numberContext.pop_back_val();
+ number(*region);
+
+ // Traverse into nested regions.
+ for (Operation &op : region->getOps()) {
+ // Isolated regions don't share value numbers with their parent, so we can
+ // start numbering these regions at zero.
+ unsigned opFirstValueID =
+ op.hasTrait() ? 0 : nextValueID;
+ for (Region ®ion : op.getRegions())
+ numberContext.emplace_back(®ion, opFirstValueID);
+ }
+ }
+
+ // Walk and number the recorded components within each dialect.
+ unsigned attrID = 0, opNameID = 0, typeID = 0;
+ for (DialectNumbering *dialect : llvm::make_second_range(dialects)) {
+ for (AttributeNumbering *attr : dialect->attributes)
+ attr->number = attrID++;
+ for (OpNameNumbering *opName : dialect->opNames)
+ opName->number = opNameID++;
+ for (TypeNumbering *type : dialect->types)
+ type->number = typeID++;
+ }
+}
+
+void IRNumberingState::number(Attribute attr) {
+ auto it = attrs.insert({attr, nullptr});
+ if (!it.second)
+ return;
+ auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
+ it.first->second = numbering;
+
+ // Check for OpaqueAttr, which is a dialect-specific attribute that didn't
+ // have a registered dialect when it got created. We don't want to encode this
+ // as the builtin OpaqueAttr, we want to encode it as if the dialect was
+ // actually loaded.
+ if (OpaqueAttr opaqueAttr = attr.dyn_cast())
+ numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
+ else
+ numbering->dialect = &numberDialect(&attr.getDialect());
+
+ numbering->dialect->attributes.push_back(numbering);
+}
+
+void IRNumberingState::number(Block &block) {
+ // Number the arguments of the block.
+ for (BlockArgument arg : block.getArguments()) {
+ valueIDs.try_emplace(arg, nextValueID++);
+ number(arg.getLoc());
+ number(arg.getType());
+ }
+
+ // Number the operations in this block.
+ unsigned &numOps = blockOperationCounts[&block];
+ for (Operation &op : block) {
+ number(op);
+ ++numOps;
+ }
+}
+
+auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
+ DialectNumbering *&numbering = registeredDialects[dialect];
+ if (!numbering) {
+ numbering = &numberDialect(dialect->getNamespace());
+ numbering->dialect = dialect;
+ }
+ return *numbering;
+}
+
+auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
+ DialectNumbering *&numbering = dialects[dialect];
+ if (!numbering) {
+ numbering = new (dialectAllocator.Allocate())
+ DialectNumbering(dialect, dialects.size() - 1);
+ }
+ return *numbering;
+}
+
+void IRNumberingState::number(Region ®ion) {
+ if (region.empty())
+ return;
+ size_t firstValueID = nextValueID;
+
+ // Number the blocks within this region.
+ size_t blockCount = 0;
+ for (auto &it : llvm::enumerate(region)) {
+ blockIDs.try_emplace(&it.value(), it.index());
+ number(it.value());
+ ++blockCount;
+ }
+
+ // Remember the number of blocks and values in this region.
+ regionBlockValueCounts.try_emplace(®ion, blockCount,
+ nextValueID - firstValueID);
+}
+
+void IRNumberingState::number(Operation &op) {
+ // Number the components of an operation that won't be numbered elsewhere
+ // (e.g. we don't number operands, regions, or successors here).
+ number(op.getName());
+ for (OpResult result : op.getResults()) {
+ valueIDs.try_emplace(result, nextValueID++);
+ number(result.getType());
+ }
+ number(op.getAttrDictionary());
+ number(op.getLoc());
+}
+
+void IRNumberingState::number(OperationName opName) {
+ OpNameNumbering *&numbering = opNames[opName];
+ if (numbering)
+ return;
+ DialectNumbering *dialectNumber = nullptr;
+ if (Dialect *dialect = opName.getDialect())
+ dialectNumber = &numberDialect(dialect);
+ else
+ dialectNumber = &numberDialect(opName.getDialectNamespace());
+ numbering = new (opNameAllocator.Allocate()) OpNameNumbering(opName);
+ dialectNumber->opNames.emplace_back(numbering);
+}
+
+void IRNumberingState::number(Type type) {
+ auto it = types.insert({type, nullptr});
+ if (!it.second)
+ return;
+ auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
+ it.first->second = numbering;
+
+ // Check for OpaqueType, which is a dialect-specific type that didn't have a
+ // registered dialect when it got created. We don't want to encode this as the
+ // builtin OpaqueType, we want to encode it as if the dialect was actually
+ // loaded.
+ if (OpaqueType opaqueType = type.dyn_cast())
+ numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
+ else
+ numbering->dialect = &numberDialect(&type.getDialect());
+
+ numbering->dialect->types.push_back(numbering);
+}
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -3,6 +3,7 @@
add_subdirectory(Analysis)
add_subdirectory(AsmParser)
+add_subdirectory(Bytecode)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(IR)
diff --git a/mlir/lib/Parser/CMakeLists.txt b/mlir/lib/Parser/CMakeLists.txt
--- a/mlir/lib/Parser/CMakeLists.txt
+++ b/mlir/lib/Parser/CMakeLists.txt
@@ -6,5 +6,6 @@
LINK_LIBS PUBLIC
MLIRAsmParser
+ MLIRBytecodeReader
MLIRIR
)
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -12,6 +12,7 @@
#include "mlir/Parser/Parser.h"
#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/Bytecode/BytecodeReader.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
@@ -25,6 +26,8 @@
sourceBuf->getBufferIdentifier(),
/*line=*/0, /*column=*/0);
}
+ if (isBytecode(*sourceBuf))
+ return readBytecodeFile(*sourceBuf, block, config);
return parseAsmSourceFile(sourceMgr, block, config);
}
diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
--- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt
+++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
@@ -5,6 +5,7 @@
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-opt
LINK_LIBS PUBLIC
+ MLIRBytecodeWriter
MLIRPass
MLIRParser
MLIRSupport
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
@@ -47,7 +48,8 @@
static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
bool verifyPasses, SourceMgr &sourceMgr,
MLIRContext *context,
- PassPipelineFn passManagerSetupFn) {
+ PassPipelineFn passManagerSetupFn,
+ bool emitBytecode) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
@@ -86,8 +88,12 @@
// Print the output.
TimingScope outputTiming = timing.nest("Output");
- module->print(os);
- os << '\n';
+ if (emitBytecode) {
+ writeBytecodeToFile(module->getOperation(), os);
+ } else {
+ module->print(os);
+ os << '\n';
+ }
return success();
}
@@ -97,8 +103,8 @@
processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
- PassPipelineFn passManagerSetupFn, DialectRegistry ®istry,
- llvm::ThreadPool *threadPool) {
+ bool emitBytecode, PassPipelineFn passManagerSetupFn,
+ DialectRegistry ®istry, llvm::ThreadPool *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -122,7 +128,7 @@
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
- &context, passManagerSetupFn);
+ &context, passManagerSetupFn, emitBytecode);
}
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
@@ -131,7 +137,7 @@
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
- passManagerSetupFn);
+ passManagerSetupFn, emitBytecode);
// Verify the diagnostic handler to make sure that each of the diagnostics
// matched.
@@ -144,7 +150,8 @@
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext) {
+ bool preloadDialectsInContext,
+ bool emitBytecode) {
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
// We use an explicit threadpool to avoid creating and joining/destroying
@@ -163,8 +170,8 @@
raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
- preloadDialectsInContext, passManagerSetupFn, registry,
- threadPool);
+ preloadDialectsInContext, emitBytecode,
+ passManagerSetupFn, registry, threadPool);
};
return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
splitInputFile, /*insertMarkerInOutput=*/true);
@@ -176,7 +183,8 @@
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext) {
+ bool preloadDialectsInContext,
+ bool emitBytecode) {
auto passManagerSetupFn = [&](PassManager &pm) {
auto errorHandler = [&](const Twine &msg) {
emitError(UnknownLoc::get(pm.getContext())) << msg;
@@ -186,7 +194,8 @@
};
return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
registry, splitInputFile, verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, preloadDialectsInContext);
+ allowUnregisteredDialects, preloadDialectsInContext,
+ emitBytecode);
}
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
@@ -224,6 +233,10 @@
"show-dialects", cl::desc("Print the list of registered dialects"),
cl::init(false));
+ static cl::opt emitBytecode(
+ "emit-bytecode", cl::desc("Emit bytecode when generating output"),
+ cl::init(false));
+
InitLLVM y(argc, argv);
// Register any command line options.
@@ -268,7 +281,8 @@
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
splitInputFile, verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, preloadDialectsInContext)))
+ allowUnregisteredDialects, preloadDialectsInContext,
+ emitBytecode)))
return failure();
// Keep the output file if the invocation of MlirOptMain was successful.
diff --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bytecode/general.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -allow-unregistered-dialect -emit-bytecode %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: "bytecode.test1"
+// CHECK-NEXT: "bytecode.empty"() : () -> ()
+// CHECK-NEXT: "bytecode.attributes"() {attra = 10 : i64, attrb = #bytecode.attr} : () -> ()
+// CHECK-NEXT: %[[RESULTS:.*]]:3 = "bytecode.results"() : () -> (i32, i64, i32)
+// CHECK-NEXT: "bytecode.operands"(%[[RESULTS]]#0, %[[RESULTS]]#1, %[[RESULTS]]#2) : (i32, i64, i32) -> ()
+// CHECK-NEXT: "bytecode.branch"()[^[[BLOCK:.*]]] : () -> ()
+// CHECK-NEXT: ^[[BLOCK]](%[[ARG0:.*]]: i32, %[[ARG1:.*]]: !bytecode.int, %[[ARG2:.*]]: !pdl.operation):
+// CHECK-NEXT: "bytecode.regions"() ({
+// CHECK-NEXT: "bytecode.operands"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (i32, !bytecode.int, !pdl.operation) -> ()
+// CHECK-NEXT: "bytecode.return"() : () -> ()
+// CHECK-NEXT: }) : () -> ()
+// CHECK-NEXT: "bytecode.return"() : () -> ()
+// CHECK-NEXT: }) : () -> ()
+
+"bytecode.test1"() ({
+ "bytecode.empty"() : () -> ()
+ "bytecode.attributes"() {attra = 10, attrb = #bytecode.attr} : () -> ()
+ %results:3 = "bytecode.results"() : () -> (i32, i64, i32)
+ "bytecode.operands"(%results#0, %results#1, %results#2) : (i32, i64, i32) -> ()
+ "bytecode.branch"()[^secondBlock] : () -> ()
+
+^secondBlock(%arg1: i32, %arg2: !bytecode.int, %arg3: !pdl.operation):
+ "bytecode.regions"() ({
+ "bytecode.operands"(%arg1, %arg2, %arg3) : (i32, !bytecode.int, !pdl.operation) -> ()
+ "bytecode.return"() : () -> ()
+ }) : () -> ()
+ "bytecode.return"() : () -> ()
+}) : () -> ()