diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -154,6 +154,7 @@ dialect_section { numDialects: varint, dialectNames: varint[], + numTotalOpNames: varint, opNames: op_name_group[] } @@ -405,8 +406,8 @@ } block_argument { - typeIndex: varint, - location: varint + typeAndLocation: varint, // (type << 1) | (hasLocation) + location: varint? // Optional, else unknown location } ``` diff --git a/mlir/lib/Bytecode/Encoding.h b/mlir/lib/Bytecode/Encoding.h --- a/mlir/lib/Bytecode/Encoding.h +++ b/mlir/lib/Bytecode/Encoding.h @@ -27,7 +27,7 @@ kMinSupportedVersion = 0, /// The current bytecode version. - kVersion = 1, + kVersion = 3, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -1443,6 +1443,14 @@ opNames.emplace_back(dialect, opName); return success(); }; + // Avoid re-allocation in bytecode version >= 3 where the number of ops are + // known. + if (version >= 3) { + uint64_t numOps; + if (failed(sectionReader.parseVarInt(numOps))) + return failure(); + opNames.reserve(numOps); + } while (!sectionReader.empty()) if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) return failure(); @@ -1756,15 +1764,25 @@ argTypes.reserve(numArgs); argLocs.reserve(numArgs); + Location unknownLoc = UnknownLoc::get(config.getContext()); while (numArgs--) { Type argType; - LocationAttr argLoc; - if (failed(parseType(reader, argType)) || - failed(parseAttribute(reader, argLoc))) - return failure(); + LocationAttr argLoc = nullptr; + uint64_t typeIdx; + bool hasLoc = true; + if (version >= 3) { + if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) || + (argType = attrTypeReader.resolveType(typeIdx)) || + (hasLoc && failed(parseAttribute(reader, argLoc)))) + return failure(); + } else { + if (failed(parseType(reader, argType)) || + (hasLoc && failed(parseAttribute(reader, argLoc)))) + return failure(); + } argTypes.push_back(argType); - argLocs.push_back(argLoc); + argLocs.push_back(argLoc ? Location(argLoc) : unknownLoc); } block->addArguments(argTypes, argLocs); return defineValues(reader, block->getArguments()); diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -579,6 +579,9 @@ std::move(versionEmitter)); } + if (config.bytecodeVersion >= 3) + dialectEmitter.emitVarInt(size(numberingState.getOpNames())); + // Emit the referenced operation names grouped by dialect. auto emitOpName = [&](OpNameNumbering &name) { dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect())); @@ -664,8 +667,16 @@ if (hasArgs) { emitter.emitVarInt(args.size()); for (BlockArgument arg : args) { - emitter.emitVarInt(numberingState.getNumber(arg.getType())); - emitter.emitVarInt(numberingState.getNumber(arg.getLoc())); + Location argLoc = arg.getLoc(); + if (config.bytecodeVersion >= 3) { + emitter.emitVarIntWithFlag(numberingState.getNumber(arg.getType()), + !isa(argLoc)); + if (!isa(argLoc)) + emitter.emitVarInt(numberingState.getNumber(argLoc)); + } else { + emitter.emitVarInt(numberingState.getNumber(arg.getType())); + emitter.emitVarInt(numberingState.getNumber(argLoc)); + } } } diff --git a/mlir/test/Bytecode/general.mlir b/mlir/test/Bytecode/general.mlir --- a/mlir/test/Bytecode/general.mlir +++ b/mlir/test/Bytecode/general.mlir @@ -32,7 +32,7 @@ } "bytecode.branch"()[^secondBlock] : () -> () -^secondBlock(%arg1: i32, %arg2: !bytecode.int, %arg3: !pdl.operation): +^secondBlock(%arg1: i32 loc(unknown), %arg2: !bytecode.int, %arg3: !pdl.operation loc(unknown)): "bytecode.regions"() ({ "bytecode.operands"(%arg1, %arg2, %arg3) : (i32, !bytecode.int, !pdl.operation) -> () "bytecode.return"() : () -> ()