diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -240,6 +240,69 @@ return resolveEntry(reader, entries, entryIdx, entry, entryStr); } +//===----------------------------------------------------------------------===// +// StringSectionReader +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to read references to the string section from the +/// bytecode. +class StringSectionReader { +public: + /// Initialize the string section reader with the given section data. + LogicalResult initialize(Location fileLoc, ArrayRef sectionData); + + /// Parse a shared string from the string section. The shared string is + /// encoded using an index to a corresponding string in the string section. + LogicalResult parseString(EncodingReader &reader, StringRef &result) { + return parseEntry(reader, strings, result, "string"); + } + +private: + /// The table of strings referenced within the bytecode file. + SmallVector strings; +}; +} // namespace + +LogicalResult StringSectionReader::initialize(Location fileLoc, + ArrayRef sectionData) { + EncodingReader stringReader(sectionData, fileLoc); + + // Parse the number of strings in the section. + uint64_t numStrings; + if (failed(stringReader.parseVarInt(numStrings))) + return failure(); + strings.resize(numStrings); + + // Parse each of the strings. The sizes of the strings are encoded in reverse + // order, so that's the order we populate the table. + size_t stringDataEndOffset = sectionData.size(); + for (StringRef &string : llvm::reverse(strings)) { + uint64_t stringSize; + if (failed(stringReader.parseVarInt(stringSize))) + return failure(); + if (stringDataEndOffset < stringSize) { + return stringReader.emitError( + "string size exceeds the available data size"); + } + + // Extract the string from the data, dropping the null character. + size_t stringOffset = stringDataEndOffset - stringSize; + string = StringRef( + reinterpret_cast(sectionData.data() + stringOffset), + stringSize - 1); + stringDataEndOffset = stringOffset; + } + + // Check that the only remaining data was for the strings, i.e. the reader + // should be at the same offset as the first string. + if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { + return stringReader.emitError("unexpected trailing data between the " + "offsets for strings and their data"); + } + return success(); +} + //===----------------------------------------------------------------------===// // BytecodeDialect //===----------------------------------------------------------------------===// @@ -595,17 +658,6 @@ LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); - //===--------------------------------------------------------------------===// - // String Section - - LogicalResult parseStringSection(ArrayRef sectionData); - - /// Parse a shared string from the string section. The shared string is - /// encoded using an index to a corresponding string in the string section. - LogicalResult parseSharedString(EncodingReader &reader, StringRef &result) { - return parseEntry(reader, strings, result, "string"); - } - //===--------------------------------------------------------------------===// // Value Processing @@ -667,7 +719,7 @@ SmallVector opNames; /// The table of strings referenced within the bytecode file. - SmallVector strings; + StringSectionReader stringReader; /// The current set of available IR value scopes. std::vector valueScopes; @@ -726,7 +778,8 @@ } // Process the string section first. - if (failed(parseStringSection(*sectionDatas[bytecode::Section::kString]))) + if (failed(stringReader.initialize( + fileLoc, *sectionDatas[bytecode::Section::kString]))) return failure(); // Process the dialect section. @@ -777,13 +830,13 @@ // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) - if (failed(parseSharedString(sectionReader, dialects[i].name))) + if (failed(stringReader.parseString(sectionReader, dialects[i].name))) return failure(); // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { StringRef opName; - if (failed(parseSharedString(sectionReader, opName))) + if (failed(stringReader.parseString(sectionReader, opName))) return failure(); opNames.emplace_back(dialect, opName); return success(); @@ -1091,51 +1144,6 @@ return defineValues(reader, block->getArguments()); } -//===----------------------------------------------------------------------===// -// String Section - -LogicalResult -BytecodeReader::parseStringSection(ArrayRef sectionData) { - EncodingReader stringReader(sectionData, fileLoc); - - // Parse the number of strings in the section. - uint64_t numStrings; - if (failed(stringReader.parseVarInt(numStrings))) - return failure(); - strings.resize(numStrings); - - // Parse each of the strings. The sizes of the strings are encoded in reverse - // order, so that's the order we populate the table. - size_t stringDataEndOffset = sectionData.size(); - size_t totalStringDataSize = 0; - for (StringRef &string : llvm::reverse(strings)) { - uint64_t stringSize; - if (failed(stringReader.parseVarInt(stringSize))) - return failure(); - if (stringDataEndOffset < stringSize) { - return stringReader.emitError( - "string size exceeds the available data size"); - } - - // Extract the string from the data, dropping the null character. - size_t stringOffset = stringDataEndOffset - stringSize; - string = StringRef( - reinterpret_cast(sectionData.data() + stringOffset), - stringSize - 1); - stringDataEndOffset = stringOffset; - - // Update the total string data size. - totalStringDataSize += stringSize; - } - - // Check that the only remaining data was for the strings - if (stringReader.size() != totalStringDataSize) { - return stringReader.emitError("unexpected trailing data between the " - "offsets for strings and their data"); - } - return success(); -} - //===----------------------------------------------------------------------===// // Value Processing diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -196,6 +196,41 @@ emitBytes({reinterpret_cast(&value), sizeof(value)}); } +//===----------------------------------------------------------------------===// +// StringSectionBuilder +//===----------------------------------------------------------------------===// + +namespace { +/// This class is used to simplify the process of emitting the string section. +class StringSectionBuilder { +public: + /// Add the given string to the string section, and return the index of the + /// string within the section. + size_t insert(StringRef str) { + auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); + return it.first->second; + } + + /// Write the current set of strings to the given emitter. + void write(EncodingEmitter &emitter) { + emitter.emitVarInt(strings.size()); + + // Emit the sizes in reverse order, so that we don't need to backpatch an + // offset to the string data or have a separate section. + for (const auto &it : llvm::reverse(strings)) + emitter.emitVarInt(it.first.size() + 1); + // Emit the string data itself. + for (const auto &it : strings) + emitter.emitNulTerminatedString(it.first.val()); + } + +private: + /// A set of strings referenced within the bytecode. The value of the map is + /// unused. + llvm::MapVector strings; +}; +} // namespace + //===----------------------------------------------------------------------===// // Bytecode Writer //===----------------------------------------------------------------------===// @@ -232,19 +267,14 @@ void writeStringSection(EncodingEmitter &emitter); - /// Get the number for the given shared string, that is contained within the - /// string section. - size_t getSharedStringNumber(StringRef str); - //===--------------------------------------------------------------------===// // Fields + /// The builder used for the string section. + StringSectionBuilder stringSection; + /// The IR numbering state generated for the root operation. IRNumberingState numberingState; - - /// A set of strings referenced within the bytecode. The value of the map is - /// unused. - llvm::MapVector strings; }; } // namespace @@ -314,11 +344,11 @@ auto dialects = numberingState.getDialects(); dialectEmitter.emitVarInt(llvm::size(dialects)); for (DialectNumbering &dialect : dialects) - dialectEmitter.emitVarInt(getSharedStringNumber(dialect.name)); + dialectEmitter.emitVarInt(stringSection.insert(dialect.name)); // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { - dialectEmitter.emitVarInt(getSharedStringNumber(name.name.stripDialect())); + dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect())); }; writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName); @@ -491,24 +521,10 @@ void BytecodeWriter::writeStringSection(EncodingEmitter &emitter) { EncodingEmitter stringEmitter; - stringEmitter.emitVarInt(strings.size()); - - // Emit the sizes in reverse order, so that we don't need to backpatch an - // offset to the string data or have a separate section. - for (const auto &it : llvm::reverse(strings)) - stringEmitter.emitVarInt(it.first.size() + 1); - // Emit the string data itself. - for (const auto &it : strings) - stringEmitter.emitNulTerminatedString(it.first.val()); - + stringSection.write(stringEmitter); emitter.emitSection(bytecode::Section::kString, std::move(stringEmitter)); } -size_t BytecodeWriter::getSharedStringNumber(StringRef str) { - auto it = strings.insert({llvm::CachedHashStringRef(str), strings.size()}); - return it.first->second; -} - //===----------------------------------------------------------------------===// // Entry Points //===----------------------------------------------------------------------===//