diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md new file mode 100644 --- /dev/null +++ b/mlir/docs/BytecodeFormat.md @@ -0,0 +1,296 @@ +# MLIR Bytecode Format + +This documents describes the MLIR bytecode format and its encoding. + +[TOC] + +## Magic Number + +MLIR uses the following four-byte magic number to indicate bytecode files: + +'\[‘M’8, ‘L’8, ‘ï’8, ‘R’8\]' + +## Format Overview + +An MLIR Bytecode file is comprised of a byte stream, with a few simple +structural concepts layered on top. + +### Primitives + +#### Fixed-Width Integers + +``` + byte ::= `0x00`...`0xFF` +``` + +Fixed width integers are unsigned integers of a known byte size. The values are +stored in little-endian byte order. + +TODO: Add larger fixed width integers as necessary. + +#### Variable-Width Integers + +Variable width integers, or `VarInt`s, provide a compact representation for +integers. Each encoded VarInt consists of one to nine bytes, which together +represent a single 64-bit value. The MLIR bytecode utilizes the "PrefixVarInt" +encoding for VarInts. This encoding is a variant of the +[LEB128 ("Little-Endian Base 128")](https://en.wikipedia.org/wiki/LEB128) +encoding, where each byte of the encoding provides up to 7 bits for the value, +with the remaining bit used to store a tag indicating the number of bytes used +for the encoding. This means that small unsigned integers (less than 2^7) may be +stored in one byte, unsigned integers up to 2^14 may be stored in two bytes, +etc. + +The first byte of the encoding includes a length prefix in the low bits. This +prefix is a bit sequence of '0's followed by a terminal '1', or the end of the +byte. The number of '0' bits indicate the number of _additional_ bytes, not +including the prefix byte, used to encode the value. All of the remaining bits +in the first byte, along with all of the bits in the additional bytes, provide +the value of the integer. Below are the various possible encodings of the prefix +byte: + +``` +xxxxxxx1: 7 value bits, the encoding uses 1 byte +xxxxxx10: 14 value bits, the encoding uses 2 bytes +xxxxx100: 21 value bits, the encoding uses 3 bytes +xxxx1000: 28 value bits, the encoding uses 4 bytes +xxx10000: 35 value bits, the encoding uses 5 bytes +xx100000: 42 value bits, the encoding uses 6 bytes +x1000000: 49 value bits, the encoding uses 7 bytes +10000000: 56 value bits, the encoding uses 8 bytes +00000000: 64 value bits, the encoding uses 9 bytes +``` + +#### NUL Terminated Strings + +NUL Terminated Strings are terminated with the ASCII NUL character (whose byte +value is zero). These are not used in cases where a string may contain an +embedded NUL character. In cases that may hold an embedded NUL character, the +string is encoded using a length and byte array. + +### Sections + +``` +section { + id: byte + length: varint +} +``` + +Sections are a mechanism for grouping data within the bytecode. The enable +delayed processing, which is useful for out-of-order processing of data, +lazy-loading, and more. Each section contains a Section ID and a length (which +allowing for skipping over the section). + +TODO: Sections should also carry an optional alignment. Add this when necessary. + +## MLIR Encoding + +Given the generic structure of MLIR, the bytecode encoding is actually fairly +simplistic. It effectively maps to the core components of MLIR. + +### Top Level Structure + +The top-level structure of the bytecode contains the 4-byte "magic number", a +version number, and a list of sections. Each section is currently only expected +to appear once within a bytecode file. + +``` +bytecode { + magic: "MLïR", + version: varint, + sections: section[] +} +``` + +### Dialect Section + +The dialect section of the bytecode contains all of the dialects referenced +within the encoded IR, and some information about the components of those +dialects that were also referenced. + +``` +dialect_section { + dialects: dialect[] +} + +dialect { + name: nul_terminated_string, + numAttrs: varint, + numTypes: varint, + numOpNames: varint, + opNames: nul_terminated_string[] +} +``` + +### Attribute/Type Sections + +Attributes and types are encoded using two [sections](#sections), one section +(`attr_type_section`) containing the actual encoded representation, and another +section (`attr_type_offset_section`) containing the offsets of each encoded +attribute/type into the previous section. This allows for attributes and types +to always be lazily loaded on demand. + +``` +attr_type_section { + attrs: attribute[], + types: type[] +} +attr_type_offset_section { + offset: varint[] +} + +attribute { + code: byte, // kAsmForm + encoding: ... +} +type { + code: byte, // kAsmForm + encoding: ... +} +``` + +Each `offset` in the `attr_type_offset_section` above is the size of the +encoding for the attribute or type. We avoid using the direct offset into the +`attr_type_section`, as a smaller relative offsets provides more effective +compression. Attributes and types are grouped by dialect, with each dialect +grouping in the same order of the dialects within the +[dialect section](#dialect-section). + +#### Attribute/Type Encodings + +In the previous section, the forms of `attribute` and `type` both start with a +`code` field. This field indicates how the attribute or type was encoded. In the +abstract, an attribute/type is encoded in one of two possible ways: via its +assembly format, or via a custom dialect defined encoding. + +##### Assembly Format Fallback + +In the case where a dialect does not define a method for encoding the attribute +or type, the textual assembly format of that attribute or type is used as a +fallback. For example, a type of `!bytecode.type` would be encoded as the null +terminated string "!bytecode.type". This ensures that every attribute and type +may be encoded, even if the owning dialect has not yet opted in to a more +efficient serialization. + +##### Dialect Defined Encoding + +TODO: This is not yet supported. + +### IR Section + +The IR section contains the encoded form of operations within the bytecode. + +#### Operation Encoding + +``` +op { + name: varint, + encodingMask: byte, + + location: varint?, + + attrDict: varint?, + + firstResultIndex: varint?, + numResults: varint?, + resultTypes: varint[], + + numOperands: varint?, + operands: varint[], + + numSuccessors: varint?, + successors: varint[], + + numRegions: varint?, + regions: region[] +} +``` + +The encoding of an operation is important because this is generally the most +commonly appearing structure in the bytecode. A single encoding is used for +every type of operation. Given this prevelance, many of the fields of an +operation are optional. The `encodingMask` field is a bitmask which indicates +which of the components of the operation are present. + +##### Location + +If necessary to encode, i.e. if the location for this operation is different +than the location for the last operation or block argument, the index of the +location within the attribute table is encoded. + +##### Attributes + +If the operation has attribues, the index of the operation attribute dictionary +within the attribute table is encoded. + +##### Results + +If the operation has results, the value index of the first result is encoded. +After that, the number of results and the indexes of the result types within the +type table are encoded. + +##### Operands + +If the operation has operands, the number of operands and the value index of +each operand is encoded. + +##### Successors + +If the operation has successors, the number of successors and the indexes of the +successor blocks within the parent region are encoded. + +##### Regions + +If the operation has regions, the number of regions and each region are encoded. + +#### Region Encoding + +``` +region { + code: byte, // kRegion | kRegionEmpty + + numBlocks: varint?, + numValues: varint?, + blocks: block[] +} +``` + +A region is encoded with a leading code followed by the body. The code indicates +how the body is encoded. If the code is `kRegionEmpty`, the region has no body. +If the code is `kRegion`, the body is present. + +#### Block Encoding + +``` +block { + block_code: byte, // kBlockArguments | kOp | kBlockEnd + block_element: block_arguments | op | [] +} + +block_arguments { + code: byte, // kBlockArguments + + firstArgIndex: varint, + numArgs: varint?, + args: block_argument[] + +} +block_argument { + typeIndexAndHasLoc: varint, // (typeIndex << 1) | (hasLoc) + location: varint? +} + +``` + +A block is encoded with an array of elements determined by a leading code. The +terminal `kBlockEnd` code indicates the end of a block. The `kOp` code indicates +that an operation follows. If the block has arguments, the first element of the +block will contain the encoded representation of the arguments, or +`block_arguments` above. The encoding for the block arguments includes the value +index of the first argument, the number of arguments, and an encoded list of +arguments. The `typeIndexAndHasLoc` field of the argument is a varint that in +the high-bits holds the index for the type of that argument, and in the low bit +contains a flag that indicates if the argument has a location encoded along with +it. A location is encoded if the argument had a different location than the +previously encoded argument or operation. diff --git a/mlir/include/mlir/Bytecode/BytecodeReader.h b/mlir/include/mlir/Bytecode/BytecodeReader.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/BytecodeReader.h @@ -0,0 +1,34 @@ +//===- BytecodeReader.h - MLIR Bytecode Reader ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines interfaces to read MLIR bytecode files/streams. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BYTECODE_BYTECODEREADER_H +#define MLIR_BYTECODE_BYTECODEREADER_H + +#include "mlir/IR/AsmState.h" +#include "mlir/Support/LLVM.h" + +namespace llvm { +class MemoryBufferRef; +} // namespace llvm + +namespace mlir { +/// Returns true if the given buffer starts with the magic bytes that signal +/// MLIR bytecode. +bool isBytecode(llvm::MemoryBufferRef buffer); + +/// Read the operations defined within the given memory buffer, containing MLIR +/// bytecode, into the provided block. +LogicalResult readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, + const ParserConfig &config); +} // namespace mlir + +#endif // MLIR_BYTECODE_BYTECODEREADER_H diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h @@ -0,0 +1,52 @@ +//===- BytecodeWriter.h - MLIR Bytecode Writer ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines interfaces to write MLIR bytecode files/streams. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BYTECODE_BYTECODEWRITER_H +#define MLIR_BYTECODE_BYTECODEWRITER_H + +#include "mlir/Support/LLVM.h" + +namespace mlir { +class Operation; + +//===----------------------------------------------------------------------===// +// BytecodeWriterConfig +//===----------------------------------------------------------------------===// + +/// This class provides a configuration for the bytecode writer. It is the main +/// injection of information into the writer. +class BytecodeWriterConfig { + struct Impl; + +public: + BytecodeWriterConfig(Operation *op); + ~BytecodeWriterConfig(); + + /// Return the root operation of the writer. + Operation *getRootOp() const; + +private: + /// A pointer to the allocated storage for the impl state. + std::unique_ptr impl; +}; + +//===----------------------------------------------------------------------===// +// Entry Points +//===----------------------------------------------------------------------===// + +/// Write the given bytecode configuration to the provided output stream. For +/// streams where it matters, the given stream should be in "binary" mode. +void writeBytecodeToFile(const BytecodeWriterConfig &config, raw_ostream &os); + +} // namespace mlir + +#endif // MLIR_BYTECODE_BYTECODEWRITER_H diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -642,11 +642,11 @@ OperationState(Location location, OperationName name); OperationState(Location location, OperationName name, ValueRange operands, - TypeRange types, ArrayRef attributes, + TypeRange types, ArrayRef attributes = {}, BlockRange successors = {}, MutableArrayRef> regions = {}); OperationState(Location location, StringRef name, ValueRange operands, - TypeRange types, ArrayRef attributes, + TypeRange types, ArrayRef attributes = {}, BlockRange successors = {}, MutableArrayRef> regions = {}); diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h --- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h +++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h @@ -50,13 +50,15 @@ /// - preloadDialectsInContext will trigger the upfront loading of all /// dialects from the global registry in the MLIRContext. This option is /// deprecated and will be removed soon. +/// - emitBytecode will generate bytecode output instead of text. LogicalResult MlirOptMain(llvm::raw_ostream &outputStream, std::unique_ptr buffer, const PassPipelineCLParser &passPipeline, DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext = false); + bool preloadDialectsInContext = false, + bool emitBytecode = false); /// Support a callback to setup the pass manager. /// - passManagerSetupFn is the callback invoked to setup the pass manager to @@ -67,7 +69,8 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext = false); + bool preloadDialectsInContext = false, + bool emitBytecode = false); /// Implementation for tools like `mlir-opt`. /// - toolName is used for the header displayed by `--help`. diff --git a/mlir/lib/Bytecode/CMakeLists.txt b/mlir/lib/Bytecode/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(Reader) +add_subdirectory(Writer) \ No newline at end of file diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Encoding.h @@ -0,0 +1,127 @@ +//===- Encoding.h - MLIR binary format encoding information -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines enum values describing the structure of MLIR bytecode +// files. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_BYTECODE_ENCODING_H +#define LIB_MLIR_BYTECODE_ENCODING_H + +#include + +namespace mlir { +namespace bytecode { +//===----------------------------------------------------------------------===// +// General constants +//===----------------------------------------------------------------------===// + +enum { + /// The current bytecode version. + kVersion = 0, + + /// The first non-builtin section code. + kFirstNonBuiltinCode = 16, +}; + +namespace BuiltinCode { +enum : uint8_t { + /// This value indicates the code for a section. + kSection = 0, +}; +} // namespace BuiltinCode + +//===----------------------------------------------------------------------===// +// Sections +//===----------------------------------------------------------------------===// + +namespace Section { +enum ID : uint8_t { + /// This section contains the dialects referenced within an IR module. + kDialect = 0, + + /// This section contains the attributes and types referenced within an IR + /// module. + kAttrType = 1, + + /// This section contains the offsets for the attribute and types within the + /// AttrType section. + kAttrTypeOffset = 2, + + /// This section contains the top level operation, and its nested + /// regions/operations. + kTopLevelOp = 3, + + /// The total number of section types. + kNumSections = 4, +}; +} // namespace Section + +//===----------------------------------------------------------------------===// +// AttrType Section +//===----------------------------------------------------------------------===// + +namespace AttrTypeCode { +enum : uint8_t { + /// This code represents an attribute or type represented in the textual + /// assembly format. + kAsmForm, +}; +} // namespace AttrTypeCode + +//===----------------------------------------------------------------------===// +// kTopLevelOp Section +//===----------------------------------------------------------------------===// + +namespace TopLevelOpCode { +enum : uint8_t { + //===--------------------------------------------------------------------===// + // Operation Codes + + /// This code represents an operation. + kOp = kFirstNonBuiltinCode, + + //===--------------------------------------------------------------------===// + // Region Codes + + /// This code represents a non-empty region. + kRegion, + + /// This code represents an empty region. + kRegionEmpty, + + //===--------------------------------------------------------------------===// + // Block Codes + + /// This code represents the argument list of a block. + kBlockArguments, + + /// This code represents the end of a block. + kBlockEnd, +}; +} // namespace TopLevelOpCode + +/// This enum represents a mask of all of the potential components of an +/// operation. This mask is used when encoding an operation to indicate which +/// components are present in the bytecode. +namespace OpEncodingMask { +enum : uint8_t { + kHasLoc = 1 << 0, + kHasAttrs = 1 << 1, + kHasResults = 1 << 2, + kHasOperands = 1 << 3, + kHasSuccessors = 1 << 4, + kHasInlineRegions = 1 << 5, +}; +} // namespace OpEncodingMask + +} // namespace bytecode +} // namespace mlir + +#endif diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -0,0 +1,965 @@ +//===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeReader.h" +#include "../Encoding.h" +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/MemoryBufferRef.h" +#include "llvm/Support/SaveAndRestore.h" + +#define DEBUG_TYPE "mlir-bytecode" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// EncodingReader +//===----------------------------------------------------------------------===// + +namespace { +class EncodingReader { +public: + explicit EncodingReader(ArrayRef contents, Location fileLoc) + : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} + explicit EncodingReader(StringRef contents, Location fileLoc) + : EncodingReader({reinterpret_cast(contents.data()), + contents.size()}, + fileLoc) {} + + /// Returns true if the entire section has been read. + bool empty() const { return dataIt == dataEnd; } + + /// Returns the remaining size of the bytecode. + size_t size() const { return dataEnd - dataIt; } + + /// Emit an error using the given arguments. + template + LogicalResult emitError(Args &&...args) const { + return ::emitError(fileLoc).append(std::forward(args)...); + } + + /// Parse a single byte from the stream. + template + ParseResult parseByte(T &value) { + if (empty()) + return emitError("attempting to parse a byte at the end of the bytecode"); + value = *dataIt++; + return success(); + } + /// Parse a range of bytes of 'length' into the given result. + ParseResult parseBytes(size_t length, ArrayRef &result) { + if (length > size()) { + return emitError("attempting to parse ", length, " bytes when only ", + size(), " remain"); + } + result = {dataIt, length}; + dataIt += length; + return success(); + } + /// Parse a range of bytes of 'length' into the given result, which can be + /// assumed to be large enough to hold `length`. + ParseResult parseBytes(size_t length, uint8_t *result) { + if (length > size()) { + return emitError("attempting to parse ", length, " bytes when only ", + size(), " remain"); + } + memcpy(result, dataIt, length); + dataIt += length; + return success(); + } + + /// Parse a variable length encoded integer from the byte stream. The first + /// encoded byte contains a prefix in the low bits indicating the encoded + /// length of the value. This length prefix is a bit sequence of '0's followed + /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes + /// (not including the prefix byte). All remaining bits in the first byte, + /// along with all of the bits in additional bytes, provide the value of the + /// integer encoded in little-endian order. + ParseResult parseVarInt(uint64_t &result) { + // Parse the first byte of the encoding, which contains the length prefix. + if (parseByte(result)) + return failure(); + + // Handle the overwhelmingly common case where the value is stored in a + // single byte. In this case, the first bit is the `1` marker bit. + if (LLVM_LIKELY(result & 1)) { + result >>= 1; + return success(); + } + + // Handle the overwhelming uncommon case where the value required all 8 + // bytes (i.e. a really really big number). In this case, the marker byte is + // all zeros: `00000000`. + if (LLVM_UNLIKELY(result == 0)) + return parseBytes(sizeof(result), reinterpret_cast(&result)); + return parseMultiByteVarInt(result); + } + + /// Skip the first `length` bytes within the reader. + ParseResult skipBytes(size_t length) { + if (length > size()) { + return emitError("attempting to skip ", length, " bytes when only ", + size(), " remain"); + } + dataIt += length; + return success(); + } + + /// Parse a NUL terminated string into `result` (without including the NUL + /// terminator). + ParseResult parseNULTerminatedString(StringRef &result) { + const char *startIt = (const char *)dataIt; + const char *nulIt = (const char *)memchr(startIt, 0, size()); + if (!nulIt) + return emitError("malformed NUL terminated string, no NUL found"); + + result = StringRef(startIt, nulIt - startIt); + dataIt = (const uint8_t *)nulIt + 1; + return success(); + } + + /// Parse a section header, placing the kind of section in `sectionID` and the + /// contents of the section in `sectionData`. + ParseResult parseSection(uint8_t §ionID, ArrayRef §ionData) { + size_t length; + if (parseByte(sectionID) || parseVarInt(length)) + return failure(); + + // Parse the actua section data now that we have its length. + return parseBytes(length, sectionData); + } + +private: + /// Parse a variable length encoded integer from the byte stream. This method + /// is a fallback when the number of bytes used to encode the value is greater + /// than 1, but less than the max (9). The provided `result` value can be + /// assumed to already contain the first byte of the value. This method is + /// marked noinline to avoid pessimizing the common case of single byte + /// encoding. + LLVM_ATTRIBUTE_NOINLINE ParseResult parseMultiByteVarInt(uint64_t &result) { + // Count the number of trailing zeros in the marker byte, this indicates the + // number of trailing bytes that are part of the value. We use `uint32_t` + // here because we only care about the first byte, and so that be actually + // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop + // implementation). + uint32_t numBytes = + llvm::countTrailingZeros(result, llvm::ZB_Undefined); + + // Parse in the remaining bytes of the value. + if (parseBytes(numBytes, reinterpret_cast(&result) + 1)) + return failure(); + + // Shift out the low-order bits that were used to mark how the value was + // encoded. + result >>= (numBytes + 1); + return success(); + } + + /// The current data iterator, and an iterator to the end of the buffer. + const uint8_t *dataIt, *dataEnd; + + /// A location for the bytecode used to report errors. + Location fileLoc; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// BytecodeDialect +//===----------------------------------------------------------------------===// + +namespace { +/// This struct represents a dialect entry within the bytecode. +struct BytecodeDialect { + BytecodeDialect(Dialect *dialect, StringRef name, unsigned numAttrs, + unsigned numTypes) + : dialect(dialect), name(name), numAttrs(numAttrs), numTypes(numTypes) {} + + /// The loaded dialect entry, if available, otherwise nullptr. + Dialect *dialect; + + /// The name of the dialect. + StringRef name; + + /// The number of attributes owned by this dialect in the bytecode. + unsigned numAttrs; + + /// The number of types owned by this dialect in the bytecode. + unsigned numTypes; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Attribute/Type Reader +//===----------------------------------------------------------------------===// + +namespace { +/// This class provides support for reading attribute and type entries from the +/// bytecode. Attribute and Type entries are read lazily on demand, so we use +/// this reader to manage when to actually parse them from the bytecode. +class AttrTypeReader { + /// This class represents a single attribute or type entry. + template + struct Entry { + /// The entry, or null if it hasn't been resolved yet. + T entry = {}; + /// The raw data of this entry in the bytecode. + ArrayRef data; + }; + using AttrEntry = Entry; + using TypeEntry = Entry; + +public: + AttrTypeReader(Location fileLoc) : fileLoc(fileLoc) {} + + /// Initialize the attribute and type information within the reader. + LogicalResult initialize(ArrayRef dialects, + ArrayRef sectionData, + ArrayRef offsetSectionData); + + /// Resolve the attribute or type at the given index. Returns nullptr on + /// failure. + Attribute resolveAttribute(unsigned index) { + return resolveEntry(attributes, index, "Attribute"); + } + Type resolveType(unsigned index) { + return resolveEntry(types, index, "Type"); + } + +private: + /// Initialize the offsets for the attribute and type entries. + LogicalResult initializeOffsets(ArrayRef sectionData, + ArrayRef offsetSectionData); + + /// Resolve the given entry at `index`. + template + T resolveEntry(SmallVectorImpl> &entries, unsigned index, + StringRef entryType); + + /// Parse the value defined within the given reader. `code` indicates how the + /// entry was encoded. + LogicalResult parseEntry(EncodingReader &reader, uint8_t code, + Attribute &result); + LogicalResult parseEntry(EncodingReader &reader, uint8_t code, Type &result); + + /// The set of attribute and type entries. + SmallVector attributes; + SmallVector types; + + /// A location used for error emission. + Location fileLoc; +}; +} // namespace + +LogicalResult AttrTypeReader::initialize(ArrayRef dialects, + ArrayRef sectionData, + ArrayRef offsetSectionData) { + // Initialize the entries using the dialect information. + unsigned numAttrs = 0, numTypes = 0; + for (const BytecodeDialect &dialect : dialects) { + numAttrs += dialect.numAttrs; + numTypes += dialect.numTypes; + } + attributes.resize(numAttrs); + types.resize(numTypes); + + // With the entries initialized, we can process the offsets. + return initializeOffsets(sectionData, offsetSectionData); +} + +LogicalResult +AttrTypeReader::initializeOffsets(ArrayRef sectionData, + ArrayRef offsetSectionData) { + EncodingReader reader(offsetSectionData, fileLoc); + + // A functor used to accumulate the offsets for the entries in the given + // range. + uint64_t currentOffset = 0; + auto accumulateOffsets = [&](auto &&range) { + for (auto &entry : range) { + uint64_t entrySize; + if (reader.parseVarInt(entrySize)) + return failure(); + entry.data = sectionData.slice(currentOffset, entrySize); + currentOffset += entrySize; + } + return success(); + }; + + // Process each of the attributes, and then the types. + if (failed(accumulateOffsets(attributes)) || failed(accumulateOffsets(types))) + return failure(); + + // Ensure that we read everything from the section. + if (!reader.empty()) { + return reader.emitError( + "unexpected trailing data in the Attribute/Type offset section"); + } + return success(); +} + +template +T AttrTypeReader::resolveEntry(SmallVectorImpl> &entries, + unsigned index, StringRef entryType) { + if (index >= entries.size()) { + emitError(fileLoc) << "invalid " << entryType << "index:" << index; + return {}; + } + + // If the entry has already been resolved, there is nothing left to do. + Entry &entry = entries[index]; + if (entry.entry) + return entry.entry; + + // Parse the entry. Each entry starts with a specific code that indicates how + // it is represented. + EncodingReader reader(entry.data, fileLoc); + uint8_t code; + if (reader.parseByte(code) || failed(parseEntry(reader, code, entry.entry))) + return T(); + if (!reader.empty()) { + (void)reader.emitError("unexpected trailing bytes after " + entryType + + " entry"); + return T(); + } + return entry.entry; +} + +LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader, uint8_t code, + Attribute &result) { + // Handle the fallback case, where the attribute was encoded using its + // assembly format. + if (code == bytecode::AttrTypeCode::kAsmForm) { + StringRef attrStr; + if (failed(reader.parseNULTerminatedString(attrStr))) + return failure(); + + size_t numRead = 0; + if (!(result = parseAttribute(attrStr, fileLoc->getContext(), numRead))) + return failure(); + if (numRead != attrStr.size()) { + return reader.emitError( + "trailing characters found after Attribute assembly format: ", + attrStr.drop_front(numRead)); + } + return success(); + } + + return reader.emitError("unexpected Attribute encoding: ", code); +} + +LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader, uint8_t code, + Type &result) { + // Handle the fallback case, where the type was encoded using its + // assembly format. + if (code == bytecode::AttrTypeCode::kAsmForm) { + StringRef typeStr; + if (failed(reader.parseNULTerminatedString(typeStr))) + return failure(); + + size_t numRead = 0; + if (!(result = parseType(typeStr, fileLoc->getContext(), numRead))) + return failure(); + if (numRead != typeStr.size()) { + return reader.emitError( + "trailing characters found after Type assembly format: " + + typeStr.drop_front(numRead)); + } + return success(); + } + + return reader.emitError("unexpected Type encoding: ", code); +} + +//===----------------------------------------------------------------------===// +// Bytecode Reader +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to read a bytecode buffer and translate it into MLIR. +class BytecodeReader { +public: + BytecodeReader(Location fileLoc, const ParserConfig &config) + : config(config), fileLoc(fileLoc), attrTypeReader(fileLoc), + forwardRefOpState(UnknownLoc::get(config.getContext()), + "builtin.unrealized_conversion_cast", ValueRange(), + NoneType::get(config.getContext())) {} + + /// Read the bytecode defined within `buffer` into the given block. + LogicalResult read(llvm::MemoryBufferRef buffer, Block *block); + +private: + /// Return the context for this config. + MLIRContext *getContext() const { return config.getContext(); } + + //===--------------------------------------------------------------------===// + // Dialect Section + + LogicalResult parseDialectSection(ArrayRef sectionData); + + /// Parse an operation name reference using the given reader. + FailureOr parseOpName(EncodingReader &reader); + + //===--------------------------------------------------------------------===// + // Attribute/Type Section + + /// Parse an attribute or type using the given reader. Returns nullptr in the + /// case of failure. + Attribute parseAttribute(EncodingReader &reader); + Type parseType(EncodingReader &reader); + + template + T parseAttribute(EncodingReader &reader) { + if (Attribute attr = parseAttribute(reader)) { + if (auto derivedAttr = attr.dyn_cast()) + return derivedAttr; + (void)reader.emitError("expected attribute of type: ", + llvm::getTypeName(), ", but got: ", attr); + } + return T(); + } + + //===--------------------------------------------------------------------===// + // TopLevelOp Section + + LogicalResult parseTopLevelOpSection(ArrayRef sectionData, + Block *block); + LogicalResult parseOp(EncodingReader &reader, Block *block, + ArrayRef regionBlocks, LocationAttr &lastLoc); + LogicalResult parseRegion(EncodingReader &reader, Region *region, + LocationAttr &lastLoc); + LogicalResult parseBlock(EncodingReader &reader, Block *block, + ArrayRef regionBlocks, + LocationAttr &lastLoc); + LogicalResult parseBlockArguments(EncodingReader &reader, Block *block, + LocationAttr &lastLoc); + + //===--------------------------------------------------------------------===// + // Value Processing + + /// Parse an operand reference using the given reader. Returns nullptr in the + /// case of failure. + Value parseOperand(EncodingReader &reader); + + /// Sequentially define the given value range starting at the provided first + /// value ID. + LogicalResult defineValues(EncodingReader &reader, ValueRange values, + unsigned firstValueID); + + /// Create a value to use for a forward reference. + Value createForwardRef(); + + //===--------------------------------------------------------------------===// + // Fields + + /// The configuration of the parser. + const ParserConfig &config; + + /// A location to use when emitting errors. + Location fileLoc; + + /// The reader used to process attribute and types within the bytecode. + AttrTypeReader attrTypeReader; + + /// The version of the bytecode being read. + uint64_t version = 0; + + /// The table of IR units referenced within the bitcode file. + SmallVector dialects; + SmallVector opNames; + + /// The current set of available IR values. + std::vector values; + /// A block containing the set of operations defined to create forward + /// references. + Block forwardRefOps; + /// A block containing previously created, and no longer used, forward + /// reference operations. + Block openForwardRefOps; + /// An operation state used when instantiating forward references. + OperationState forwardRefOpState; +}; +} // namespace + +LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { + EncodingReader reader(buffer.getBuffer(), fileLoc); + + // Skip over the bytecode header, this should have already been checked. + if (reader.skipBytes(StringRef("ML\xefR").size())) + return failure(); + + // Parse the bytecode version. + if (reader.parseVarInt(version)) + return failure(); + + // Validate the bytecode version. + if (version < bytecode::kVersion) { + return reader.emitError( + "bytecode version ", version, " is older than the current version of ", + bytecode::kVersion, ", and upgrade is not supported."); + } + if (version > bytecode::kVersion) { + return reader.emitError("bytecode version ", version, + " is newer than the current version ", + bytecode::kVersion, "."); + } + + // The raw data for the AttrTypeOffset section. + Optional> attrTypeOffsetSection; + + BitVector seenSections(bytecode::Section::kNumSections); + while (!reader.empty()) { + // Read the next section from the bytecode. + uint8_t code; + if (reader.parseByte(code) || code != bytecode::BuiltinCode::kSection) + return reader.emitError("expected top-level section code"); + uint8_t sectionID; + ArrayRef sectionData; + if (reader.parseSection(sectionID, sectionData)) + return failure(); + + // Check for duplicate sections, we only expect one instance of each. + if (seenSections.test(sectionID)) + return reader.emitError("duplicate top-level section ID: ", sectionID); + seenSections.set(sectionID); + + // Process the section. + switch (sectionID) { + case bytecode::Section::kDialect: + if (failed(parseDialectSection(sectionData))) + return failure(); + break; + case bytecode::Section::kAttrType: + if (!attrTypeOffsetSection) { + return reader.emitError( + "expected the AttrTypeOffset section before the AttrType section"); + } + if (dialects.empty()) { + return reader.emitError("expected the Dialect section before the" + "AttrTypeOffset section"); + } + + // With everything ready, initialize the attribute/type reader. + if (failed(attrTypeReader.initialize(dialects, sectionData, + *attrTypeOffsetSection))) + return failure(); + break; + case bytecode::Section::kAttrTypeOffset: + // We won't parse this section until we process the main AttrType section. + // For now, just record the raw data. + attrTypeOffsetSection = sectionData; + break; + case bytecode::Section::kTopLevelOp: + if (failed(parseTopLevelOpSection(sectionData, block))) + return failure(); + break; + default: + return reader.emitError("unexpected top-level section: ", sectionID); + } + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Dialect Section + +LogicalResult +BytecodeReader::parseDialectSection(ArrayRef sectionData) { + MLIRContext *ctx = getContext(); + + EncodingReader sectionReader(sectionData, fileLoc); + while (!sectionReader.empty()) { + // Read the name of the next dialect. + StringRef dialectName; + if (sectionReader.parseNULTerminatedString(dialectName)) + return failure(); + + // Parse the attribute and type counts. + uint64_t attrCount, typeCount; + if (sectionReader.parseVarInt(attrCount) || + sectionReader.parseVarInt(typeCount)) + return failure(); + + // Try to load the dialect. + Dialect *dialect = ctx->getOrLoadDialect(dialectName); + if (!dialect && !ctx->allowsUnregisteredDialects()) { + return sectionReader.emitError( + "dialect '", dialectName, + "' is unknown. If this is intended, please call " + "allowUnregisteredDialects() on the MLIRContext, or use " + "-allow-unregistered-dialect with the MLIR tool used."); + } + dialects.emplace_back(dialect, dialectName, attrCount, typeCount); + + // Parse the operation names of the dialect. + uint64_t numOpNames; + if (sectionReader.parseVarInt(numOpNames)) + return failure(); + SmallString<32> opNameStorage({dialectName, "."}); + while (numOpNames--) { + StringRef opName; + if (sectionReader.parseNULTerminatedString(opName)) + return failure(); + + opNameStorage.resize(dialectName.size() + 1); + opNameStorage.append(opName); + opNames.push_back(OperationName(opNameStorage, ctx)); + } + } + return success(); +} + +FailureOr BytecodeReader::parseOpName(EncodingReader &reader) { + uint64_t opNameIdx; + if (reader.parseVarInt(opNameIdx)) + return failure(); + + if (opNameIdx >= opNames.size()) + return reader.emitError("invalid operation name index: ", opNameIdx); + return opNames[opNameIdx]; +} + +//===----------------------------------------------------------------------===// +// Attribute/Type Section + +Attribute BytecodeReader::parseAttribute(EncodingReader &reader) { + uint64_t attrIdx; + if (reader.parseVarInt(attrIdx)) + return Attribute(); + return attrTypeReader.resolveAttribute(attrIdx); +} + +Type BytecodeReader::parseType(EncodingReader &reader) { + uint64_t typeIdx; + if (reader.parseVarInt(typeIdx)) + return Type(); + return attrTypeReader.resolveType(typeIdx); +} + +//===----------------------------------------------------------------------===// +// TopLevelOp Section + +LogicalResult +BytecodeReader::parseTopLevelOpSection(ArrayRef sectionData, + Block *block) { + EncodingReader reader(sectionData, fileLoc); + + LocationAttr lastLoc; + if (failed(parseOp(reader, block, /*regionBlocks=*/llvm::None, lastLoc))) + return failure(); + if (!forwardRefOps.empty()) + return reader.emitError( + "not all forward unresolved forward operand references"); + return success(); +} + +LogicalResult BytecodeReader::parseOp(EncodingReader &reader, Block *block, + ArrayRef regionBlocks, + LocationAttr &lastLoc) { + // Parse the name of the operation. + FailureOr opName = parseOpName(reader); + if (failed(opName)) + return failure(); + + // Parse the operation mask, which indicates which components of the operation + // are present. + uint8_t opMask; + if (reader.parseByte(opMask)) + return failure(); + + /// Check to see if this op has a new location. + if (opMask & bytecode::OpEncodingMask::kHasLoc) { + if (!(lastLoc = parseAttribute(reader))) + return failure(); + } + + // With the location and name resolved, we can start building the operation + // state. + OperationState opState(lastLoc, *opName); + + // Parse the attributes of the operation. + if (opMask & bytecode::OpEncodingMask::kHasAttrs) { + DictionaryAttr dictAttr = parseAttribute(reader); + if (!dictAttr) + return failure(); + opState.attributes = dictAttr; + } + + /// Parse the results of the operation. + Optional firstResultID; + if (opMask & bytecode::OpEncodingMask::kHasResults) { + firstResultID.emplace(0); + if (reader.parseVarInt(*firstResultID)) + return failure(); + + // Parse the result types. + uint64_t numResults; + if (reader.parseVarInt(numResults)) + return failure(); + opState.types.resize(numResults); + for (int i = 0, e = numResults; i < e; ++i) + if (!(opState.types[i] = parseType(reader))) + return failure(); + } + + /// Parse the operands of the operation. + if (opMask & bytecode::OpEncodingMask::kHasOperands) { + uint64_t numOperands; + if (reader.parseVarInt(numOperands)) + return failure(); + opState.operands.resize(numOperands); + for (int i = 0, e = numOperands; i < e; ++i) + if (!(opState.operands[i] = parseOperand(reader))) + return failure(); + } + + /// Parse the successors of the operation. + if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { + uint64_t numSuccs; + if (reader.parseVarInt(numSuccs)) + return failure(); + opState.operands.reserve(numSuccs); + for (int i = 0, e = numSuccs; i < e; ++i) { + uint64_t succID; + if (reader.parseVarInt(succID)) + return failure(); + if (succID >= regionBlocks.size()) + return reader.emitError("invalid successor index: ", succID); + opState.successors.push_back(regionBlocks[succID]); + } + } + + /// Parse the regions of the operation. + if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { + uint64_t numRegions; + if (reader.parseVarInt(numRegions)) + return failure(); + opState.regions.reserve(numRegions); + for (int i = 0, e = numRegions; i < e; ++i) { + opState.regions.push_back(std::make_unique()); + if (failed(parseRegion(reader, &*opState.regions.back(), lastLoc))) + return failure(); + } + } + + // Create the operation. + Operation *op = Operation::create(opState); + block->push_back(op); + + // If the operation had results, update the value references. + if (firstResultID) + return defineValues(reader, op->getResults(), *firstResultID); + return LogicalResult::success(); +} + +LogicalResult BytecodeReader::parseRegion(EncodingReader &reader, + Region *region, + LocationAttr &lastLoc) { + // Read the code defining how this region was encoded. + uint8_t regionCode; + if (reader.parseByte(regionCode)) + return failure(); + + // If it's an empty region, there is nothing more to do. + if (regionCode == bytecode::TopLevelOpCode::kRegionEmpty) + return success(); + + // Otherwise, we need to parse the region body. + if (regionCode != bytecode::TopLevelOpCode::kRegion) + return reader.emitError("invalid region code: ", regionCode); + + // Parse the number of blocks and values in this region. + uint64_t numBlocks, numValues; + if (reader.parseVarInt(numBlocks) || reader.parseVarInt(numValues)) + return failure(); + + // Reserve enough values for those defined in this region. Make sure to reset + // the size of the value table after processing though. + size_t origNumValues = values.size(); + auto atExit = llvm::make_scope_exit([&]() { values.resize(origNumValues); }); + values.resize(values.size() + numValues); + + // Create the blocks within this region. We do this before processing so that + // we can rely on the blocks existing when creating operations. + SmallVector regionBlocks; + regionBlocks.reserve(numBlocks); + for (uint64_t i = 0; i < numBlocks; ++i) { + regionBlocks.push_back(new Block()); + region->push_back(regionBlocks.back()); + } + + for (uint64_t i = 0; i < numBlocks; ++i) + if (failed(parseBlock(reader, regionBlocks[i], regionBlocks, lastLoc))) + return failure(); + return success(); +} + +LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, Block *block, + ArrayRef regionBlocks, + LocationAttr &lastLoc) { + // Parse the first code of the block explicitly in case the block has + // arguments. + uint8_t blockCode = 0; + if (reader.parseByte(blockCode)) + return failure(); + + // Check for arguments to the block. + if (blockCode == bytecode::TopLevelOpCode::kBlockArguments) { + if (failed(parseBlockArguments(reader, block, lastLoc))) + return failure(); + + // Parse the next block code. + if (reader.parseByte(blockCode)) + return failure(); + } + + while (blockCode != bytecode::TopLevelOpCode::kBlockEnd) { + // Parse an operation within the block. + if (blockCode == bytecode::TopLevelOpCode::kOp) { + if (failed(parseOp(reader, block, regionBlocks, lastLoc))) + return failure(); + } else { + return reader.emitError("unknown block code: ", blockCode); + } + + // Parse the next code. + if (reader.parseByte(blockCode)) + return failure(); + } + return success(); +} + +LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader, + Block *block, + LocationAttr &lastLoc) { + // Parse the value ID for the first argument, and the number of arguments. + uint64_t firstArgID, numArgs; + if (reader.parseVarInt(firstArgID) || reader.parseVarInt(numArgs)) + return failure(); + + SmallVector argTypes; + SmallVector argLocs; + argTypes.reserve(numArgs); + argLocs.reserve(numArgs); + while (numArgs--) { + uint64_t typeIdx; + if (reader.parseVarInt(typeIdx)) + return failure(); + + // Check the low bit of the type index to see if this argument has a new + // location. + bool hasNewLoc = (typeIdx & 1) != 0; + typeIdx >>= 1; + + // Parse the type, and optionally the location. + Type argType = attrTypeReader.resolveType(typeIdx); + if (!argType) + return failure(); + if (hasNewLoc && !(lastLoc = parseAttribute(reader))) + return failure(); + + argTypes.push_back(argType); + argLocs.push_back(lastLoc); + } + block->addArguments(argTypes, argLocs); + return defineValues(reader, block->getArguments(), firstArgID); +} + +//===----------------------------------------------------------------------===// +// Value Processing + +Value BytecodeReader::parseOperand(EncodingReader &reader) { + uint64_t valueIdx; + if (failed(reader.parseVarInt(valueIdx))) + return nullptr; + if (valueIdx >= values.size()) + return (void)reader.emitError("invalid value index: ", valueIdx), Value(); + + // Resolve it, or create a new forward reference if necessary. + Value &value = values[valueIdx]; + if (!value) + value = createForwardRef(); + return value; +} + +LogicalResult BytecodeReader::defineValues(EncodingReader &reader, + ValueRange newValues, + unsigned firstValueID) { + size_t maxId = firstValueID + newValues.size(); + if (maxId > values.size()) { + return reader.emitError( + "value index range was outside of the expected range for " + "the parent region, got [", + firstValueID, ", ", maxId, "), but the maximum index was ", + values.size() - 1); + } + + // Assign the values and update any forward references. + for (unsigned i = 0, e = newValues.size(); i != e; ++i) { + Value newValue = newValues[i]; + + // Check to see if a definition for this value already exists. + if (Value oldValue = std::exchange(values[firstValueID + i], newValue)) { + Operation *forwardRefOp = oldValue.getDefiningOp(); + if (!forwardRefOp || forwardRefOp->getBlock() != &forwardRefOps) { + return reader.emitError("value index ", firstValueID + i, + " was already defined"); + } + + oldValue.replaceAllUsesWith(newValue); + forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); + } + } + return LogicalResult::success(); +} + +Value BytecodeReader::createForwardRef() { + // Check for an avaliable existing operation to use. Otherwise, create a new + // fake operation to use for the reference. + if (!openForwardRefOps.empty()) { + Operation *op = &openForwardRefOps.back(); + op->moveBefore(&forwardRefOps, forwardRefOps.end()); + } else { + forwardRefOps.push_back(Operation::create(forwardRefOpState)); + } + return forwardRefOps.back().getResult(0); +} + +//===----------------------------------------------------------------------===// +// Entry Points +//===----------------------------------------------------------------------===// + +bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { + return buffer.getBuffer().startswith("ML\xefR"); +} + +LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, + const ParserConfig &config) { + Location sourceFileLoc = + FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), + /*line=*/0, /*column=*/0); + if (!isBytecode(buffer)) { + return emitError(sourceFileLoc, + "input buffer is not an MLIR bytecode file"); + } + + Block parsedBlock; + BytecodeReader reader(sourceFileLoc, config); + if (failed(reader.read(buffer, &parsedBlock))) + return failure(); + + // Splice the parsed operations over to the provided top-level block. + auto &parsedOps = parsedBlock.getOperations(); + auto &destOps = block->getOperations(); + destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), + parsedOps, parsedOps.begin(), parsedOps.end()); + return success(); +} diff --git a/mlir/lib/Bytecode/Reader/CMakeLists.txt b/mlir/lib/Bytecode/Reader/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Reader/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_library(MLIRBytecodeReader + BytecodeReader.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode + + LINK_LIBS PUBLIC + MLIRAsmParser + MLIRIR + MLIRSupport + ) diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -0,0 +1,474 @@ +//===- BytecodeWriter.cpp - MLIR Bytecode Writer --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Bytecode/BytecodeWriter.h" +#include "../Encoding.h" +#include "IRNumbering.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/Debug.h" +#include + +#define DEBUG_TYPE "mlir-bytecode" + +using namespace mlir; +using namespace mlir::bytecode::detail; + +//===----------------------------------------------------------------------===// +// BytecodeWriterConfig +//===----------------------------------------------------------------------===// + +struct BytecodeWriterConfig::Impl { + explicit Impl(Operation *op) : rootOp(op) {} + + /// The root operation of the bytecode. + Operation *rootOp; +}; + +BytecodeWriterConfig::BytecodeWriterConfig(Operation *op) + : impl(std::make_unique(op)) {} +BytecodeWriterConfig::~BytecodeWriterConfig() = default; + +Operation *BytecodeWriterConfig::getRootOp() const { return impl->rootOp; } + +//===----------------------------------------------------------------------===// +// EncodingEmitter +//===----------------------------------------------------------------------===// + +namespace { +/// This class functions as the underlying encoding emitter for the bytecode +/// writer. This class is a bit different compared to other types of encoders; +/// it does not use a single buffer, but instead may contain several buffers +/// (some owned by the writer, and some not) that get concatted during the final +/// emission. +class EncodingEmitter { +public: + EncodingEmitter() = default; + EncodingEmitter(const EncodingEmitter &) = delete; + EncodingEmitter &operator=(const EncodingEmitter &) = delete; + + /// Write the current contents to the provided stream. + void writeTo(raw_ostream &os) const; + + /// Return the current size of the encoded buffer. + size_t size() const { return prevResultSize + currentResult.size(); } + + //===--------------------------------------------------------------------===// + // Emission + //===--------------------------------------------------------------------===// + + /// Return a raw pointer into the result buffer at the specified offset. + uint8_t *getRawPointer(uint64_t offset) { + assert(offset < size() && offset >= prevResultSize && + "cannot get pointer to previously emitted data"); + return currentResult.data() + (offset - prevResultSize); + } + + //===--------------------------------------------------------------------===// + // Integer Emission + + /// Emit a single byte. + void emitByte(uint8_t byte) { currentResult.push_back(byte); } + + /// Emit a range of bytes. + void emitBytes(ArrayRef bytes) { + llvm::append_range(currentResult, bytes); + } + + /// Emit a variable length integer. The first encoded byte contains a prefix + /// in the low bits indicating the encoded length of the value. This length + /// prefix is a bit sequence of '0's followed by a '1'. The number of '0' bits + /// indicate the number of _additional_ bytes (not including the prefix byte). + /// All remaining bits in the first byte, along with all of the bits in + /// additional bytes, provide the value of the integer encoded in + /// little-endian order. + void emitVarInt(uint64_t value) { + // In the most common case, the value can be represented in a single byte. + // Given how hot this case is, explicitly handle that here. + if ((value >> 7) == 0) + return emitByte((value << 1) | 0x1); + emitMultiByteVarInt(value); + } + + //===--------------------------------------------------------------------===// + // String Emission + + /// Emit the given string as a nul terminated string. + void emitNulTerminatedString(StringRef str) { + emitString(str); + emitByte(0); + } + + /// Emit the given string without a nul terminator. + void emitString(StringRef str) { + emitBytes({reinterpret_cast(str.data()), str.size()}); + } + + //===--------------------------------------------------------------------===// + // Section Emission + + /// Emit a nested section of the given code, whose contents are encoded in the + /// provided emitter. + void emitSection(bytecode::Section::ID code, EncodingEmitter &&emitter) { + emitByte(bytecode::BuiltinCode::kSection); + + // Emit the section code and length. + emitByte(code); + emitVarInt(emitter.size()); + + // Push our current buffer and then merge the provided section body into + // ours. + appendResult(std::move(currentResult)); + for (std::vector &result : emitter.prevResultStorage) + appendResult(std::move(result)); + appendResult(std::move(emitter.currentResult)); + } + +private: + /// Emit the given value using a variable width encoding. This method is a + /// fallback when the number of bytes needed to encode the value is greater + /// than 1. We mark it noinline here so that the single byte hot path isn't + /// pessimized. + LLVM_ATTRIBUTE_NOINLINE void emitMultiByteVarInt(uint64_t value); + + /// Append a new result buffer to the current contents. + void appendResult(std::vector &&result) { + prevResultSize += result.size(); + prevResultStorage.emplace_back(std::move(result)); + prevResultList.emplace_back(prevResultStorage.back()); + } + + /// The result of the emitter currently being built. We refrain from building + /// a single buffer to simplify emitting sections, large data, and more. The + /// result is thus represented using multiple distinct buffers, some of which + /// we own (via prevResultStorage), and some of which are just pointers into + /// externally owned buffers. + std::vector currentResult; + std::vector> prevResultList; + std::vector> prevResultStorage; + + /// An up-to-date total size of all of the buffers within `prevResultList`. + /// This enables O(1) size checks of the current encoding. + size_t prevResultSize = 0; +}; + +/// A simple raw_ostream wrapper around a EncodingEmitter. This removes the need +/// to go through an intermediate buffer when interacting with code that wants a +/// raw_ostream. +class raw_emitter_ostream : public raw_ostream { +public: + explicit raw_emitter_ostream(EncodingEmitter &emitter) : emitter(emitter) { + SetUnbuffered(); + } + +private: + void write_impl(const char *ptr, size_t size) override { + emitter.emitBytes({reinterpret_cast(ptr), size}); + } + uint64_t current_pos() const override { return emitter.size(); } + + /// The section being emitted to. + EncodingEmitter &emitter; +}; +} // namespace + +void EncodingEmitter::writeTo(raw_ostream &os) const { + for (auto &prevResult : prevResultList) + os.write((const char *)prevResult.data(), prevResult.size()); + os.write((const char *)currentResult.data(), currentResult.size()); +} + +void EncodingEmitter::emitMultiByteVarInt(uint64_t value) { + // Compute the number of bytes needed to encode the value. Each byte can hold + // up to 7-bits of data. We only check up to the number of bits we can encode + // in the first byte (8). + uint64_t it = value >> 7; + for (size_t numBytes = 2; numBytes < 9; ++numBytes) { + if (LLVM_LIKELY(it >>= 7) == 0) { + uint64_t encodedValue = (value << 1) | 0x1; + encodedValue <<= (numBytes - 1); + emitBytes({reinterpret_cast(&encodedValue), numBytes}); + return; + } + } + + // If the value is too large to encode in a single byte, emit a special all + // zero marker byte and splat the value directly. + emitByte(0); + emitBytes({reinterpret_cast(&value), sizeof(value)}); +} + +//===----------------------------------------------------------------------===// +// Bytecode Writer +//===----------------------------------------------------------------------===// + +namespace { +class BytecodeWriter { +public: + BytecodeWriter(const BytecodeWriterConfig &config) : numberingState(config) {} + + /// Write the bytecode for the given root operation. + void write(Operation *rootOp, raw_ostream &os); + +private: + //===--------------------------------------------------------------------===// + // Dialects + + void writeDialectSection(EncodingEmitter &emitter); + + //===--------------------------------------------------------------------===// + // Attributes and Types + + void writeAttrTypeSection(EncodingEmitter &emitter); + + //===--------------------------------------------------------------------===// + // Operations + + void writeBlock(EncodingEmitter &emitter, Block *block, Attribute &lastLoc); + void writeOp(EncodingEmitter &emitter, Operation *op, Attribute &lastLoc); + void writeRegion(EncodingEmitter &emitter, Region *region, + Attribute &lastLoc); + void writeTopLevelOp(EncodingEmitter &emitter, Operation *op); + + //===--------------------------------------------------------------------===// + // Fields + + /// The IR numbering state generated for the root operation. + IRNumberingState numberingState; +}; +} // namespace + +void BytecodeWriter::write(Operation *rootOp, raw_ostream &os) { + EncodingEmitter emitter; + + // Emit the bytecode file header. This is how we identify the output as a + // bytecode file. + emitter.emitString("ML\xefR"); + + // Emit the bytecode version. + emitter.emitVarInt(bytecode::kVersion); + + // Emit the dialect section. + writeDialectSection(emitter); + + // Emit the attributes and types section. + writeAttrTypeSection(emitter); + + // Emit the top level operation section. + writeTopLevelOp(emitter, rootOp); + + // Write the generated bytecode to the provided output stream. + emitter.writeTo(os); +} + +//===----------------------------------------------------------------------===// +// Dialects + +void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) { + EncodingEmitter dialectEmitter; + + // Emit the referenced dialects. + for (DialectNumbering &dialect : numberingState.getDialects()) { + // Emit the dialect name. + dialectEmitter.emitNulTerminatedString(dialect.name); + + // Emit the number of attributes and types emitted for this dialect. + dialectEmitter.emitVarInt(dialect.attributes.size()); + dialectEmitter.emitVarInt(dialect.types.size()); + + // Emit the referenced operation names of this dialect. + dialectEmitter.emitVarInt(dialect.opNames.size()); + for (OpNameNumbering *opName : dialect.opNames) + dialectEmitter.emitNulTerminatedString(opName->name.stripDialect()); + } + + emitter.emitSection(bytecode::Section::kDialect, std::move(dialectEmitter)); +} + +//===----------------------------------------------------------------------===// +// Attributes and Types + +void BytecodeWriter::writeAttrTypeSection(EncodingEmitter &emitter) { + EncodingEmitter attrTypeEmitter; + EncodingEmitter offsetEmitter; + + // A functor used to emit an attribute or type entry. + uint64_t prevOffset = 0; + auto emitAttrOrType = [&](auto value) { + // Emit the entry using the textual format. + // TODO: Allow dialects to provide more optimal implementations of attribute + // and type encodings. + attrTypeEmitter.emitByte(bytecode::AttrTypeCode::kAsmForm); + raw_emitter_ostream(attrTypeEmitter) << value; + attrTypeEmitter.emitByte(0); + + // Record the offset of this entry. + uint64_t curOffset = attrTypeEmitter.size(); + offsetEmitter.emitVarInt(curOffset - prevOffset); + prevOffset = curOffset; + }; + + // Emit the attribute and type entries for each dialect. + for (DialectNumbering &dialect : numberingState.getDialects()) + for (AttributeNumbering *attr : dialect.attributes) + emitAttrOrType(attr->getValue()); + for (DialectNumbering &dialect : numberingState.getDialects()) + for (TypeNumbering *type : dialect.types) + emitAttrOrType(type->getValue()); + + // Emit the sections to the stream. + emitter.emitSection(bytecode::Section::kAttrTypeOffset, + std::move(offsetEmitter)); + emitter.emitSection(bytecode::Section::kAttrType, std::move(attrTypeEmitter)); +} + +//===----------------------------------------------------------------------===// +// Operations + +void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block, + Attribute &lastLoc) { + // Emit the arguments of the block. + ArrayRef args = block->getArguments(); + if (!args.empty()) { + emitter.emitByte(bytecode::TopLevelOpCode::kBlockArguments); + + // Emit the value number for the first argument, and the number of arguments + // we are encoding. + emitter.emitVarInt(numberingState.getNumber(args.front())); + emitter.emitVarInt(args.size()); + + for (const auto &it : llvm::enumerate(args)) { + // Check to see if this argument has a new location. + Attribute argLoc = it.value().getLoc(); + bool argHasNewLoc = argLoc != std::exchange(lastLoc, argLoc); + + // Emit the argument type. We use the low bit of the type number to + // indicate if the argument changed locations. + uint64_t typeID = numberingState.getNumber(it.value().getType()); + emitter.emitVarInt((typeID << 1) | (argHasNewLoc ? 1 : 0)); + if (argHasNewLoc) + emitter.emitVarInt(numberingState.getNumber(argLoc)); + } + } + + // Emit the operations within the block. + for (Operation &op : *block) { + emitter.emitByte(bytecode::TopLevelOpCode::kOp); + writeOp(emitter, &op, lastLoc); + } + // Emit a terminal code to indicate when we are finished emitting operations. + emitter.emitByte(bytecode::TopLevelOpCode::kBlockEnd); +} + +void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op, + Attribute &lastLoc) { + emitter.emitVarInt(numberingState.getNumber(op->getName())); + + // Emit a mask for the operation components. We need to fill this in later + // (when we actually know what needs to be emitted), so emit a placeholder for + // now. + uint64_t maskOffset = emitter.size(); + uint8_t opEncodingMask = 0; + emitter.emitByte(0); + + // Emit the location for this operation. + Attribute opLoc = op->getLoc(); + if (opLoc != std::exchange(lastLoc, opLoc)) { + opEncodingMask |= bytecode::OpEncodingMask::kHasLoc; + emitter.emitVarInt(numberingState.getNumber(opLoc)); + } + + // Emit the attributes of this operation. + DictionaryAttr attrs = op->getAttrDictionary(); + if (!attrs.empty()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasAttrs; + emitter.emitVarInt(numberingState.getNumber(op->getAttrDictionary())); + } + + // Emit the result types of the operation. + if (unsigned numResults = op->getNumResults()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasResults; + emitter.emitVarInt(numberingState.getNumber(op->getResult(0))); + emitter.emitVarInt(numResults); + for (Type type : op->getResultTypes()) + emitter.emitVarInt(numberingState.getNumber(type)); + } + + // Emit the operands of the operation. + if (unsigned numOperands = op->getNumOperands()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasOperands; + emitter.emitVarInt(numOperands); + for (Value operand : op->getOperands()) + emitter.emitVarInt(numberingState.getNumber(operand)); + } + + // Emit the successors of the operation. + if (unsigned numSuccessors = op->getNumSuccessors()) { + opEncodingMask |= bytecode::OpEncodingMask::kHasSuccessors; + emitter.emitVarInt(numSuccessors); + for (Block *successor : op->getSuccessors()) + emitter.emitVarInt(numberingState.getNumber(successor)); + } + + // Check for regions. + unsigned numRegions = op->getNumRegions(); + if (numRegions) + opEncodingMask |= bytecode::OpEncodingMask::kHasInlineRegions; + + // Update the mask for the operation. + *emitter.getRawPointer(maskOffset) = opEncodingMask; + + // With the mask emitted, we can now emit the regions of the operation. We do + // this after mask emission to avoid offset complications that may arise by + // emitting the regions first (e.g. if the regions are huge, backpatching the + // op encoding mask is more annoying). + if (numRegions) { + emitter.emitVarInt(numRegions); + for (Region ®ion : op->getRegions()) + writeRegion(emitter, ®ion, lastLoc); + } +} + +void BytecodeWriter::writeRegion(EncodingEmitter &emitter, Region *region, + Attribute &lastLoc) { + if (region->empty()) + return emitter.emitByte(bytecode::TopLevelOpCode::kRegionEmpty); + + // Emit the number of blocks and values within the region. + unsigned numBlocks, numValues; + std::tie(numBlocks, numValues) = numberingState.getBlockValueCount(region); + emitter.emitByte(bytecode::TopLevelOpCode::kRegion); + emitter.emitVarInt(numBlocks); + emitter.emitVarInt(numValues); + + // Emit the blocks within the region. + for (Block &block : *region) + writeBlock(emitter, &block, lastLoc); +} + +void BytecodeWriter::writeTopLevelOp(EncodingEmitter &emitter, Operation *op) { + EncodingEmitter topLevelOpEmitter; + + Attribute lastLoc; + writeOp(topLevelOpEmitter, op, lastLoc); + + emitter.emitSection(bytecode::Section::kTopLevelOp, + std::move(topLevelOpEmitter)); +} + +//===----------------------------------------------------------------------===// +// Entry Points +//===----------------------------------------------------------------------===// + +void mlir::writeBytecodeToFile(const BytecodeWriterConfig &config, + raw_ostream &os) { + BytecodeWriter writer(config); + writer.write(config.getRootOp(), os); +} diff --git a/mlir/lib/Bytecode/Writer/CMakeLists.txt b/mlir/lib/Bytecode/Writer/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Writer/CMakeLists.txt @@ -0,0 +1,11 @@ +add_mlir_library(MLIRBytecodeWriter + BytecodeWriter.cpp + IRNumbering.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Bytecode + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport + ) diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.h b/mlir/lib/Bytecode/Writer/IRNumbering.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Writer/IRNumbering.h @@ -0,0 +1,173 @@ +//===- IRNumbering.h - MLIR bytecode IR numbering ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains various utilities that number IR structures in preparation +// for bytecode emission. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H +#define LIB_MLIR_BYTECODE_WRITER_IRNUMBERING_H + +#include "mlir/IR/OperationSupport.h" +#include "llvm/ADT/MapVector.h" + +namespace mlir { +class BytecodeWriterConfig; + +namespace bytecode { +namespace detail { +struct DialectNumbering; + +//===----------------------------------------------------------------------===// +// Attribute and Type Numbering +//===----------------------------------------------------------------------===// + +/// This class represents a numbering entry for an Attribute or Type. +struct AttrTypeNumbering { + AttrTypeNumbering(PointerUnion value) : value(value) {} + + /// The concrete value. + PointerUnion value; + + /// The number assigned to this value. + unsigned number = 0; + + /// The dialect of this value. + DialectNumbering *dialect = nullptr; +}; +struct AttributeNumbering : public AttrTypeNumbering { + AttributeNumbering(Attribute value) : AttrTypeNumbering(value) {} + Attribute getValue() const { return value.get(); } +}; +struct TypeNumbering : public AttrTypeNumbering { + TypeNumbering(Type value) : AttrTypeNumbering(value) {} + Type getValue() const { return value.get(); } +}; + +//===----------------------------------------------------------------------===// +// OpName Numbering +//===----------------------------------------------------------------------===// + +/// This class represents the numbering entry of an operation name. +struct OpNameNumbering { + OpNameNumbering(OperationName name) : name(name) {} + + /// The concrete name. + OperationName name; + + /// The number assigned to this name. + unsigned number = 0; +}; + +//===----------------------------------------------------------------------===// +// Dialect Numbering +//===----------------------------------------------------------------------===// + +/// This class represents a numbering entry for an Dialect. +struct DialectNumbering { + DialectNumbering(StringRef name, unsigned number) + : name(name), number(number) {} + + /// The namespace of the dialect. + StringRef name; + + /// The number assigned to the dialect. + unsigned number; + + /// The loaded dialect, or nullptr if the dialect isn't loaded. + Dialect *dialect = nullptr; + + /// Numbered sub-components of the dialect to be emitted. + std::vector opNames; + std::vector attributes; + std::vector types; +}; + +//===----------------------------------------------------------------------===// +// IRNumberingState +//===----------------------------------------------------------------------===// + +/// This class manages numbering IR entities in preparation of bytecode +/// emission. +class IRNumberingState { +public: + IRNumberingState(const BytecodeWriterConfig &config); + + /// Return the numbered dialects. + auto getDialects() { + return llvm::make_pointee_range(llvm::make_second_range(dialects)); + } + + /// Return the number for the given IR unit. + unsigned getNumber(Attribute attr) { + assert(attrs.count(attr) && "attribute not numbered"); + return attrs[attr]->number; + } + unsigned getNumber(Block *block) { + assert(blockIDs.count(block) && "block not numbered"); + return blockIDs[block]; + } + unsigned getNumber(OperationName opName) { + assert(opNames.count(opName) && "opName not numbered"); + return opNames[opName]->number; + } + unsigned getNumber(Type type) { + assert(types.count(type) && "type not numbered"); + return types[type]->number; + } + unsigned getNumber(Value value) { + assert(valueIDs.count(value) && "value not numbered"); + return valueIDs[value]; + } + + /// Return the block and value counts of the given region. + std::pair getBlockValueCount(Region *region) { + assert(regionBlockValueCounts.count(region) && "value not numbered"); + return regionBlockValueCounts[region]; + } + +private: + /// Number the given IR unit for bytecode emission. + void number(Attribute attr); + void number(Block &block); + DialectNumbering &numberDialect(Dialect *dialect); + DialectNumbering &numberDialect(StringRef dialect); + void number(Operation &op); + void number(OperationName opName); + void number(Region ®ion); + void number(Type type); + + /// Mapping from IR to the respective numbering entries. + llvm::MapVector attrs; + llvm::MapVector opNames; + llvm::MapVector types; + llvm::MapVector registeredDialects; + llvm::MapVector dialects; + + /// Allocators used for the various numbering entries. + llvm::SpecificBumpPtrAllocator attrAllocator; + llvm::SpecificBumpPtrAllocator dialectAllocator; + llvm::SpecificBumpPtrAllocator opNameAllocator; + llvm::SpecificBumpPtrAllocator typeAllocator; + + /// The value ID for each Block and Value. + DenseMap blockIDs; + DenseMap valueIDs; + + /// A map from region to the number of blocks and values within that region. + DenseMap> regionBlockValueCounts; + + /// The next value ID to assign when numbering. + unsigned nextValueID = 0; +}; +} // namespace detail +} // namespace bytecode +} // namespace mlir + +#endif diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -0,0 +1,165 @@ +//===- IRNumbering.cpp - MLIR Bytecode IR numbering -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRNumbering.h" +#include "mlir/Bytecode/BytecodeWriter.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace mlir::bytecode::detail; + +//===----------------------------------------------------------------------===// +// IR Numbering +//===----------------------------------------------------------------------===// + +IRNumberingState::IRNumberingState(const BytecodeWriterConfig &config) { + Operation *op = config.getRootOp(); + + // Number the root operation. + number(*op); + + // Push all of the regions of the root operation onto the worklist. + SmallVector, 8> numberContext; + for (Region ®ion : op->getRegions()) + numberContext.emplace_back(®ion, nextValueID); + + // Iteratively process each of the nested regions. + while (!numberContext.empty()) { + Region *region; + std::tie(region, nextValueID) = numberContext.pop_back_val(); + number(*region); + + // Traverse into nested regions. + for (Operation &op : region->getOps()) + for (Region ®ion : op.getRegions()) + numberContext.emplace_back(®ion, nextValueID); + } + + // Walk and number the recorded components within each dialect. + unsigned attrID = 0, opNameID = 0, typeID = 0; + for (DialectNumbering *dialect : llvm::make_second_range(dialects)) { + for (AttributeNumbering *attr : dialect->attributes) + attr->number = attrID++; + for (OpNameNumbering *opName : dialect->opNames) + opName->number = opNameID++; + for (TypeNumbering *type : dialect->types) + type->number = typeID++; + } +} + +void IRNumberingState::number(Attribute attr) { + auto it = attrs.insert({attr, nullptr}); + if (!it.second) + return; + auto *numbering = new (attrAllocator.Allocate()) AttributeNumbering(attr); + it.first->second = numbering; + + // Check for OpaqueAttr, which is a dialect-specific attribute that didn't + // have a registered dialect when it got created. We don't want to encode this + // as the builtin OpaqueAttr, we want to encode it as if the dialect was + // actually loaded. + if (OpaqueAttr opaqueAttr = attr.dyn_cast()) + numbering->dialect = &numberDialect(opaqueAttr.getDialectNamespace()); + else + numbering->dialect = &numberDialect(&attr.getDialect()); + + numbering->dialect->attributes.push_back(numbering); +} + +void IRNumberingState::number(Block &block) { + // Number the arguments of the block. + for (BlockArgument arg : block.getArguments()) { + valueIDs.try_emplace(arg, nextValueID++); + number(arg.getLoc()); + number(arg.getType()); + } + + // Number the operations in this block. + for (Operation &op : block) + number(op); +} + +auto IRNumberingState::numberDialect(Dialect *dialect) -> DialectNumbering & { + DialectNumbering *&numbering = registeredDialects[dialect]; + if (!numbering) { + numbering = &numberDialect(dialect->getNamespace()); + numbering->dialect = dialect; + } + return *numbering; +} + +auto IRNumberingState::numberDialect(StringRef dialect) -> DialectNumbering & { + DialectNumbering *&numbering = dialects[dialect]; + if (!numbering) { + numbering = new (dialectAllocator.Allocate()) + DialectNumbering(dialect, dialects.size() - 1); + } + return *numbering; +} + +void IRNumberingState::number(Region ®ion) { + size_t firstValueID = nextValueID; + + // Number the blocks within this region. + size_t blockCount = 0; + for (auto &it : llvm::enumerate(region)) { + blockIDs.try_emplace(&it.value(), it.index()); + number(it.value()); + ++blockCount; + } + + // Remember the number of blocks and values in this region. + regionBlockValueCounts.try_emplace(®ion, blockCount, + nextValueID - firstValueID); +} + +void IRNumberingState::number(Operation &op) { + // Number the components of an operation that won't be numbered elsewhere + // (e.g. we don't number operands, regions, or successors here). + number(op.getName()); + for (OpResult result : op.getResults()) { + valueIDs.try_emplace(result, nextValueID++); + number(result.getType()); + } + number(op.getAttrDictionary()); + number(op.getLoc()); +} + +void IRNumberingState::number(OperationName opName) { + OpNameNumbering *&numbering = opNames[opName]; + if (numbering) + return; + DialectNumbering *dialectNumber = nullptr; + if (Dialect *dialect = opName.getDialect()) + dialectNumber = &numberDialect(dialect); + else + dialectNumber = &numberDialect(opName.getDialectNamespace()); + numbering = new (opNameAllocator.Allocate()) OpNameNumbering(opName); + dialectNumber->opNames.emplace_back(numbering); +} + +void IRNumberingState::number(Type type) { + auto it = types.insert({type, nullptr}); + if (!it.second) + return; + auto *numbering = new (typeAllocator.Allocate()) TypeNumbering(type); + it.first->second = numbering; + + // Check for OpaqueType, which is a dialect-specific type that didn't have a + // registered dialect when it got created. We don't want to encode this as the + // builtin OpaqueType, we want to encode it as if the dialect was actually + // loaded. + if (OpaqueType opaqueType = type.dyn_cast()) + numbering->dialect = &numberDialect(opaqueType.getDialectNamespace()); + else + numbering->dialect = &numberDialect(&type.getDialect()); + + numbering->dialect->types.push_back(numbering); +} diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(Analysis) add_subdirectory(AsmParser) +add_subdirectory(Bytecode) add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(IR) diff --git a/mlir/lib/Parser/CMakeLists.txt b/mlir/lib/Parser/CMakeLists.txt --- a/mlir/lib/Parser/CMakeLists.txt +++ b/mlir/lib/Parser/CMakeLists.txt @@ -6,5 +6,6 @@ LINK_LIBS PUBLIC MLIRAsmParser + MLIRBytecodeReader MLIRIR ) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -12,6 +12,7 @@ #include "mlir/Parser/Parser.h" #include "mlir/AsmParser/AsmParser.h" +#include "mlir/Bytecode/BytecodeReader.h" #include "llvm/Support/SourceMgr.h" using namespace mlir; @@ -25,6 +26,8 @@ sourceBuf->getBufferIdentifier(), /*line=*/0, /*column=*/0); } + if (isBytecode(*sourceBuf)) + return readBytecodeFile(*sourceBuf, block, config); return parseAsmSourceFile(sourceMgr, block, config); } diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt --- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt +++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt @@ -5,6 +5,7 @@ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Tools/mlir-opt LINK_LIBS PUBLIC + MLIRBytecodeWriter MLIRPass MLIRParser MLIRSupport diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" @@ -47,7 +48,8 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses, SourceMgr &sourceMgr, MLIRContext *context, - PassPipelineFn passManagerSetupFn) { + PassPipelineFn passManagerSetupFn, + bool emitBytecode) { DefaultTimingManager tm; applyDefaultTimingManagerCLOptions(tm); TimingScope timing = tm.getRootScope(); @@ -86,8 +88,12 @@ // Print the output. TimingScope outputTiming = timing.nest("Output"); - module->print(os); - os << '\n'; + if (emitBytecode) { + writeBytecodeToFile(module->getOperation(), os); + } else { + module->print(os); + os << '\n'; + } return success(); } @@ -97,8 +103,8 @@ processBuffer(raw_ostream &os, std::unique_ptr ownedBuffer, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, bool preloadDialectsInContext, - PassPipelineFn passManagerSetupFn, DialectRegistry ®istry, - llvm::ThreadPool *threadPool) { + bool emitBytecode, PassPipelineFn passManagerSetupFn, + DialectRegistry ®istry, llvm::ThreadPool *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); @@ -122,7 +128,7 @@ if (!verifyDiagnostics) { SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, - &context, passManagerSetupFn); + &context, passManagerSetupFn, emitBytecode); } SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context); @@ -131,7 +137,7 @@ // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context, - passManagerSetupFn); + passManagerSetupFn, emitBytecode); // Verify the diagnostic handler to make sure that each of the diagnostics // matched. @@ -144,7 +150,20 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext) { + bool preloadDialectsInContext, + bool emitBytecode) { + // Check to see if we are trying to output bytecode to a displayed stream. + // TODO: Do we need to provide a -f option like LLVM? Should we even + // warn/disable in this case? + if (emitBytecode && outputStream.is_displayed()) { + llvm::errs() + << "warning: Attempting to output a bytecode file to a displayed " + "stream.\n" + "This is inadvisable as it may cause display problems, disabling " + "bytecode output.\n\n"; + emitBytecode = false; + } + // The split-input-file mode is a very specific mode that slices the file // up into small pieces and checks each independently. // We use an explicit threadpool to avoid creating and joining/destroying @@ -163,8 +182,8 @@ raw_ostream &os) { return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics, verifyPasses, allowUnregisteredDialects, - preloadDialectsInContext, passManagerSetupFn, registry, - threadPool); + preloadDialectsInContext, emitBytecode, + passManagerSetupFn, registry, threadPool); }; return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, splitInputFile, /*insertMarkerInOutput=*/true); @@ -176,7 +195,8 @@ DialectRegistry ®istry, bool splitInputFile, bool verifyDiagnostics, bool verifyPasses, bool allowUnregisteredDialects, - bool preloadDialectsInContext) { + bool preloadDialectsInContext, + bool emitBytecode) { auto passManagerSetupFn = [&](PassManager &pm) { auto errorHandler = [&](const Twine &msg) { emitError(UnknownLoc::get(pm.getContext())) << msg; @@ -186,7 +206,8 @@ }; return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn, registry, splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext); + allowUnregisteredDialects, preloadDialectsInContext, + emitBytecode); } LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName, @@ -224,6 +245,10 @@ "show-dialects", cl::desc("Print the list of registered dialects"), cl::init(false)); + static cl::opt emitBytecode( + "emit-bytecode", cl::desc("Emit bytecode when generating output"), + cl::init(false)); + InitLLVM y(argc, argv); // Register any command line options. @@ -268,7 +293,8 @@ if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects, preloadDialectsInContext))) + allowUnregisteredDialects, preloadDialectsInContext, + emitBytecode))) return failure(); // Keep the output file if the invocation of MlirOptMain was successful. diff --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/general.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt -allow-unregistered-dialect -emit-bytecode %s | mlir-opt -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: "bytecode.test1" +// CHECK-NEXT: "bytecode.empty"() : () -> () +// CHECK-NEXT: "bytecode.attributes"() {attra = 10 : i64, attrb = #bytecode.attr} : () -> () +// CHECK-NEXT: %[[RESULTS:.*]]:3 = "bytecode.results"() : () -> (i32, i64, i32) +// CHECK-NEXT: "bytecode.operands"(%[[RESULTS]]#0, %[[RESULTS]]#1, %[[RESULTS]]#2) : (i32, i64, i32) -> () +// CHECK-NEXT: "bytecode.branch"()[^[[BLOCK:.*]]] : () -> () +// CHECK-NEXT: ^[[BLOCK]](%[[ARG0:.*]]: i32, %[[ARG1:.*]]: !bytecode.int, %[[ARG2:.*]]: !pdl.operation): +// CHECK-NEXT: "bytecode.regions"() ({ +// CHECK-NEXT: "bytecode.operands"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (i32, !bytecode.int, !pdl.operation) -> () +// CHECK-NEXT: "bytecode.return"() : () -> () +// CHECK-NEXT: }) : () -> () +// CHECK-NEXT: "bytecode.return"() : () -> () +// CHECK-NEXT: }) : () -> () + +"bytecode.test1"() ({ + "bytecode.empty"() : () -> () + "bytecode.attributes"() {attra = 10, attrb = #bytecode.attr} : () -> () + %results:3 = "bytecode.results"() : () -> (i32, i64, i32) + "bytecode.operands"(%results#0, %results#1, %results#2) : (i32, i64, i32) -> () + "bytecode.branch"()[^secondBlock] : () -> () + +^secondBlock(%arg1: i32, %arg2: !bytecode.int, %arg3: !pdl.operation): + "bytecode.regions"() ({ + "bytecode.operands"(%arg1, %arg2, %arg3) : (i32, !bytecode.int, !pdl.operation) -> () + "bytecode.return"() : () -> () + }) : () -> () + "bytecode.return"() : () -> () +}) : () -> ()