diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md
new file mode 100644
--- /dev/null
+++ b/mlir/docs/BytecodeFormat.md
@@ -0,0 +1,296 @@
+# MLIR Bytecode Format
+
+This documents describes the MLIR bytecode format and its encoding.
+
+[TOC]
+
+## Magic Number
+
+MLIR uses the following four-byte magic number to indicate bytecode files:
+
+'\[‘M’8, ‘L’8, ‘ï’8, ‘R’8\]'
+
+## Format Overview
+
+An MLIR Bytecode file is comprised of a byte stream, with a few simple
+structural concepts layered on top.
+
+### Primitives
+
+#### Fixed-Width Integers
+
+```
+ byte ::= `0x00`...`0xFF`
+```
+
+Fixed width integers are unsigned integers of a known byte size. The values are
+stored in little-endian byte order.
+
+TODO: Add larger fixed width integers as necessary.
+
+#### Variable-Width Integers
+
+Variable width integers, or `VarInt`s, provide a compact representation for
+integers. Each encoded VarInt consists of one to nine bytes, which together
+represent a single 64-bit value. The MLIR bytecode utilizes the "PrefixVarInt"
+encoding for VarInts. This encoding is a variant of the
+[LEB128 ("Little-Endian Base 128")](https://en.wikipedia.org/wiki/LEB128)
+encoding, where each byte of the encoding provides up to 7 bits for the value,
+with the remaining bit used to store a tag indicating the number of bytes used
+for the encoding. This means that small unsigned integers (less than 2^7) may be
+stored in one byte, unsigned integers up to 2^14 may be stored in two bytes,
+etc.
+
+The first byte of the encoding includes a length prefix in the low bits. This
+prefix is a bit sequence of '0's followed by a terminal '1', or the end of the
+byte. The number of '0' bits indicate the number of _additional_ bytes, not
+including the prefix byte, used to encode the value. All of the remaining bits
+in the first byte, along with all of the bits in the additional bytes, provide
+the value of the integer. Below are the various possible encodings of the prefix
+byte:
+
+```
+xxxxxxx1: 7 value bits, the encoding uses 1 byte
+xxxxxx10: 14 value bits, the encoding uses 2 bytes
+xxxxx100: 21 value bits, the encoding uses 3 bytes
+xxxx1000: 28 value bits, the encoding uses 4 bytes
+xxx10000: 35 value bits, the encoding uses 5 bytes
+xx100000: 42 value bits, the encoding uses 6 bytes
+x1000000: 49 value bits, the encoding uses 7 bytes
+10000000: 56 value bits, the encoding uses 8 bytes
+00000000: 64 value bits, the encoding uses 9 bytes
+```
+
+#### NUL Terminated Strings
+
+NUL Terminated Strings are terminated with the ASCII NUL character (whose byte
+value is zero). These are not used in cases where a string may contain an
+embedded NUL character. In cases that may hold an embedded NUL character, the
+string is encoded using a length and byte array.
+
+### Sections
+
+```
+section {
+ id: byte
+ length: varint
+}
+```
+
+Sections are a mechanism for grouping data within the bytecode. The enable
+delayed processing, which is useful for out-of-order processing of data,
+lazy-loading, and more. Each section contains a Section ID and a length (which
+allowing for skipping over the section).
+
+TODO: Sections should also carry an optional alignment. Add this when necessary.
+
+## MLIR Encoding
+
+Given the generic structure of MLIR, the bytecode encoding is actually fairly
+simplistic. It effectively maps to the core components of MLIR.
+
+### Top Level Structure
+
+The top-level structure of the bytecode contains the 4-byte "magic number", a
+version number, and a list of sections. Each section is currently only expected
+to appear once within a bytecode file.
+
+```
+bytecode {
+ magic: "MLïR",
+ version: varint,
+ sections: section[]
+}
+```
+
+### Dialect Section
+
+The dialect section of the bytecode contains all of the dialects referenced
+within the encoded IR, and some information about the components of those
+dialects that were also referenced.
+
+```
+dialect_section {
+ dialects: dialect[]
+}
+
+dialect {
+ name: nul_terminated_string,
+ numAttrs: varint,
+ numTypes: varint,
+ numOpNames: varint,
+ opNames: nul_terminated_string[]
+}
+```
+
+### Attribute/Type Sections
+
+Attributes and types are encoded using two [sections](#sections), one section
+(`attr_type_section`) containing the actual encoded representation, and another
+section (`attr_type_offset_section`) containing the offsets of each encoded
+attribute/type into the previous section. This allows for attributes and types
+to always be lazily loaded on demand.
+
+```
+attr_type_section {
+ attrs: attribute[],
+ types: type[]
+}
+attr_type_offset_section {
+ offset: varint[]
+}
+
+attribute {
+ code: byte, // kAsmForm
+ encoding: ...
+}
+type {
+ code: byte, // kAsmForm
+ encoding: ...
+}
+```
+
+Each `offset` in the `attr_type_offset_section` above is the size of the
+encoding for the attribute or type. We avoid using the direct offset into the
+`attr_type_section`, as a smaller relative offsets provides more effective
+compression. Attributes and types are grouped by dialect, with each dialect
+grouping in the same order of the dialects within the
+[dialect section](#dialect-section).
+
+#### Attribute/Type Encodings
+
+In the previous section, the forms of `attribute` and `type` both start with a
+`code` field. This field indicates how the attribute or type was encoded. In the
+abstract, an attribute/type is encoded in one of two possible ways: via its
+assembly format, or via a custom dialect defined encoding.
+
+##### Assembly Format Fallback
+
+In the case where a dialect does not define a method for encoding the attribute
+or type, the textual assembly format of that attribute or type is used as a
+fallback. For example, a type of `!bytecode.type` would be encoded as the null
+terminated string "!bytecode.type". This ensures that every attribute and type
+may be encoded, even if the owning dialect has not yet opted in to a more
+efficient serialization.
+
+##### Dialect Defined Encoding
+
+TODO: This is not yet supported.
+
+### IR Section
+
+The IR section contains the encoded form of operations within the bytecode.
+
+#### Operation Encoding
+
+```
+op {
+ name: varint,
+ encodingMask: byte,
+
+ location: varint?,
+
+ attrDict: varint?,
+
+ firstResultIndex: varint?,
+ numResults: varint?,
+ resultTypes: varint[],
+
+ numOperands: varint?,
+ operands: varint[],
+
+ numSuccessors: varint?,
+ successors: varint[],
+
+ numRegions: varint?,
+ regions: region[]
+}
+```
+
+The encoding of an operation is important because this is generally the most
+commonly appearing structure in the bytecode. A single encoding is used for
+every type of operation. Given this prevelance, many of the fields of an
+operation are optional. The `encodingMask` field is a bitmask which indicates
+which of the components of the operation are present.
+
+##### Location
+
+If necessary to encode, i.e. if the location for this operation is different
+than the location for the last operation or block argument, the index of the
+location within the attribute table is encoded.
+
+##### Attributes
+
+If the operation has attribues, the index of the operation attribute dictionary
+within the attribute table is encoded.
+
+##### Results
+
+If the operation has results, the value index of the first result is encoded.
+After that, the number of results and the indexes of the result types within the
+type table are encoded.
+
+##### Operands
+
+If the operation has operands, the number of operands and the value index of
+each operand is encoded.
+
+##### Successors
+
+If the operation has successors, the number of successors and the indexes of the
+successor blocks within the parent region are encoded.
+
+##### Regions
+
+If the operation has regions, the number of regions and each region are encoded.
+
+#### Region Encoding
+
+```
+region {
+ code: byte, // kRegion | kRegionEmpty
+
+ numBlocks: varint?,
+ numValues: varint?,
+ blocks: block[]
+}
+```
+
+A region is encoded with a leading code followed by the body. The code indicates
+how the body is encoded. If the code is `kRegionEmpty`, the region has no body.
+If the code is `kRegion`, the body is present.
+
+#### Block Encoding
+
+```
+block {
+ block_code: byte, // kBlockArguments | kOp | kBlockEnd
+ block_element: block_arguments | op | []
+}
+
+block_arguments {
+ code: byte, // kBlockArguments
+
+ firstArgIndex: varint,
+ numArgs: varint?,
+ args: block_argument[]
+
+}
+block_argument {
+ typeIndexAndHasLoc: varint, // (typeIndex << 1) | (hasLoc)
+ location: varint?
+}
+
+```
+
+A block is encoded with an array of elements determined by a leading code. The
+terminal `kBlockEnd` code indicates the end of a block. The `kOp` code indicates
+that an operation follows. If the block has arguments, the first element of the
+block will contain the encoded representation of the arguments, or
+`block_arguments` above. The encoding for the block arguments includes the value
+index of the first argument, the number of arguments, and an encoded list of
+arguments. The `typeIndexAndHasLoc` field of the argument is a varint that in
+the high-bits holds the index for the type of that argument, and in the low bit
+contains a flag that indicates if the argument has a location encoded along with
+it. A location is encoded if the argument had a different location than the
+previously encoded argument or operation.
diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h
new file mode 100644
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeReader.h
@@ -0,0 +1,34 @@
+//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines interfaces to read MLIR bytecode files/streams.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEREADER_H
+#define MLIR_BYTECODE_BYTECODEREADER_H
+
+#include "mlir/IR/AsmState.h"
+#include "mlir/Support/LLVM.h"
+
+namespace llvm {
+class MemoryBufferRef;
+} // namespace llvm
+
+namespace mlir {
+/// Returns true if the given buffer starts with the magic bytes that signal
+/// MLIR bytecode.
+bool isBytecode(llvm::MemoryBufferRef buffer);
+
+/// Read the operations defined within the given memory buffer, containing MLIR
+/// bytecode, into the provided block.
+LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
+ const ParserConfig &config);
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEREADER_H
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
new file mode 100644
--- /dev/null
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -0,0 +1,52 @@
+//===- BytecodeWriter.h - MLIR Bytecode Writer ------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines interfaces to write MLIR bytecode files/streams.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BYTECODE_BYTECODEWRITER_H
+#define MLIR_BYTECODE_BYTECODEWRITER_H
+
+#include "mlir/Support/LLVM.h"
+
+namespace mlir {
+class Operation;
+
+//===----------------------------------------------------------------------===//
+// BytecodeWriterConfig
+//===----------------------------------------------------------------------===//
+
+/// This class provides a configuration for the bytecode writer. It is the main
+/// injection of information into the writer.
+class BytecodeWriterConfig {
+ struct Impl;
+
+public:
+ BytecodeWriterConfig(Operation *op);
+ ~BytecodeWriterConfig();
+
+ /// Return the root operation of the writer.
+ Operation *getRootOp() const;
+
+private:
+ /// A pointer to the allocated storage for the impl state.
+ std::unique_ptr impl;
+};
+
+//===----------------------------------------------------------------------===//
+// Entry Points
+//===----------------------------------------------------------------------===//
+
+/// Write the given bytecode configuration to the provided output stream. For
+/// streams where it matters, the given stream should be in "binary" mode.
+void writeBytecodeToFile(const BytecodeWriterConfig &config, raw_ostream &os);
+
+} // namespace mlir
+
+#endif // MLIR_BYTECODE_BYTECODEWRITER_H
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -642,11 +642,11 @@
OperationState(Location location, OperationName name);
OperationState(Location location, OperationName name, ValueRange operands,
- TypeRange types, ArrayRef attributes,
+ TypeRange types, ArrayRef attributes = {},
BlockRange successors = {},
MutableArrayRef> regions = {});
OperationState(Location location, StringRef name, ValueRange operands,
- TypeRange types, ArrayRef attributes,
+ TypeRange types, ArrayRef attributes = {},
BlockRange successors = {},
MutableArrayRef> regions = {});
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -50,13 +50,15 @@
/// - preloadDialectsInContext will trigger the upfront loading of all
/// dialects from the global registry in the MLIRContext. This option is
/// deprecated and will be removed soon.
+/// - emitBytecode will generate bytecode output instead of text.
LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
std::unique_ptr buffer,
const PassPipelineCLParser &passPipeline,
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext = false);
+ bool preloadDialectsInContext = false,
+ bool emitBytecode = false);
/// Support a callback to setup the pass manager.
/// - passManagerSetupFn is the callback invoked to setup the pass manager to
@@ -67,7 +69,8 @@
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext = false);
+ bool preloadDialectsInContext = false,
+ bool emitBytecode = false);
/// Implementation for tools like `mlir-opt`.
/// - toolName is used for the header displayed by `--help`.
diff --git a/mlir/lib/Bytecode/CMakeLists.txt b/mlir/lib/Bytecode/CMakeLists.txt
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/CMakeLists.txt
@@ -0,0 +1,2 @@
+add_subdirectory(Reader)
+add_subdirectory(Writer)
\ No newline at end of file
diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Encoding.h
@@ -0,0 +1,127 @@
+//===- Encoding.h - MLIR binary format encoding information -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header defines enum values describing the structure of MLIR bytecode
+// files.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_BYTECODE_ENCODING_H
+#define LIB_MLIR_BYTECODE_ENCODING_H
+
+#include
+
+namespace mlir {
+namespace bytecode {
+//===----------------------------------------------------------------------===//
+// General constants
+//===----------------------------------------------------------------------===//
+
+enum {
+ /// The current bytecode version.
+ kVersion = 0,
+
+ /// The first non-builtin section code.
+ kFirstNonBuiltinCode = 16,
+};
+
+namespace BuiltinCode {
+enum : uint8_t {
+ /// This value indicates the code for a section.
+ kSection = 0,
+};
+} // namespace BuiltinCode
+
+//===----------------------------------------------------------------------===//
+// Sections
+//===----------------------------------------------------------------------===//
+
+namespace Section {
+enum ID : uint8_t {
+ /// This section contains the dialects referenced within an IR module.
+ kDialect = 0,
+
+ /// This section contains the attributes and types referenced within an IR
+ /// module.
+ kAttrType = 1,
+
+ /// This section contains the offsets for the attribute and types within the
+ /// AttrType section.
+ kAttrTypeOffset = 2,
+
+ /// This section contains the top level operation, and its nested
+ /// regions/operations.
+ kTopLevelOp = 3,
+
+ /// The total number of section types.
+ kNumSections = 4,
+};
+} // namespace Section
+
+//===----------------------------------------------------------------------===//
+// AttrType Section
+//===----------------------------------------------------------------------===//
+
+namespace AttrTypeCode {
+enum : uint8_t {
+ /// This code represents an attribute or type represented in the textual
+ /// assembly format.
+ kAsmForm,
+};
+} // namespace AttrTypeCode
+
+//===----------------------------------------------------------------------===//
+// kTopLevelOp Section
+//===----------------------------------------------------------------------===//
+
+namespace TopLevelOpCode {
+enum : uint8_t {
+ //===--------------------------------------------------------------------===//
+ // Operation Codes
+
+ /// This code represents an operation.
+ kOp = kFirstNonBuiltinCode,
+
+ //===--------------------------------------------------------------------===//
+ // Region Codes
+
+ /// This code represents a non-empty region.
+ kRegion,
+
+ /// This code represents an empty region.
+ kRegionEmpty,
+
+ //===--------------------------------------------------------------------===//
+ // Block Codes
+
+ /// This code represents the argument list of a block.
+ kBlockArguments,
+
+ /// This code represents the end of a block.
+ kBlockEnd,
+};
+} // namespace TopLevelOpCode
+
+/// This enum represents a mask of all of the potential components of an
+/// operation. This mask is used when encoding an operation to indicate which
+/// components are present in the bytecode.
+namespace OpEncodingMask {
+enum : uint8_t {
+ kHasLoc = 1 << 0,
+ kHasAttrs = 1 << 1,
+ kHasResults = 1 << 2,
+ kHasOperands = 1 << 3,
+ kHasSuccessors = 1 << 4,
+ kHasInlineRegions = 1 << 5,
+};
+} // namespace OpEncodingMask
+
+} // namespace bytecode
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp
@@ -0,0 +1,965 @@
+//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeReader.h"
+#include "../Encoding.h"
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Support/MemoryBufferRef.h"
+#include "llvm/Support/SaveAndRestore.h"
+
+#define DEBUG_TYPE "mlir-bytecode"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// EncodingReader
+//===----------------------------------------------------------------------===//
+
+namespace {
+class EncodingReader {
+public:
+ explicit EncodingReader(ArrayRef contents, Location fileLoc)
+ : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {}
+ explicit EncodingReader(StringRef contents, Location fileLoc)
+ : EncodingReader({reinterpret_cast(contents.data()),
+ contents.size()},
+ fileLoc) {}
+
+ /// Returns true if the entire section has been read.
+ bool empty() const { return dataIt == dataEnd; }
+
+ /// Returns the remaining size of the bytecode.
+ size_t size() const { return dataEnd - dataIt; }
+
+ /// Emit an error using the given arguments.
+ template
+ LogicalResult emitError(Args &&...args) const {
+ return ::emitError(fileLoc).append(std::forward(args)...);
+ }
+
+ /// Parse a single byte from the stream.
+ template
+ ParseResult parseByte(T &value) {
+ if (empty())
+ return emitError("attempting to parse a byte at the end of the bytecode");
+ value = *dataIt++;
+ return success();
+ }
+ /// Parse a range of bytes of 'length' into the given result.
+ ParseResult parseBytes(size_t length, ArrayRef &result) {
+ if (length > size()) {
+ return emitError("attempting to parse ", length, " bytes when only ",
+ size(), " remain");
+ }
+ result = {dataIt, length};
+ dataIt += length;
+ return success();
+ }
+ /// Parse a range of bytes of 'length' into the given result, which can be
+ /// assumed to be large enough to hold `length`.
+ ParseResult parseBytes(size_t length, uint8_t *result) {
+ if (length > size()) {
+ return emitError("attempting to parse ", length, " bytes when only ",
+ size(), " remain");
+ }
+ memcpy(result, dataIt, length);
+ dataIt += length;
+ return success();
+ }
+
+ /// Parse a variable length encoded integer from the byte stream. The first
+ /// encoded byte contains a prefix in the low bits indicating the encoded
+ /// length of the value. This length prefix is a bit sequence of '0's followed
+ /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes
+ /// (not including the prefix byte). All remaining bits in the first byte,
+ /// along with all of the bits in additional bytes, provide the value of the
+ /// integer encoded in little-endian order.
+ ParseResult parseVarInt(uint64_t &result) {
+ // Parse the first byte of the encoding, which contains the length prefix.
+ if (parseByte(result))
+ return failure();
+
+ // Handle the overwhelmingly common case where the value is stored in a
+ // single byte. In this case, the first bit is the `1` marker bit.
+ if (LLVM_LIKELY(result & 1)) {
+ result >>= 1;
+ return success();
+ }
+
+ // Handle the overwhelming uncommon case where the value required all 8
+ // bytes (i.e. a really really big number). In this case, the marker byte is
+ // all zeros: `00000000`.
+ if (LLVM_UNLIKELY(result == 0))
+ return parseBytes(sizeof(result), reinterpret_cast(&result));
+ return parseMultiByteVarInt(result);
+ }
+
+ /// Skip the first `length` bytes within the reader.
+ ParseResult skipBytes(size_t length) {
+ if (length > size()) {
+ return emitError("attempting to skip ", length, " bytes when only ",
+ size(), " remain");
+ }
+ dataIt += length;
+ return success();
+ }
+
+ /// Parse a NUL terminated string into `result` (without including the NUL
+ /// terminator).
+ ParseResult parseNULTerminatedString(StringRef &result) {
+ const char *startIt = (const char *)dataIt;
+ const char *nulIt = (const char *)memchr(startIt, 0, size());
+ if (!nulIt)
+ return emitError("malformed NUL terminated string, no NUL found");
+
+ result = StringRef(startIt, nulIt - startIt);
+ dataIt = (const uint8_t *)nulIt + 1;
+ return success();
+ }
+
+ /// Parse a section header, placing the kind of section in `sectionID` and the
+ /// contents of the section in `sectionData`.
+ ParseResult parseSection(uint8_t §ionID, ArrayRef §ionData) {
+ size_t length;
+ if (parseByte(sectionID) || parseVarInt(length))
+ return failure();
+
+ // Parse the actua section data now that we have its length.
+ return parseBytes(length, sectionData);
+ }
+
+private:
+ /// Parse a variable length encoded integer from the byte stream. This method
+ /// is a fallback when the number of bytes used to encode the value is greater
+ /// than 1, but less than the max (9). The provided `result` value can be
+ /// assumed to already contain the first byte of the value. This method is
+ /// marked noinline to avoid pessimizing the common case of single byte
+ /// encoding.
+ LLVM_ATTRIBUTE_NOINLINE ParseResult parseMultiByteVarInt(uint64_t &result) {
+ // Count the number of trailing zeros in the marker byte, this indicates the
+ // number of trailing bytes that are part of the value. We use `uint32_t`
+ // here because we only care about the first byte, and so that be actually
+ // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop
+ // implementation).
+ uint32_t numBytes =
+ llvm::countTrailingZeros(result, llvm::ZB_Undefined);
+
+ // Parse in the remaining bytes of the value.
+ if (parseBytes(numBytes, reinterpret_cast(&result) + 1))
+ return failure();
+
+ // Shift out the low-order bits that were used to mark how the value was
+ // encoded.
+ result >>= (numBytes + 1);
+ return success();
+ }
+
+ /// The current data iterator, and an iterator to the end of the buffer.
+ const uint8_t *dataIt, *dataEnd;
+
+ /// A location for the bytecode used to report errors.
+ Location fileLoc;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// BytecodeDialect
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This struct represents a dialect entry within the bytecode.
+struct BytecodeDialect {
+ BytecodeDialect(Dialect *dialect, StringRef name, unsigned numAttrs,
+ unsigned numTypes)
+ : dialect(dialect), name(name), numAttrs(numAttrs), numTypes(numTypes) {}
+
+ /// The loaded dialect entry, if available, otherwise nullptr.
+ Dialect *dialect;
+
+ /// The name of the dialect.
+ StringRef name;
+
+ /// The number of attributes owned by this dialect in the bytecode.
+ unsigned numAttrs;
+
+ /// The number of types owned by this dialect in the bytecode.
+ unsigned numTypes;
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Attribute/Type Reader
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class provides support for reading attribute and type entries from the
+/// bytecode. Attribute and Type entries are read lazily on demand, so we use
+/// this reader to manage when to actually parse them from the bytecode.
+class AttrTypeReader {
+ /// This class represents a single attribute or type entry.
+ template
+ struct Entry {
+ /// The entry, or null if it hasn't been resolved yet.
+ T entry = {};
+ /// The raw data of this entry in the bytecode.
+ ArrayRef data;
+ };
+ using AttrEntry = Entry;
+ using TypeEntry = Entry;
+
+public:
+ AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {}
+
+ /// Initialize the attribute and type information within the reader.
+ LogicalResult initialize(ArrayRef dialects,
+ ArrayRef sectionData,
+ ArrayRef offsetSectionData);
+
+ /// Resolve the attribute or type at the given index. Returns nullptr on
+ /// failure.
+ Attribute resolveAttribute(unsigned index) {
+ return resolveEntry(attributes, index, "Attribute");
+ }
+ Type resolveType(unsigned index) {
+ return resolveEntry(types, index, "Type");
+ }
+
+private:
+ /// Initialize the offsets for the attribute and type entries.
+ LogicalResult initializeOffsets(ArrayRef sectionData,
+ ArrayRef offsetSectionData);
+
+ /// Resolve the given entry at `index`.
+ template
+ T resolveEntry(SmallVectorImpl> &entries, unsigned index,
+ StringRef entryType);
+
+ /// Parse the value defined within the given reader. `code` indicates how the
+ /// entry was encoded.
+ LogicalResult parseEntry(EncodingReader &reader, uint8_t code,
+ Attribute &result);
+ LogicalResult parseEntry(EncodingReader &reader, uint8_t code, Type &result);
+
+ /// The set of attribute and type entries.
+ SmallVector attributes;
+ SmallVector types;
+
+ /// A location used for error emission.
+ Location fileLoc;
+};
+} // namespace
+
+LogicalResult AttrTypeReader::initialize(ArrayRef dialects,
+ ArrayRef sectionData,
+ ArrayRef offsetSectionData) {
+ // Initialize the entries using the dialect information.
+ unsigned numAttrs = 0, numTypes = 0;
+ for (const BytecodeDialect &dialect : dialects) {
+ numAttrs += dialect.numAttrs;
+ numTypes += dialect.numTypes;
+ }
+ attributes.resize(numAttrs);
+ types.resize(numTypes);
+
+ // With the entries initialized, we can process the offsets.
+ return initializeOffsets(sectionData, offsetSectionData);
+}
+
+LogicalResult
+AttrTypeReader::initializeOffsets(ArrayRef sectionData,
+ ArrayRef offsetSectionData) {
+ EncodingReader reader(offsetSectionData, fileLoc);
+
+ // A functor used to accumulate the offsets for the entries in the given
+ // range.
+ uint64_t currentOffset = 0;
+ auto accumulateOffsets = [&](auto &&range) {
+ for (auto &entry : range) {
+ uint64_t entrySize;
+ if (reader.parseVarInt(entrySize))
+ return failure();
+ entry.data = sectionData.slice(currentOffset, entrySize);
+ currentOffset += entrySize;
+ }
+ return success();
+ };
+
+ // Process each of the attributes, and then the types.
+ if (failed(accumulateOffsets(attributes)) || failed(accumulateOffsets(types)))
+ return failure();
+
+ // Ensure that we read everything from the section.
+ if (!reader.empty()) {
+ return reader.emitError(
+ "unexpected trailing data in the Attribute/Type offset section");
+ }
+ return success();
+}
+
+template
+T AttrTypeReader::resolveEntry(SmallVectorImpl> &entries,
+ unsigned index, StringRef entryType) {
+ if (index >= entries.size()) {
+ emitError(fileLoc) << "invalid " << entryType << "index:" << index;
+ return {};
+ }
+
+ // If the entry has already been resolved, there is nothing left to do.
+ Entry &entry = entries[index];
+ if (entry.entry)
+ return entry.entry;
+
+ // Parse the entry. Each entry starts with a specific code that indicates how
+ // it is represented.
+ EncodingReader reader(entry.data, fileLoc);
+ uint8_t code;
+ if (reader.parseByte(code) || failed(parseEntry(reader, code, entry.entry)))
+ return T();
+ if (!reader.empty()) {
+ (void)reader.emitError("unexpected trailing bytes after " + entryType +
+ " entry");
+ return T();
+ }
+ return entry.entry;
+}
+
+LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader, uint8_t code,
+ Attribute &result) {
+ // Handle the fallback case, where the attribute was encoded using its
+ // assembly format.
+ if (code == bytecode::AttrTypeCode::kAsmForm) {
+ StringRef attrStr;
+ if (failed(reader.parseNULTerminatedString(attrStr)))
+ return failure();
+
+ size_t numRead = 0;
+ if (!(result = parseAttribute(attrStr, fileLoc->getContext(), numRead)))
+ return failure();
+ if (numRead != attrStr.size()) {
+ return reader.emitError(
+ "trailing characters found after Attribute assembly format: ",
+ attrStr.drop_front(numRead));
+ }
+ return success();
+ }
+
+ return reader.emitError("unexpected Attribute encoding: ", code);
+}
+
+LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader, uint8_t code,
+ Type &result) {
+ // Handle the fallback case, where the type was encoded using its
+ // assembly format.
+ if (code == bytecode::AttrTypeCode::kAsmForm) {
+ StringRef typeStr;
+ if (failed(reader.parseNULTerminatedString(typeStr)))
+ return failure();
+
+ size_t numRead = 0;
+ if (!(result = parseType(typeStr, fileLoc->getContext(), numRead)))
+ return failure();
+ if (numRead != typeStr.size()) {
+ return reader.emitError(
+ "trailing characters found after Type assembly format: " +
+ typeStr.drop_front(numRead));
+ }
+ return success();
+ }
+
+ return reader.emitError("unexpected Type encoding: ", code);
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode Reader
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class is used to read a bytecode buffer and translate it into MLIR.
+class BytecodeReader {
+public:
+ BytecodeReader(Location fileLoc, const ParserConfig &config)
+ : config(config), fileLoc(fileLoc), attrTypeReader(fileLoc),
+ forwardRefOpState(UnknownLoc::get(config.getContext()),
+ "builtin.unrealized_conversion_cast", ValueRange(),
+ NoneType::get(config.getContext())) {}
+
+ /// Read the bytecode defined within `buffer` into the given block.
+ LogicalResult read(llvm::MemoryBufferRef buffer, Block *block);
+
+private:
+ /// Return the context for this config.
+ MLIRContext *getContext() const { return config.getContext(); }
+
+ //===--------------------------------------------------------------------===//
+ // Dialect Section
+
+ LogicalResult parseDialectSection(ArrayRef sectionData);
+
+ /// Parse an operation name reference using the given reader.
+ FailureOr parseOpName(EncodingReader &reader);
+
+ //===--------------------------------------------------------------------===//
+ // Attribute/Type Section
+
+ /// Parse an attribute or type using the given reader. Returns nullptr in the
+ /// case of failure.
+ Attribute parseAttribute(EncodingReader &reader);
+ Type parseType(EncodingReader &reader);
+
+ template
+ T parseAttribute(EncodingReader &reader) {
+ if (Attribute attr = parseAttribute(reader)) {
+ if (auto derivedAttr = attr.dyn_cast())
+ return derivedAttr;
+ (void)reader.emitError("expected attribute of type: ",
+ llvm::getTypeName(), ", but got: ", attr);
+ }
+ return T();
+ }
+
+ //===--------------------------------------------------------------------===//
+ // TopLevelOp Section
+
+ LogicalResult parseTopLevelOpSection(ArrayRef sectionData,
+ Block *block);
+ LogicalResult parseOp(EncodingReader &reader, Block *block,
+ ArrayRef regionBlocks, LocationAttr &lastLoc);
+ LogicalResult parseRegion(EncodingReader &reader, Region *region,
+ LocationAttr &lastLoc);
+ LogicalResult parseBlock(EncodingReader &reader, Block *block,
+ ArrayRef regionBlocks,
+ LocationAttr &lastLoc);
+ LogicalResult parseBlockArguments(EncodingReader &reader, Block *block,
+ LocationAttr &lastLoc);
+
+ //===--------------------------------------------------------------------===//
+ // Value Processing
+
+ /// Parse an operand reference using the given reader. Returns nullptr in the
+ /// case of failure.
+ Value parseOperand(EncodingReader &reader);
+
+ /// Sequentially define the given value range starting at the provided first
+ /// value ID.
+ LogicalResult defineValues(EncodingReader &reader, ValueRange values,
+ unsigned firstValueID);
+
+ /// Create a value to use for a forward reference.
+ Value createForwardRef();
+
+ //===--------------------------------------------------------------------===//
+ // Fields
+
+ /// The configuration of the parser.
+ const ParserConfig &config;
+
+ /// A location to use when emitting errors.
+ Location fileLoc;
+
+ /// The reader used to process attribute and types within the bytecode.
+ AttrTypeReader attrTypeReader;
+
+ /// The version of the bytecode being read.
+ uint64_t version = 0;
+
+ /// The table of IR units referenced within the bitcode file.
+ SmallVector dialects;
+ SmallVector opNames;
+
+ /// The current set of available IR values.
+ std::vector values;
+ /// A block containing the set of operations defined to create forward
+ /// references.
+ Block forwardRefOps;
+ /// A block containing previously created, and no longer used, forward
+ /// reference operations.
+ Block openForwardRefOps;
+ /// An operation state used when instantiating forward references.
+ OperationState forwardRefOpState;
+};
+} // namespace
+
+LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) {
+ EncodingReader reader(buffer.getBuffer(), fileLoc);
+
+ // Skip over the bytecode header, this should have already been checked.
+ if (reader.skipBytes(StringRef("ML\xefR").size()))
+ return failure();
+
+ // Parse the bytecode version.
+ if (reader.parseVarInt(version))
+ return failure();
+
+ // Validate the bytecode version.
+ if (version < bytecode::kVersion) {
+ return reader.emitError(
+ "bytecode version ", version, " is older than the current version of ",
+ bytecode::kVersion, ", and upgrade is not supported.");
+ }
+ if (version > bytecode::kVersion) {
+ return reader.emitError("bytecode version ", version,
+ " is newer than the current version ",
+ bytecode::kVersion, ".");
+ }
+
+ // The raw data for the AttrTypeOffset section.
+ Optional> attrTypeOffsetSection;
+
+ BitVector seenSections(bytecode::Section::kNumSections);
+ while (!reader.empty()) {
+ // Read the next section from the bytecode.
+ uint8_t code;
+ if (reader.parseByte(code) || code != bytecode::BuiltinCode::kSection)
+ return reader.emitError("expected top-level section code");
+ uint8_t sectionID;
+ ArrayRef sectionData;
+ if (reader.parseSection(sectionID, sectionData))
+ return failure();
+
+ // Check for duplicate sections, we only expect one instance of each.
+ if (seenSections.test(sectionID))
+ return reader.emitError("duplicate top-level section ID: ", sectionID);
+ seenSections.set(sectionID);
+
+ // Process the section.
+ switch (sectionID) {
+ case bytecode::Section::kDialect:
+ if (failed(parseDialectSection(sectionData)))
+ return failure();
+ break;
+ case bytecode::Section::kAttrType:
+ if (!attrTypeOffsetSection) {
+ return reader.emitError(
+ "expected the AttrTypeOffset section before the AttrType section");
+ }
+ if (dialects.empty()) {
+ return reader.emitError("expected the Dialect section before the"
+ "AttrTypeOffset section");
+ }
+
+ // With everything ready, initialize the attribute/type reader.
+ if (failed(attrTypeReader.initialize(dialects, sectionData,
+ *attrTypeOffsetSection)))
+ return failure();
+ break;
+ case bytecode::Section::kAttrTypeOffset:
+ // We won't parse this section until we process the main AttrType section.
+ // For now, just record the raw data.
+ attrTypeOffsetSection = sectionData;
+ break;
+ case bytecode::Section::kTopLevelOp:
+ if (failed(parseTopLevelOpSection(sectionData, block)))
+ return failure();
+ break;
+ default:
+ return reader.emitError("unexpected top-level section: ", sectionID);
+ }
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Dialect Section
+
+LogicalResult
+BytecodeReader::parseDialectSection(ArrayRef sectionData) {
+ MLIRContext *ctx = getContext();
+
+ EncodingReader sectionReader(sectionData, fileLoc);
+ while (!sectionReader.empty()) {
+ // Read the name of the next dialect.
+ StringRef dialectName;
+ if (sectionReader.parseNULTerminatedString(dialectName))
+ return failure();
+
+ // Parse the attribute and type counts.
+ uint64_t attrCount, typeCount;
+ if (sectionReader.parseVarInt(attrCount) ||
+ sectionReader.parseVarInt(typeCount))
+ return failure();
+
+ // Try to load the dialect.
+ Dialect *dialect = ctx->getOrLoadDialect(dialectName);
+ if (!dialect && !ctx->allowsUnregisteredDialects()) {
+ return sectionReader.emitError(
+ "dialect '", dialectName,
+ "' is unknown. If this is intended, please call "
+ "allowUnregisteredDialects() on the MLIRContext, or use "
+ "-allow-unregistered-dialect with the MLIR tool used.");
+ }
+ dialects.emplace_back(dialect, dialectName, attrCount, typeCount);
+
+ // Parse the operation names of the dialect.
+ uint64_t numOpNames;
+ if (sectionReader.parseVarInt(numOpNames))
+ return failure();
+ SmallString<32> opNameStorage({dialectName, "."});
+ while (numOpNames--) {
+ StringRef opName;
+ if (sectionReader.parseNULTerminatedString(opName))
+ return failure();
+
+ opNameStorage.resize(dialectName.size() + 1);
+ opNameStorage.append(opName);
+ opNames.push_back(OperationName(opNameStorage, ctx));
+ }
+ }
+ return success();
+}
+
+FailureOr BytecodeReader::parseOpName(EncodingReader &reader) {
+ uint64_t opNameIdx;
+ if (reader.parseVarInt(opNameIdx))
+ return failure();
+
+ if (opNameIdx >= opNames.size())
+ return reader.emitError("invalid operation name index: ", opNameIdx);
+ return opNames[opNameIdx];
+}
+
+//===----------------------------------------------------------------------===//
+// Attribute/Type Section
+
+Attribute BytecodeReader::parseAttribute(EncodingReader &reader) {
+ uint64_t attrIdx;
+ if (reader.parseVarInt(attrIdx))
+ return Attribute();
+ return attrTypeReader.resolveAttribute(attrIdx);
+}
+
+Type BytecodeReader::parseType(EncodingReader &reader) {
+ uint64_t typeIdx;
+ if (reader.parseVarInt(typeIdx))
+ return Type();
+ return attrTypeReader.resolveType(typeIdx);
+}
+
+//===----------------------------------------------------------------------===//
+// TopLevelOp Section
+
+LogicalResult
+BytecodeReader::parseTopLevelOpSection(ArrayRef sectionData,
+ Block *block) {
+ EncodingReader reader(sectionData, fileLoc);
+
+ LocationAttr lastLoc;
+ if (failed(parseOp(reader, block, /*regionBlocks=*/llvm::None, lastLoc)))
+ return failure();
+ if (!forwardRefOps.empty())
+ return reader.emitError(
+ "not all forward unresolved forward operand references");
+ return success();
+}
+
+LogicalResult BytecodeReader::parseOp(EncodingReader &reader, Block *block,
+ ArrayRef regionBlocks,
+ LocationAttr &lastLoc) {
+ // Parse the name of the operation.
+ FailureOr opName = parseOpName(reader);
+ if (failed(opName))
+ return failure();
+
+ // Parse the operation mask, which indicates which components of the operation
+ // are present.
+ uint8_t opMask;
+ if (reader.parseByte(opMask))
+ return failure();
+
+ /// Check to see if this op has a new location.
+ if (opMask & bytecode::OpEncodingMask::kHasLoc) {
+ if (!(lastLoc = parseAttribute(reader)))
+ return failure();
+ }
+
+ // With the location and name resolved, we can start building the operation
+ // state.
+ OperationState opState(lastLoc, *opName);
+
+ // Parse the attributes of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
+ DictionaryAttr dictAttr = parseAttribute(reader);
+ if (!dictAttr)
+ return failure();
+ opState.attributes = dictAttr;
+ }
+
+ /// Parse the results of the operation.
+ Optional firstResultID;
+ if (opMask & bytecode::OpEncodingMask::kHasResults) {
+ firstResultID.emplace(0);
+ if (reader.parseVarInt(*firstResultID))
+ return failure();
+
+ // Parse the result types.
+ uint64_t numResults;
+ if (reader.parseVarInt(numResults))
+ return failure();
+ opState.types.resize(numResults);
+ for (int i = 0, e = numResults; i < e; ++i)
+ if (!(opState.types[i] = parseType(reader)))
+ return failure();
+ }
+
+ /// Parse the operands of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasOperands) {
+ uint64_t numOperands;
+ if (reader.parseVarInt(numOperands))
+ return failure();
+ opState.operands.resize(numOperands);
+ for (int i = 0, e = numOperands; i < e; ++i)
+ if (!(opState.operands[i] = parseOperand(reader)))
+ return failure();
+ }
+
+ /// Parse the successors of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasSuccessors) {
+ uint64_t numSuccs;
+ if (reader.parseVarInt(numSuccs))
+ return failure();
+ opState.operands.reserve(numSuccs);
+ for (int i = 0, e = numSuccs; i < e; ++i) {
+ uint64_t succID;
+ if (reader.parseVarInt(succID))
+ return failure();
+ if (succID >= regionBlocks.size())
+ return reader.emitError("invalid successor index: ", succID);
+ opState.successors.push_back(regionBlocks[succID]);
+ }
+ }
+
+ /// Parse the regions of the operation.
+ if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) {
+ uint64_t numRegions;
+ if (reader.parseVarInt(numRegions))
+ return failure();
+ opState.regions.reserve(numRegions);
+ for (int i = 0, e = numRegions; i < e; ++i) {
+ opState.regions.push_back(std::make_unique());
+ if (failed(parseRegion(reader, &*opState.regions.back(), lastLoc)))
+ return failure();
+ }
+ }
+
+ // Create the operation.
+ Operation *op = Operation::create(opState);
+ block->push_back(op);
+
+ // If the operation had results, update the value references.
+ if (firstResultID)
+ return defineValues(reader, op->getResults(), *firstResultID);
+ return LogicalResult::success();
+}
+
+LogicalResult BytecodeReader::parseRegion(EncodingReader &reader,
+ Region *region,
+ LocationAttr &lastLoc) {
+ // Read the code defining how this region was encoded.
+ uint8_t regionCode;
+ if (reader.parseByte(regionCode))
+ return failure();
+
+ // If it's an empty region, there is nothing more to do.
+ if (regionCode == bytecode::TopLevelOpCode::kRegionEmpty)
+ return success();
+
+ // Otherwise, we need to parse the region body.
+ if (regionCode != bytecode::TopLevelOpCode::kRegion)
+ return reader.emitError("invalid region code: ", regionCode);
+
+ // Parse the number of blocks and values in this region.
+ uint64_t numBlocks, numValues;
+ if (reader.parseVarInt(numBlocks) || reader.parseVarInt(numValues))
+ return failure();
+
+ // Reserve enough values for those defined in this region. Make sure to reset
+ // the size of the value table after processing though.
+ size_t origNumValues = values.size();
+ auto atExit = llvm::make_scope_exit([&]() { values.resize(origNumValues); });
+ values.resize(values.size() + numValues);
+
+ // Create the blocks within this region. We do this before processing so that
+ // we can rely on the blocks existing when creating operations.
+ SmallVector regionBlocks;
+ regionBlocks.reserve(numBlocks);
+ for (uint64_t i = 0; i < numBlocks; ++i) {
+ regionBlocks.push_back(new Block());
+ region->push_back(regionBlocks.back());
+ }
+
+ for (uint64_t i = 0; i < numBlocks; ++i)
+ if (failed(parseBlock(reader, regionBlocks[i], regionBlocks, lastLoc)))
+ return failure();
+ return success();
+}
+
+LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, Block *block,
+ ArrayRef regionBlocks,
+ LocationAttr &lastLoc) {
+ // Parse the first code of the block explicitly in case the block has
+ // arguments.
+ uint8_t blockCode = 0;
+ if (reader.parseByte(blockCode))
+ return failure();
+
+ // Check for arguments to the block.
+ if (blockCode == bytecode::TopLevelOpCode::kBlockArguments) {
+ if (failed(parseBlockArguments(reader, block, lastLoc)))
+ return failure();
+
+ // Parse the next block code.
+ if (reader.parseByte(blockCode))
+ return failure();
+ }
+
+ while (blockCode != bytecode::TopLevelOpCode::kBlockEnd) {
+ // Parse an operation within the block.
+ if (blockCode == bytecode::TopLevelOpCode::kOp) {
+ if (failed(parseOp(reader, block, regionBlocks, lastLoc)))
+ return failure();
+ } else {
+ return reader.emitError("unknown block code: ", blockCode);
+ }
+
+ // Parse the next code.
+ if (reader.parseByte(blockCode))
+ return failure();
+ }
+ return success();
+}
+
+LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader,
+ Block *block,
+ LocationAttr &lastLoc) {
+ // Parse the value ID for the first argument, and the number of arguments.
+ uint64_t firstArgID, numArgs;
+ if (reader.parseVarInt(firstArgID) || reader.parseVarInt(numArgs))
+ return failure();
+
+ SmallVector argTypes;
+ SmallVector argLocs;
+ argTypes.reserve(numArgs);
+ argLocs.reserve(numArgs);
+ while (numArgs--) {
+ uint64_t typeIdx;
+ if (reader.parseVarInt(typeIdx))
+ return failure();
+
+ // Check the low bit of the type index to see if this argument has a new
+ // location.
+ bool hasNewLoc = (typeIdx & 1) != 0;
+ typeIdx >>= 1;
+
+ // Parse the type, and optionally the location.
+ Type argType = attrTypeReader.resolveType(typeIdx);
+ if (!argType)
+ return failure();
+ if (hasNewLoc && !(lastLoc = parseAttribute(reader)))
+ return failure();
+
+ argTypes.push_back(argType);
+ argLocs.push_back(lastLoc);
+ }
+ block->addArguments(argTypes, argLocs);
+ return defineValues(reader, block->getArguments(), firstArgID);
+}
+
+//===----------------------------------------------------------------------===//
+// Value Processing
+
+Value BytecodeReader::parseOperand(EncodingReader &reader) {
+ uint64_t valueIdx;
+ if (failed(reader.parseVarInt(valueIdx)))
+ return nullptr;
+ if (valueIdx >= values.size())
+ return (void)reader.emitError("invalid value index: ", valueIdx), Value();
+
+ // Resolve it, or create a new forward reference if necessary.
+ Value &value = values[valueIdx];
+ if (!value)
+ value = createForwardRef();
+ return value;
+}
+
+LogicalResult BytecodeReader::defineValues(EncodingReader &reader,
+ ValueRange newValues,
+ unsigned firstValueID) {
+ size_t maxId = firstValueID + newValues.size();
+ if (maxId > values.size()) {
+ return reader.emitError(
+ "value index range was outside of the expected range for "
+ "the parent region, got [",
+ firstValueID, ", ", maxId, "), but the maximum index was ",
+ values.size() - 1);
+ }
+
+ // Assign the values and update any forward references.
+ for (unsigned i = 0, e = newValues.size(); i != e; ++i) {
+ Value newValue = newValues[i];
+
+ // Check to see if a definition for this value already exists.
+ if (Value oldValue = std::exchange(values[firstValueID + i], newValue)) {
+ Operation *forwardRefOp = oldValue.getDefiningOp();
+ if (!forwardRefOp || forwardRefOp->getBlock() != &forwardRefOps) {
+ return reader.emitError("value index ", firstValueID + i,
+ " was already defined");
+ }
+
+ oldValue.replaceAllUsesWith(newValue);
+ forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end());
+ }
+ }
+ return LogicalResult::success();
+}
+
+Value BytecodeReader::createForwardRef() {
+ // Check for an avaliable existing operation to use. Otherwise, create a new
+ // fake operation to use for the reference.
+ if (!openForwardRefOps.empty()) {
+ Operation *op = &openForwardRefOps.back();
+ op->moveBefore(&forwardRefOps, forwardRefOps.end());
+ } else {
+ forwardRefOps.push_back(Operation::create(forwardRefOpState));
+ }
+ return forwardRefOps.back().getResult(0);
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Points
+//===----------------------------------------------------------------------===//
+
+bool mlir::isBytecode(llvm::MemoryBufferRef buffer) {
+ return buffer.getBuffer().startswith("ML\xefR");
+}
+
+LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block,
+ const ParserConfig &config) {
+ Location sourceFileLoc =
+ FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(),
+ /*line=*/0, /*column=*/0);
+ if (!isBytecode(buffer)) {
+ return emitError(sourceFileLoc,
+ "input buffer is not an MLIR bytecode file");
+ }
+
+ Block parsedBlock;
+ BytecodeReader reader(sourceFileLoc, config);
+ if (failed(reader.read(buffer, &parsedBlock)))
+ return failure();
+
+ // Splice the parsed operations over to the provided top-level block.
+ auto &parsedOps = parsedBlock.getOperations();
+ auto &destOps = block->getOperations();
+ destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()),
+ parsedOps, parsedOps.begin(), parsedOps.end());
+ return success();
+}
diff --git a/mlir/lib/Bytecode/Reader/CMakeLists.txt b/mlir/lib/Bytecode/Reader/CMakeLists.txt
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Reader/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(MLIRBytecodeReader
+ BytecodeReader.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode
+
+ LINK_LIBS PUBLIC
+ MLIRAsmParser
+ MLIRIR
+ MLIRSupport
+ )
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -0,0 +1,474 @@
+//===- BytecodeWriter.cpp - MLIR Bytecode Writer --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "../Encoding.h"
+#include "IRNumbering.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/OpImplementation.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/Support/Debug.h"
+#include
+
+#define DEBUG_TYPE "mlir-bytecode"
+
+using namespace mlir;
+using namespace mlir::bytecode::detail;
+
+//===----------------------------------------------------------------------===//
+// BytecodeWriterConfig
+//===----------------------------------------------------------------------===//
+
+struct BytecodeWriterConfig::Impl {
+ explicit Impl(Operation *op) : rootOp(op) {}
+
+ /// The root operation of the bytecode.
+ Operation *rootOp;
+};
+
+BytecodeWriterConfig::BytecodeWriterConfig(Operation *op)
+ : impl(std::make_unique(op)) {}
+BytecodeWriterConfig::~BytecodeWriterConfig() = default;
+
+Operation *BytecodeWriterConfig::getRootOp() const { return impl->rootOp; }
+
+//===----------------------------------------------------------------------===//
+// EncodingEmitter
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// This class functions as the underlying encoding emitter for the bytecode
+/// writer. This class is a bit different compared to other types of encoders;
+/// it does not use a single buffer, but instead may contain several buffers
+/// (some owned by the writer, and some not) that get concatted during the final
+/// emission.
+class EncodingEmitter {
+public:
+ EncodingEmitter() = default;
+ EncodingEmitter(const EncodingEmitter &) = delete;
+ EncodingEmitter &operator=(const EncodingEmitter &) = delete;
+
+ /// Write the current contents to the provided stream.
+ void writeTo(raw_ostream &os) const;
+
+ /// Return the current size of the encoded buffer.
+ size_t size() const { return prevResultSize + currentResult.size(); }
+
+ //===--------------------------------------------------------------------===//
+ // Emission
+ //===--------------------------------------------------------------------===//
+
+ /// Return a raw pointer into the result buffer at the specified offset.
+ uint8_t *getRawPointer(uint64_t offset) {
+ assert(offset < size() && offset >= prevResultSize &&
+ "cannot get pointer to previously emitted data");
+ return currentResult.data() + (offset - prevResultSize);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Integer Emission
+
+ /// Emit a single byte.
+ void emitByte(uint8_t byte) { currentResult.push_back(byte); }
+
+ /// Emit a range of bytes.
+ void emitBytes(ArrayRef bytes) {
+ llvm::append_range(currentResult, bytes);
+ }
+
+ /// Emit a variable length integer. The first encoded byte contains a prefix
+ /// in the low bits indicating the encoded length of the value. This length
+ /// prefix is a bit sequence of '0's followed by a '1'. The number of '0' bits
+ /// indicate the number of _additional_ bytes (not including the prefix byte).
+ /// All remaining bits in the first byte, along with all of the bits in
+ /// additional bytes, provide the value of the integer encoded in
+ /// little-endian order.
+ void emitVarInt(uint64_t value) {
+ // In the most common case, the value can be represented in a single byte.
+ // Given how hot this case is, explicitly handle that here.
+ if ((value >> 7) == 0)
+ return emitByte((value << 1) | 0x1);
+ emitMultiByteVarInt(value);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // String Emission
+
+ /// Emit the given string as a nul terminated string.
+ void emitNulTerminatedString(StringRef str) {
+ emitString(str);
+ emitByte(0);
+ }
+
+ /// Emit the given string without a nul terminator.
+ void emitString(StringRef str) {
+ emitBytes({reinterpret_cast(str.data()), str.size()});
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Section Emission
+
+ /// Emit a nested section of the given code, whose contents are encoded in the
+ /// provided emitter.
+ void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) {
+ emitByte(bytecode::BuiltinCode::kSection);
+
+ // Emit the section code and length.
+ emitByte(code);
+ emitVarInt(emitter.size());
+
+ // Push our current buffer and then merge the provided section body into
+ // ours.
+ appendResult(std::move(currentResult));
+ for (std::vector &result : emitter.prevResultStorage)
+ appendResult(std::move(result));
+ appendResult(std::move(emitter.currentResult));
+ }
+
+private:
+ /// Emit the given value using a variable width encoding. This method is a
+ /// fallback when the number of bytes needed to encode the value is greater
+ /// than 1. We mark it noinline here so that the single byte hot path isn't
+ /// pessimized.
+ LLVM_ATTRIBUTE_NOINLINE void emitMultiByteVarInt(uint64_t value);
+
+ /// Append a new result buffer to the current contents.
+ void appendResult(std::vector &&result) {
+ prevResultSize += result.size();
+ prevResultStorage.emplace_back(std::move(result));
+ prevResultList.emplace_back(prevResultStorage.back());
+ }
+
+ /// The result of the emitter currently being built. We refrain from building
+ /// a single buffer to simplify emitting sections, large data, and more. The
+ /// result is thus represented using multiple distinct buffers, some of which
+ /// we own (via prevResultStorage), and some of which are just pointers into
+ /// externally owned buffers.
+ std::vector currentResult;
+ std::vector> prevResultList;
+ std::vector> prevResultStorage;
+
+ /// An up-to-date total size of all of the buffers within `prevResultList`.
+ /// This enables O(1) size checks of the current encoding.
+ size_t prevResultSize = 0;
+};
+
+/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need
+/// to go through an intermediate buffer when interacting with code that wants a
+/// raw_ostream.
+class raw_emitter_ostream : public raw_ostream {
+public:
+ explicit raw_emitter_ostream(EncodingEmitter &emitter) : emitter(emitter) {
+ SetUnbuffered();
+ }
+
+private:
+ void write_impl(const char *ptr, size_t size) override {
+ emitter.emitBytes({reinterpret_cast(ptr), size});
+ }
+ uint64_t current_pos() const override { return emitter.size(); }
+
+ /// The section being emitted to.
+ EncodingEmitter &emitter;
+};
+} // namespace
+
+void EncodingEmitter::writeTo(raw_ostream &os) const {
+ for (auto &prevResult : prevResultList)
+ os.write((const char *)prevResult.data(), prevResult.size());
+ os.write((const char *)currentResult.data(), currentResult.size());
+}
+
+void EncodingEmitter::emitMultiByteVarInt(uint64_t value) {
+ // Compute the number of bytes needed to encode the value. Each byte can hold
+ // up to 7-bits of data. We only check up to the number of bits we can encode
+ // in the first byte (8).
+ uint64_t it = value >> 7;
+ for (size_t numBytes = 2; numBytes < 9; ++numBytes) {
+ if (LLVM_LIKELY(it >>= 7) == 0) {
+ uint64_t encodedValue = (value << 1) | 0x1;
+ encodedValue <<= (numBytes - 1);
+ emitBytes({reinterpret_cast(&encodedValue), numBytes});
+ return;
+ }
+ }
+
+ // If the value is too large to encode in a single byte, emit a special all
+ // zero marker byte and splat the value directly.
+ emitByte(0);
+ emitBytes({reinterpret_cast(&value), sizeof(value)});
+}
+
+//===----------------------------------------------------------------------===//
+// Bytecode Writer
+//===----------------------------------------------------------------------===//
+
+namespace {
+class BytecodeWriter {
+public:
+ BytecodeWriter(const BytecodeWriterConfig &config) : numberingState(config) {}
+
+ /// Write the bytecode for the given root operation.
+ void write(Operation *rootOp, raw_ostream &os);
+
+private:
+ //===--------------------------------------------------------------------===//
+ // Dialects
+
+ void writeDialectSection(EncodingEmitter &emitter);
+
+ //===--------------------------------------------------------------------===//
+ // Attributes and Types
+
+ void writeAttrTypeSection(EncodingEmitter &emitter);
+
+ //===--------------------------------------------------------------------===//
+ // Operations
+
+ void writeBlock(EncodingEmitter &emitter, Block *block, Attribute &lastLoc);
+ void writeOp(EncodingEmitter &emitter, Operation *op, Attribute &lastLoc);
+ void writeRegion(EncodingEmitter &emitter, Region *region,
+ Attribute &lastLoc);
+ void writeTopLevelOp(EncodingEmitter &emitter, Operation *op);
+
+ //===--------------------------------------------------------------------===//
+ // Fields
+
+ /// The IR numbering state generated for the root operation.
+ IRNumberingState numberingState;
+};
+} // namespace
+
+void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) {
+ EncodingEmitter emitter;
+
+ // Emit the bytecode file header. This is how we identify the output as a
+ // bytecode file.
+ emitter.emitString("ML\xefR");
+
+ // Emit the bytecode version.
+ emitter.emitVarInt(bytecode::kVersion);
+
+ // Emit the dialect section.
+ writeDialectSection(emitter);
+
+ // Emit the attributes and types section.
+ writeAttrTypeSection(emitter);
+
+ // Emit the top level operation section.
+ writeTopLevelOp(emitter, rootOp);
+
+ // Write the generated bytecode to the provided output stream.
+ emitter.writeTo(os);
+}
+
+//===----------------------------------------------------------------------===//
+// Dialects
+
+void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
+ EncodingEmitter dialectEmitter;
+
+ // Emit the referenced dialects.
+ for (DialectNumbering &dialect : numberingState.getDialects()) {
+ // Emit the dialect name.
+ dialectEmitter.emitNulTerminatedString(dialect.name);
+
+ // Emit the number of attributes and types emitted for this dialect.
+ dialectEmitter.emitVarInt(dialect.attributes.size());
+ dialectEmitter.emitVarInt(dialect.types.size());
+
+ // Emit the referenced operation names of this dialect.
+ dialectEmitter.emitVarInt(dialect.opNames.size());
+ for (OpNameNumbering *opName : dialect.opNames)
+ dialectEmitter.emitNulTerminatedString(opName->name.stripDialect());
+ }
+
+ emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter));
+}
+
+//===----------------------------------------------------------------------===//
+// Attributes and Types
+
+void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) {
+ EncodingEmitter attrTypeEmitter;
+ EncodingEmitter offsetEmitter;
+
+ // A functor used to emit an attribute or type entry.
+ uint64_t prevOffset = 0;
+ auto emitAttrOrType = [&](auto value) {
+ // Emit the entry using the textual format.
+ // TODO: Allow dialects to provide more optimal implementations of attribute
+ // and type encodings.
+ attrTypeEmitter.emitByte(bytecode::AttrTypeCode::kAsmForm);
+ raw_emitter_ostream(attrTypeEmitter) << value;
+ attrTypeEmitter.emitByte(0);
+
+ // Record the offset of this entry.
+ uint64_t curOffset = attrTypeEmitter.size();
+ offsetEmitter.emitVarInt(curOffset - prevOffset);
+ prevOffset = curOffset;
+ };
+
+ // Emit the attribute and type entries for each dialect.
+ for (DialectNumbering &dialect : numberingState.getDialects())
+ for (AttributeNumbering *attr : dialect.attributes)
+ emitAttrOrType(attr->getValue());
+ for (DialectNumbering &dialect : numberingState.getDialects())
+ for (TypeNumbering *type : dialect.types)
+ emitAttrOrType(type->getValue());
+
+ // Emit the sections to the stream.
+ emitter.emitSection(bytecode::Section::kAttrTypeOffset,
+ std::move(offsetEmitter));
+ emitter.emitSection(bytecode::Section::kAttrType, std::move(attrTypeEmitter));
+}
+
+//===----------------------------------------------------------------------===//
+// Operations
+
+void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block,
+ Attribute &lastLoc) {
+ // Emit the arguments of the block.
+ ArrayRef args = block->getArguments();
+ if (!args.empty()) {
+ emitter.emitByte(bytecode::TopLevelOpCode::kBlockArguments);
+
+ // Emit the value number for the first argument, and the number of arguments
+ // we are encoding.
+ emitter.emitVarInt(numberingState.getNumber(args.front()));
+ emitter.emitVarInt(args.size());
+
+ for (const auto &it : llvm::enumerate(args)) {
+ // Check to see if this argument has a new location.
+ Attribute argLoc = it.value().getLoc();
+ bool argHasNewLoc = argLoc != std::exchange(lastLoc, argLoc);
+
+ // Emit the argument type. We use the low bit of the type number to
+ // indicate if the argument changed locations.
+ uint64_t typeID = numberingState.getNumber(it.value().getType());
+ emitter.emitVarInt((typeID << 1) | (argHasNewLoc ? 1 : 0));
+ if (argHasNewLoc)
+ emitter.emitVarInt(numberingState.getNumber(argLoc));
+ }
+ }
+
+ // Emit the operations within the block.
+ for (Operation &op : *block) {
+ emitter.emitByte(bytecode::TopLevelOpCode::kOp);
+ writeOp(emitter, &op, lastLoc);
+ }
+ // Emit a terminal code to indicate when we are finished emitting operations.
+ emitter.emitByte(bytecode::TopLevelOpCode::kBlockEnd);
+}
+
+void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op,
+ Attribute &lastLoc) {
+ emitter.emitVarInt(numberingState.getNumber(op->getName()));
+
+ // Emit a mask for the operation components. We need to fill this in later
+ // (when we actually know what needs to be emitted), so emit a placeholder for
+ // now.
+ uint64_t maskOffset = emitter.size();
+ uint8_t opEncodingMask = 0;
+ emitter.emitByte(0);
+
+ // Emit the location for this operation.
+ Attribute opLoc = op->getLoc();
+ if (opLoc != std::exchange(lastLoc, opLoc)) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasLoc;
+ emitter.emitVarInt(numberingState.getNumber(opLoc));
+ }
+
+ // Emit the attributes of this operation.
+ DictionaryAttr attrs = op->getAttrDictionary();
+ if (!attrs.empty()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs;
+ emitter.emitVarInt(numberingState.getNumber(op->getAttrDictionary()));
+ }
+
+ // Emit the result types of the operation.
+ if (unsigned numResults = op->getNumResults()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasResults;
+ emitter.emitVarInt(numberingState.getNumber(op->getResult(0)));
+ emitter.emitVarInt(numResults);
+ for (Type type : op->getResultTypes())
+ emitter.emitVarInt(numberingState.getNumber(type));
+ }
+
+ // Emit the operands of the operation.
+ if (unsigned numOperands = op->getNumOperands()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasOperands;
+ emitter.emitVarInt(numOperands);
+ for (Value operand : op->getOperands())
+ emitter.emitVarInt(numberingState.getNumber(operand));
+ }
+
+ // Emit the successors of the operation.
+ if (unsigned numSuccessors = op->getNumSuccessors()) {
+ opEncodingMask |= bytecode::OpEncodingMask::kHasSuccessors;
+ emitter.emitVarInt(numSuccessors);
+ for (Block *successor : op->getSuccessors())
+ emitter.emitVarInt(numberingState.getNumber(successor));
+ }
+
+ // Check for regions.
+ unsigned numRegions = op->getNumRegions();
+ if (numRegions)
+ opEncodingMask |= bytecode::OpEncodingMask::kHasInlineRegions;
+
+ // Update the mask for the operation.
+ *emitter.getRawPointer(maskOffset) = opEncodingMask;
+
+ // With the mask emitted, we can now emit the regions of the operation. We do
+ // this after mask emission to avoid offset complications that may arise by
+ // emitting the regions first (e.g. if the regions are huge, backpatching the
+ // op encoding mask is more annoying).
+ if (numRegions) {
+ emitter.emitVarInt(numRegions);
+ for (Region ®ion : op->getRegions())
+ writeRegion(emitter, ®ion, lastLoc);
+ }
+}
+
+void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region,
+ Attribute &lastLoc) {
+ if (region->empty())
+ return emitter.emitByte(bytecode::TopLevelOpCode::kRegionEmpty);
+
+ // Emit the number of blocks and values within the region.
+ unsigned numBlocks, numValues;
+ std::tie(numBlocks, numValues) = numberingState.getBlockValueCount(region);
+ emitter.emitByte(bytecode::TopLevelOpCode::kRegion);
+ emitter.emitVarInt(numBlocks);
+ emitter.emitVarInt(numValues);
+
+ // Emit the blocks within the region.
+ for (Block &block : *region)
+ writeBlock(emitter, &block, lastLoc);
+}
+
+void BytecodeWriter::writeTopLevelOp(EncodingEmitter &emitter, Operation *op) {
+ EncodingEmitter topLevelOpEmitter;
+
+ Attribute lastLoc;
+ writeOp(topLevelOpEmitter, op, lastLoc);
+
+ emitter.emitSection(bytecode::Section::kTopLevelOp,
+ std::move(topLevelOpEmitter));
+}
+
+//===----------------------------------------------------------------------===//
+// Entry Points
+//===----------------------------------------------------------------------===//
+
+void mlir::writeBytecodeToFile(const BytecodeWriterConfig &config,
+ raw_ostream &os) {
+ BytecodeWriter writer(config);
+ writer.write(config.getRootOp(), os);
+}
diff --git a/mlir/lib/Bytecode/Writer/CMakeLists.txt b/mlir/lib/Bytecode/Writer/CMakeLists.txt
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/CMakeLists.txt
@@ -0,0 +1,11 @@
+add_mlir_library(MLIRBytecodeWriter
+ BytecodeWriter.cpp
+ IRNumbering.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRSupport
+ )
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.h
@@ -0,0 +1,173 @@
+//===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains various utilities that number IR structures in preparation
+// for bytecode emission.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
+#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H
+
+#include "mlir/IR/OperationSupport.h"
+#include "llvm/ADT/MapVector.h"
+
+namespace mlir {
+class BytecodeWriterConfig;
+
+namespace bytecode {
+namespace detail {
+struct DialectNumbering;
+
+//===----------------------------------------------------------------------===//
+// Attribute and Type Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents a numbering entry for an Attribute or Type.
+struct AttrTypeNumbering {
+ AttrTypeNumbering(PointerUnion value) : value(value) {}
+
+ /// The concrete value.
+ PointerUnion value;
+
+ /// The number assigned to this value.
+ unsigned number = 0;
+
+ /// The dialect of this value.
+ DialectNumbering *dialect = nullptr;
+};
+struct AttributeNumbering : public AttrTypeNumbering {
+ AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {}
+ Attribute getValue() const { return value.get(); }
+};
+struct TypeNumbering : public AttrTypeNumbering {
+ TypeNumbering(Type value) : AttrTypeNumbering(value) {}
+ Type getValue() const { return value.get(); }
+};
+
+//===----------------------------------------------------------------------===//
+// OpName Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents the numbering entry of an operation name.
+struct OpNameNumbering {
+ OpNameNumbering(OperationName name) : name(name) {}
+
+ /// The concrete name.
+ OperationName name;
+
+ /// The number assigned to this name.
+ unsigned number = 0;
+};
+
+//===----------------------------------------------------------------------===//
+// Dialect Numbering
+//===----------------------------------------------------------------------===//
+
+/// This class represents a numbering entry for an Dialect.
+struct DialectNumbering {
+ DialectNumbering(StringRef name, unsigned number)
+ : name(name), number(number) {}
+
+ /// The namespace of the dialect.
+ StringRef name;
+
+ /// The number assigned to the dialect.
+ unsigned number;
+
+ /// The loaded dialect, or nullptr if the dialect isn't loaded.
+ Dialect *dialect = nullptr;
+
+ /// Numbered sub-components of the dialect to be emitted.
+ std::vector opNames;
+ std::vector attributes;
+ std::vector types;
+};
+
+//===----------------------------------------------------------------------===//
+// IRNumberingState
+//===----------------------------------------------------------------------===//
+
+/// This class manages numbering IR entities in preparation of bytecode
+/// emission.
+class IRNumberingState {
+public:
+ IRNumberingState(const BytecodeWriterConfig &config);
+
+ /// Return the numbered dialects.
+ auto getDialects() {
+ return llvm::make_pointee_range(llvm::make_second_range(dialects));
+ }
+
+ /// Return the number for the given IR unit.
+ unsigned getNumber(Attribute attr) {
+ assert(attrs.count(attr) && "attribute not numbered");
+ return attrs[attr]->number;
+ }
+ unsigned getNumber(Block *block) {
+ assert(blockIDs.count(block) && "block not numbered");
+ return blockIDs[block];
+ }
+ unsigned getNumber(OperationName opName) {
+ assert(opNames.count(opName) && "opName not numbered");
+ return opNames[opName]->number;
+ }
+ unsigned getNumber(Type type) {
+ assert(types.count(type) && "type not numbered");
+ return types[type]->number;
+ }
+ unsigned getNumber(Value value) {
+ assert(valueIDs.count(value) && "value not numbered");
+ return valueIDs[value];
+ }
+
+ /// Return the block and value counts of the given region.
+ std::pair getBlockValueCount(Region *region) {
+ assert(regionBlockValueCounts.count(region) && "value not numbered");
+ return regionBlockValueCounts[region];
+ }
+
+private:
+ /// Number the given IR unit for bytecode emission.
+ void number(Attribute attr);
+ void number(Block &block);
+ DialectNumbering &numberDialect(Dialect *dialect);
+ DialectNumbering &numberDialect(StringRef dialect);
+ void number(Operation &op);
+ void number(OperationName opName);
+ void number(Region ®ion);
+ void number(Type type);
+
+ /// Mapping from IR to the respective numbering entries.
+ llvm::MapVector attrs;
+ llvm::MapVector opNames;
+ llvm::MapVector types;
+ llvm::MapVector registeredDialects;
+ llvm::MapVector dialects;
+
+ /// Allocators used for the various numbering entries.
+ llvm::SpecificBumpPtrAllocator attrAllocator;
+ llvm::SpecificBumpPtrAllocator dialectAllocator;
+ llvm::SpecificBumpPtrAllocator opNameAllocator;
+ llvm::SpecificBumpPtrAllocator typeAllocator;
+
+ /// The value ID for each Block and Value.
+ DenseMap blockIDs;
+ DenseMap valueIDs;
+
+ /// A map from region to the number of blocks and values within that region.
+ DenseMap> regionBlockValueCounts;
+
+ /// The next value ID to assign when numbering.
+ unsigned nextValueID = 0;
+};
+} // namespace detail
+} // namespace bytecode
+} // namespace mlir
+
+#endif
diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp
@@ -0,0 +1,165 @@
+//===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRNumbering.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/Operation.h"
+
+using namespace mlir;
+using namespace mlir::bytecode::detail;
+
+//===----------------------------------------------------------------------===//
+// IR Numbering
+//===----------------------------------------------------------------------===//
+
+IRNumberingState::IRNumberingState(const BytecodeWriterConfig &config) {
+ Operation *op = config.getRootOp();
+
+ // Number the root operation.
+ number(*op);
+
+ // Push all of the regions of the root operation onto the worklist.
+ SmallVector, 8> numberContext;
+ for (Region ®ion : op->getRegions())
+ numberContext.emplace_back(®ion, nextValueID);
+
+ // Iteratively process each of the nested regions.
+ while (!numberContext.empty()) {
+ Region *region;
+ std::tie(region, nextValueID) = numberContext.pop_back_val();
+ number(*region);
+
+ // Traverse into nested regions.
+ for (Operation &op : region->getOps())
+ for (Region ®ion : op.getRegions())
+ numberContext.emplace_back(®ion, nextValueID);
+ }
+
+ // Walk and number the recorded components within each dialect.
+ unsigned attrID = 0, opNameID = 0, typeID = 0;
+ for (DialectNumbering *dialect : llvm::make_second_range(dialects)) {
+ for (AttributeNumbering *attr : dialect->attributes)
+ attr->number = attrID++;
+ for (OpNameNumbering *opName : dialect->opNames)
+ opName->number = opNameID++;
+ for (TypeNumbering *type : dialect->types)
+ type->number = typeID++;
+ }
+}
+
+void IRNumberingState::number(Attribute attr) {
+ auto it = attrs.insert({attr, nullptr});
+ if (!it.second)
+ return;
+ auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr);
+ it.first->second = numbering;
+
+ // Check for OpaqueAttr, which is a dialect-specific attribute that didn't
+ // have a registered dialect when it got created. We don't want to encode this
+ // as the builtin OpaqueAttr, we want to encode it as if the dialect was
+ // actually loaded.
+ if (OpaqueAttr opaqueAttr = attr.dyn_cast())
+ numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace());
+ else
+ numbering->dialect = &numberDialect(&attr.getDialect());
+
+ numbering->dialect->attributes.push_back(numbering);
+}
+
+void IRNumberingState::number(Block &block) {
+ // Number the arguments of the block.
+ for (BlockArgument arg : block.getArguments()) {
+ valueIDs.try_emplace(arg, nextValueID++);
+ number(arg.getLoc());
+ number(arg.getType());
+ }
+
+ // Number the operations in this block.
+ for (Operation &op : block)
+ number(op);
+}
+
+auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & {
+ DialectNumbering *&numbering = registeredDialects[dialect];
+ if (!numbering) {
+ numbering = &numberDialect(dialect->getNamespace());
+ numbering->dialect = dialect;
+ }
+ return *numbering;
+}
+
+auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & {
+ DialectNumbering *&numbering = dialects[dialect];
+ if (!numbering) {
+ numbering = new (dialectAllocator.Allocate())
+ DialectNumbering(dialect, dialects.size() - 1);
+ }
+ return *numbering;
+}
+
+void IRNumberingState::number(Region ®ion) {
+ size_t firstValueID = nextValueID;
+
+ // Number the blocks within this region.
+ size_t blockCount = 0;
+ for (auto &it : llvm::enumerate(region)) {
+ blockIDs.try_emplace(&it.value(), it.index());
+ number(it.value());
+ ++blockCount;
+ }
+
+ // Remember the number of blocks and values in this region.
+ regionBlockValueCounts.try_emplace(®ion, blockCount,
+ nextValueID - firstValueID);
+}
+
+void IRNumberingState::number(Operation &op) {
+ // Number the components of an operation that won't be numbered elsewhere
+ // (e.g. we don't number operands, regions, or successors here).
+ number(op.getName());
+ for (OpResult result : op.getResults()) {
+ valueIDs.try_emplace(result, nextValueID++);
+ number(result.getType());
+ }
+ number(op.getAttrDictionary());
+ number(op.getLoc());
+}
+
+void IRNumberingState::number(OperationName opName) {
+ OpNameNumbering *&numbering = opNames[opName];
+ if (numbering)
+ return;
+ DialectNumbering *dialectNumber = nullptr;
+ if (Dialect *dialect = opName.getDialect())
+ dialectNumber = &numberDialect(dialect);
+ else
+ dialectNumber = &numberDialect(opName.getDialectNamespace());
+ numbering = new (opNameAllocator.Allocate()) OpNameNumbering(opName);
+ dialectNumber->opNames.emplace_back(numbering);
+}
+
+void IRNumberingState::number(Type type) {
+ auto it = types.insert({type, nullptr});
+ if (!it.second)
+ return;
+ auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type);
+ it.first->second = numbering;
+
+ // Check for OpaqueType, which is a dialect-specific type that didn't have a
+ // registered dialect when it got created. We don't want to encode this as the
+ // builtin OpaqueType, we want to encode it as if the dialect was actually
+ // loaded.
+ if (OpaqueType opaqueType = type.dyn_cast())
+ numbering->dialect = &numberDialect(opaqueType.getDialectNamespace());
+ else
+ numbering->dialect = &numberDialect(&type.getDialect());
+
+ numbering->dialect->types.push_back(numbering);
+}
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -3,6 +3,7 @@
add_subdirectory(Analysis)
add_subdirectory(AsmParser)
+add_subdirectory(Bytecode)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(IR)
diff --git a/mlir/lib/Parser/CMakeLists.txt b/mlir/lib/Parser/CMakeLists.txt
--- a/mlir/lib/Parser/CMakeLists.txt
+++ b/mlir/lib/Parser/CMakeLists.txt
@@ -6,5 +6,6 @@
LINK_LIBS PUBLIC
MLIRAsmParser
+ MLIRBytecodeReader
MLIRIR
)
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -12,6 +12,7 @@
#include "mlir/Parser/Parser.h"
#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/Bytecode/BytecodeReader.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
@@ -25,6 +26,8 @@
sourceBuf->getBufferIdentifier(),
/*line=*/0, /*column=*/0);
}
+ if (isBytecode(*sourceBuf))
+ return readBytecodeFile(*sourceBuf, block, config);
return parseAsmSourceFile(sourceMgr, block, config);
}
diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
--- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt
+++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
@@ -5,6 +5,7 @@
${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-opt
LINK_LIBS PUBLIC
+ MLIRBytecodeWriter
MLIRPass
MLIRParser
MLIRSupport
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
+#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
@@ -47,7 +48,8 @@
static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
bool verifyPasses, SourceMgr &sourceMgr,
MLIRContext *context,
- PassPipelineFn passManagerSetupFn) {
+ PassPipelineFn passManagerSetupFn,
+ bool emitBytecode) {
DefaultTimingManager tm;
applyDefaultTimingManagerCLOptions(tm);
TimingScope timing = tm.getRootScope();
@@ -86,8 +88,12 @@
// Print the output.
TimingScope outputTiming = timing.nest("Output");
- module->print(os);
- os << '\n';
+ if (emitBytecode) {
+ writeBytecodeToFile(module->getOperation(), os);
+ } else {
+ module->print(os);
+ os << '\n';
+ }
return success();
}
@@ -97,8 +103,8 @@
processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects, bool preloadDialectsInContext,
- PassPipelineFn passManagerSetupFn, DialectRegistry ®istry,
- llvm::ThreadPool *threadPool) {
+ bool emitBytecode, PassPipelineFn passManagerSetupFn,
+ DialectRegistry ®istry, llvm::ThreadPool *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
@@ -122,7 +128,7 @@
if (!verifyDiagnostics) {
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
- &context, passManagerSetupFn);
+ &context, passManagerSetupFn, emitBytecode);
}
SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
@@ -131,7 +137,7 @@
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
- passManagerSetupFn);
+ passManagerSetupFn, emitBytecode);
// Verify the diagnostic handler to make sure that each of the diagnostics
// matched.
@@ -144,7 +150,20 @@
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext) {
+ bool preloadDialectsInContext,
+ bool emitBytecode) {
+ // Check to see if we are trying to output bytecode to a displayed stream.
+ // TODO: Do we need to provide a -f option like LLVM? Should we even
+ // warn/disable in this case?
+ if (emitBytecode && outputStream.is_displayed()) {
+ llvm::errs()
+ << "warning: Attempting to output a bytecode file to a displayed "
+ "stream.\n"
+ "This is inadvisable as it may cause display problems, disabling "
+ "bytecode output.\n\n";
+ emitBytecode = false;
+ }
+
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
// We use an explicit threadpool to avoid creating and joining/destroying
@@ -163,8 +182,8 @@
raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
- preloadDialectsInContext, passManagerSetupFn, registry,
- threadPool);
+ preloadDialectsInContext, emitBytecode,
+ passManagerSetupFn, registry, threadPool);
};
return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
splitInputFile, /*insertMarkerInOutput=*/true);
@@ -176,7 +195,8 @@
DialectRegistry ®istry, bool splitInputFile,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext) {
+ bool preloadDialectsInContext,
+ bool emitBytecode) {
auto passManagerSetupFn = [&](PassManager &pm) {
auto errorHandler = [&](const Twine &msg) {
emitError(UnknownLoc::get(pm.getContext())) << msg;
@@ -186,7 +206,8 @@
};
return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
registry, splitInputFile, verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, preloadDialectsInContext);
+ allowUnregisteredDialects, preloadDialectsInContext,
+ emitBytecode);
}
LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
@@ -224,6 +245,10 @@
"show-dialects", cl::desc("Print the list of registered dialects"),
cl::init(false));
+ static cl::opt emitBytecode(
+ "emit-bytecode", cl::desc("Emit bytecode when generating output"),
+ cl::init(false));
+
InitLLVM y(argc, argv);
// Register any command line options.
@@ -268,7 +293,8 @@
if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
splitInputFile, verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, preloadDialectsInContext)))
+ allowUnregisteredDialects, preloadDialectsInContext,
+ emitBytecode)))
return failure();
// Keep the output file if the invocation of MlirOptMain was successful.
diff --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Bytecode/general.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt -allow-unregistered-dialect -emit-bytecode %s | mlir-opt -allow-unregistered-dialect | FileCheck %s
+
+// CHECK-LABEL: "bytecode.test1"
+// CHECK-NEXT: "bytecode.empty"() : () -> ()
+// CHECK-NEXT: "bytecode.attributes"() {attra = 10 : i64, attrb = #bytecode.attr} : () -> ()
+// CHECK-NEXT: %[[RESULTS:.*]]:3 = "bytecode.results"() : () -> (i32, i64, i32)
+// CHECK-NEXT: "bytecode.operands"(%[[RESULTS]]#0, %[[RESULTS]]#1, %[[RESULTS]]#2) : (i32, i64, i32) -> ()
+// CHECK-NEXT: "bytecode.branch"()[^[[BLOCK:.*]]] : () -> ()
+// CHECK-NEXT: ^[[BLOCK]](%[[ARG0:.*]]: i32, %[[ARG1:.*]]: !bytecode.int, %[[ARG2:.*]]: !pdl.operation):
+// CHECK-NEXT: "bytecode.regions"() ({
+// CHECK-NEXT: "bytecode.operands"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (i32, !bytecode.int, !pdl.operation) -> ()
+// CHECK-NEXT: "bytecode.return"() : () -> ()
+// CHECK-NEXT: }) : () -> ()
+// CHECK-NEXT: "bytecode.return"() : () -> ()
+// CHECK-NEXT: }) : () -> ()
+
+"bytecode.test1"() ({
+ "bytecode.empty"() : () -> ()
+ "bytecode.attributes"() {attra = 10, attrb = #bytecode.attr} : () -> ()
+ %results:3 = "bytecode.results"() : () -> (i32, i64, i32)
+ "bytecode.operands"(%results#0, %results#1, %results#2) : (i32, i64, i32) -> ()
+ "bytecode.branch"()[^secondBlock] : () -> ()
+
+^secondBlock(%arg1: i32, %arg2: !bytecode.int, %arg3: !pdl.operation):
+ "bytecode.regions"() ({
+ "bytecode.operands"(%arg1, %arg2, %arg3) : (i32, !bytecode.int, !pdl.operation) -> ()
+ "bytecode.return"() : () -> ()
+ }) : () -> ()
+ "bytecode.return"() : () -> ()
+}) : () -> ()