diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -154,6 +154,7 @@ dialect_section { numDialects: varint, dialectNames: varint[], + numTotalOpNames: varint, opNames: op_name_group[] } @@ -444,8 +445,8 @@ } block_argument { - typeIndex: varint, - location: varint + typeAndLocation: varint, // (type << 1) | (hasLocation) + location: varint? // Optional, else unknown location } ``` diff --git a/mlir/include/mlir/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h --- a/mlir/include/mlir/Bytecode/Encoding.h +++ b/mlir/include/mlir/Bytecode/Encoding.h @@ -29,7 +29,7 @@ kMinSupportedVersion = 0, /// The current bytecode version. - kVersion = 3, + kVersion = 4, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -1603,6 +1603,14 @@ opNames.emplace_back(dialect, opName); return success(); }; + // Avoid re-allocation in bytecode version > 3 where the number of ops are + // known. + if (version > 3) { + uint64_t numOps; + if (failed(sectionReader.parseVarInt(numOps))) + return failure(); + opNames.reserve(numOps); + } while (!sectionReader.empty()) if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) return failure(); @@ -2175,13 +2183,25 @@ argTypes.reserve(numArgs); argLocs.reserve(numArgs); + Location unknownLoc = UnknownLoc::get(config.getContext()); while (numArgs--) { Type argType; - LocationAttr argLoc; - if (failed(parseType(reader, argType)) || - failed(parseAttribute(reader, argLoc))) - return failure(); - + LocationAttr argLoc = unknownLoc; + if (version > 3) { + // Parse the type with hasLoc flag to determine if it has type. + uint64_t typeIdx; + bool hasLoc; + if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || + !(argType = attrTypeReader.resolveType(typeIdx))) + return failure(); + if (hasLoc && failed(parseAttribute(reader, argLoc))) + return failure(); + } else { + // All args has type and location. + if (failed(parseType(reader, argType)) || + failed(parseAttribute(reader, argLoc))) + return failure(); + } argTypes.push_back(argType); argLocs.push_back(argLoc); } diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -585,6 +585,9 @@ std::move(versionEmitter)); } + if (config.bytecodeVersion > 3) + dialectEmitter.emitVarInt(size(numberingState.getOpNames())); + // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect())); @@ -670,8 +673,16 @@ if (hasArgs) { emitter.emitVarInt(args.size()); for (BlockArgument arg : args) { - emitter.emitVarInt(numberingState.getNumber(arg.getType())); - emitter.emitVarInt(numberingState.getNumber(arg.getLoc())); + Location argLoc = arg.getLoc(); + if (config.bytecodeVersion > 3) { + emitter.emitVarIntWithFlag(numberingState.getNumber(arg.getType()), + !isa(argLoc)); + if (!isa(argLoc)) + emitter.emitVarInt(numberingState.getNumber(argLoc)); + } else { + emitter.emitVarInt(numberingState.getNumber(arg.getType())); + emitter.emitVarInt(numberingState.getNumber(argLoc)); + } } if (config.bytecodeVersion > 2) { uint64_t maskOffset = emitter.size(); @@ -755,7 +766,7 @@ for (Region ®ion : op->getRegions()) { // If the region is not isolated from above, or we are emitting bytecode - // targetting version <2, we don't use a section. + // targeting version <2, we don't use a section. if (!isIsolatedFromAbove || config.bytecodeVersion < 2) { writeRegion(emitter, ®ion); continue; diff --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir --- a/mlir/test/Bytecode/general.mlir +++ b/mlir/test/Bytecode/general.mlir @@ -32,7 +32,7 @@ } "bytecode.branch"()[^secondBlock] : () -> () -^secondBlock(%arg1: i32, %arg2: !bytecode.int, %arg3: !pdl.operation): +^secondBlock(%arg1: i32 loc(unknown), %arg2: !bytecode.int, %arg3: !pdl.operation loc(unknown)): "bytecode.regions"() ({ "bytecode.operands"(%arg1, %arg2, %arg3) : (i32, !bytecode.int, !pdl.operation) -> () "bytecode.return"() : () -> () diff --git a/mlir/test/Bytecode/invalid/invalid-structure.mlir b/mlir/test/Bytecode/invalid/invalid-structure.mlir --- a/mlir/test/Bytecode/invalid/invalid-structure.mlir +++ b/mlir/test/Bytecode/invalid/invalid-structure.mlir @@ -9,7 +9,7 @@ //===--------------------------------------------------------------------===// // RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION -// VERSION: bytecode version 127 is newer than the current version 3 +// VERSION: bytecode version 127 is newer than the current version //===--------------------------------------------------------------------===// // Producer