diff --git a/mlir/docs/BytecodeFormat.md b/mlir/docs/BytecodeFormat.md --- a/mlir/docs/BytecodeFormat.md +++ b/mlir/docs/BytecodeFormat.md @@ -65,6 +65,22 @@ 00000000: 64 value bits, the encoding uses 9 bytes ``` +##### Signed Variable-Width Integers + +Signed variable width integer values are encoded in a similar fashion to +[varints](#variable-width-integers), but employ +[zigzag encoding](https://en.wikipedia.org/wiki/Variable-length_quantity#Zigzag_encoding). +This encoding uses the low bit of the value to indicate the sign, which allows +for more efficiently encoding negative numbers. If a negative value were encoded +using a normal [varint](#variable-width-integers), it would be treated as an +extremely large unsigned value. Using zigzag encoding allows for a smaller +number of active bits in the value, leading to a smaller encoding. Below is the +basic computation for generating a zigzag encoding: + +``` +(value << 1) ^ (value >> 63) +``` + #### Strings Strings are blobs of characters with an associated length. diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -78,14 +78,14 @@ return readList(attrs, [this](T &attr) { return readAttribute(attr); }); } template - LogicalResult parseAttribute(T &result) { + LogicalResult readAttribute(T &result) { Attribute baseResult; - if (failed(parseAttribute(baseResult))) + if (failed(readAttribute(baseResult))) return failure(); if ((result = baseResult.dyn_cast())) return success(); - return emitError() << "expected attribute of type: " - << llvm::getTypeName() << ", but got: " << baseResult; + return emitError() << "expected " << llvm::getTypeName() + << ", but got: " << baseResult; } /// Read a reference to the given type. @@ -94,6 +94,16 @@ LogicalResult readTypes(SmallVectorImpl &types) { return readList(types, [this](T &type) { return readType(type); }); } + template + LogicalResult readType(T &result) { + Type baseResult; + if (failed(readType(baseResult))) + return failure(); + if ((result = baseResult.dyn_cast())) + return success(); + return emitError() << "expected " << llvm::getTypeName() + << ", but got: " << baseResult; + } //===--------------------------------------------------------------------===// // Primitives @@ -103,6 +113,14 @@ // TODO: Add a signed variant when necessary. virtual LogicalResult readVarInt(uint64_t &result) = 0; + /// Read an APInt that is known to have been encoded with the given width. + virtual FailureOr readKnownWidthAPInt(unsigned bitWidth) = 0; + + /// Read an APFloat that is known to have been encoded with the given + /// semantics. + virtual FailureOr + readKnownSemanticsAPFloat(const llvm::fltSemantics &semantics) = 0; + /// Read a string from the bytecode. virtual LogicalResult readString(StringRef &result) = 0; }; @@ -156,6 +174,16 @@ // TODO: Add a signed variant when necessary. virtual void writeVarInt(uint64_t value) = 0; + /// Write an APInt to the bytecode stream whose bitwidth will be known + /// externally at read time. This method is useful for encoding APInt values + /// when the width is known via external means, such as via a type. + virtual void writeKnownWidthAPInt(const APInt &value) = 0; + + /// Write an APFloat to the bytecode stream whose semantics will be known + /// externally at read time. This method is useful for encoding APFloat values + /// when the semantics are known via external means, such as via a type. + virtual void writeKnownSemanticsAPFloat(const APFloat &value) = 0; + /// Write a string to the bytecode, which is owned by the caller and is /// guaranteed to not die before the end of the bytecode process. This should /// only be called if such a guarantee can be made, such as when the string is diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -128,6 +128,16 @@ return parseMultiByteVarInt(result); } + /// Parse a signed variable length encoded integer from the byte stream. A + /// signed varint is encoded as a normal varint with zigzag encoding applied, + /// i.e. the low bit of the value is used to indicate the sign. + LogicalResult parseSignedVarInt(uint64_t &result) { + if (failed(parseVarInt(result))) + return failure(); + result = (result >> 1) ^ (~(result & 1) + 1); + return success(); + } + /// Parse a variable length encoded integer whose low bit is used to encode an /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { @@ -511,6 +521,44 @@ return reader.parseVarInt(result); } + FailureOr readKnownWidthAPInt(unsigned bitWidth) override { + // Small values are encoded using a single byte. + if (bitWidth <= 8) { + uint8_t value; + if (failed(reader.parseByte(value))) + return failure(); + return APInt(bitWidth, value); + } + + // Large values up to 64 bits are encoded using a single varint. + if (bitWidth <= 64) { + uint64_t value; + if (failed(reader.parseSignedVarInt(value))) + return failure(); + return APInt(bitWidth, value); + } + + // Otherwise, for really big values we encode the array of active words in + // the value. + uint64_t numActiveWords; + if (failed(reader.parseVarInt(numActiveWords))) + return failure(); + SmallVector words(numActiveWords); + for (uint64_t i = 0; i < numActiveWords; ++i) + if (failed(reader.parseSignedVarInt(words[i]))) + return failure(); + return APInt(bitWidth, words); + } + + FailureOr + readKnownSemanticsAPFloat(const llvm::fltSemantics &semantics) override { + FailureOr intVal = + readKnownWidthAPInt(APFloat::getSizeInBits(semantics)); + if (failed(intVal)) + return failure(); + return APFloat(semantics, *intVal); + } + LogicalResult readString(StringRef &result) override { return stringReader.parseString(reader, result); } diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -85,6 +85,14 @@ emitMultiByteVarInt(value); } + /// Emit a signed variable length integer. Signed varints are encoded using + /// a varint with zigzag encoding, meaning that we use the low bit of the + /// value to indicate the sign of the value. This allows for more efficient + /// encoding of negative values by limiting the number of active bits + void emitSignedVarInt(uint64_t value) { + emitVarInt((value << 1) ^ (uint64_t)((int64_t)value >> 63)); + } + /// Emit a variable length integer whose low bit is used to encode the /// provided flag, i.e. encoded as: (value << 1) | (flag ? 1 : 0). void emitVarIntWithFlag(uint64_t value, bool flag) { @@ -384,6 +392,33 @@ void writeVarInt(uint64_t value) override { emitter.emitVarInt(value); } + void writeKnownWidthAPInt(const APInt &value) override { + size_t bitWidth = value.getBitWidth(); + + // If the value is a single byte, just emit it directly without going + // through a varint. + if (bitWidth <= 8) + return emitter.emitByte(value.getLimitedValue()); + + // If the value fits within a single varint, emit it directly. + if (bitWidth <= 64) + return emitter.emitSignedVarInt(value.getLimitedValue()); + + // Otherwise, we need to encode a variable number of active words. We use + // active words instead of the number of total words under the observation + // that smaller values will be more common. + unsigned numActiveWords = value.getActiveWords(); + emitter.emitVarInt(numActiveWords); + + const uint64_t *rawValueData = value.getRawData(); + for (unsigned i = 0; i < numActiveWords; ++i) + emitter.emitSignedVarInt(rawValueData[i]); + } + + void writeKnownSemanticsAPFloat(const APFloat &value) override { + writeKnownWidthAPInt(value.bitcastToAPInt()); + } + void writeOwnedString(StringRef str) override { emitter.emitVarInt(stringSection.insert(str)); } diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -27,6 +27,8 @@ /// Stubbed out methods that are not used for numbering. void writeVarInt(uint64_t) override {} + void writeKnownWidthAPInt(const APInt &value) override {} + void writeKnownSemanticsAPFloat(const APFloat &value) override {} void writeOwnedString(StringRef) override { // TODO: It might be nice to prenumber strings and sort by the number of // references. This could potentially be useful for optimizing things like diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -38,9 +38,49 @@ kDictionaryAttr = 1, /// StringAttr { - /// string + /// value: string /// } kStringAttr = 2, + + /// TypedStringAttr { + /// value: string, + /// type: Type + /// } + /// A variant of StringAttr with a type. + kTypedStringAttr = 3, + + /// FlatSymbolRefAttr { + /// rootReference: StringAttr + /// } + /// A variant of SymbolRefAttr with no leaf references. + kFlatSymbolRefAttr = 4, + + /// SymbolRefAttr { + /// rootReference: StringAttr, + /// leafReferences: FlatSymbolRefAttr[] + /// } + kSymbolRefAttr = 5, + + /// TypeAttr { + /// value: Type + /// } + kTypeAttr = 6, + + /// UnitAttr { + /// } + kUnitAttr = 7, + + /// IntegerAttr { + /// type: Type + /// value: APInt, + /// } + kIntegerAttr = 8, + + /// FloatAttr { + /// type: FloatType + /// value: APFloat + /// } + kFloatAttr = 9, }; /// This enum contains marker codes used to indicate which type is currently @@ -86,13 +126,22 @@ Attribute readAttribute(DialectBytecodeReader &reader) const override; ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const; DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const; - StringAttr readStringAttr(DialectBytecodeReader &reader) const; + FloatAttr readFloatAttr(DialectBytecodeReader &reader) const; + IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const; + StringAttr readStringAttr(DialectBytecodeReader &reader, bool hasType) const; + SymbolRefAttr readSymbolRefAttr(DialectBytecodeReader &reader, + bool hasNestedRefs) const; + TypeAttr readTypeAttr(DialectBytecodeReader &reader) const; LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const override; void write(ArrayAttr attr, DialectBytecodeWriter &writer) const; void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const; + void write(IntegerAttr attr, DialectBytecodeWriter &writer) const; + void write(FloatAttr attr, DialectBytecodeWriter &writer) const; void write(StringAttr attr, DialectBytecodeWriter &writer) const; + void write(SymbolRefAttr attr, DialectBytecodeWriter &writer) const; + void write(TypeAttr attr, DialectBytecodeWriter &writer) const; //===--------------------------------------------------------------------===// // Types @@ -126,7 +175,21 @@ case builtin_encoding::kDictionaryAttr: return readDictionaryAttr(reader); case builtin_encoding::kStringAttr: - return readStringAttr(reader); + return readStringAttr(reader, /*hasType=*/false); + case builtin_encoding::kTypedStringAttr: + return readStringAttr(reader, /*hasType=*/true); + case builtin_encoding::kFlatSymbolRefAttr: + return readSymbolRefAttr(reader, /*hasNestedRefs=*/false); + case builtin_encoding::kSymbolRefAttr: + return readSymbolRefAttr(reader, /*hasNestedRefs=*/true); + case builtin_encoding::kTypeAttr: + return readTypeAttr(reader); + case builtin_encoding::kUnitAttr: + return UnitAttr::get(getContext()); + case builtin_encoding::kIntegerAttr: + return readIntegerAttr(reader); + case builtin_encoding::kFloatAttr: + return readFloatAttr(reader); default: reader.emitError() << "unknown builtin attribute code: " << code; return Attribute(); @@ -157,12 +220,75 @@ return DictionaryAttr::get(getContext(), attrs); } -StringAttr BuiltinDialectBytecodeInterface::readStringAttr( +FloatAttr BuiltinDialectBytecodeInterface::readFloatAttr( + DialectBytecodeReader &reader) const { + FloatType type; + if (failed(reader.readType(type))) + return FloatAttr(); + FailureOr value = + reader.readKnownSemanticsAPFloat(type.getFloatSemantics()); + if (failed(value)) + return FloatAttr(); + return FloatAttr::get(type, *value); +} + +IntegerAttr BuiltinDialectBytecodeInterface::readIntegerAttr( DialectBytecodeReader &reader) const { + Type type; + if (failed(reader.readType(type))) + return IntegerAttr(); + + // Extract the value storage width from the type. + unsigned bitWidth; + if (auto intType = type.dyn_cast()) { + bitWidth = intType.getWidth(); + } else if (type.isa()) { + bitWidth = IndexType::kInternalStorageBitWidth; + } else { + reader.emitError() + << "expected integer or index type for IntegerAttr, but got: " << type; + return IntegerAttr(); + } + + FailureOr value = reader.readKnownWidthAPInt(bitWidth); + if (failed(value)) + return IntegerAttr(); + return IntegerAttr::get(type, *value); +} + +StringAttr +BuiltinDialectBytecodeInterface::readStringAttr(DialectBytecodeReader &reader, + bool hasType) const { StringRef string; if (failed(reader.readString(string))) return StringAttr(); - return StringAttr::get(getContext(), string); + + // Read the type if present. + Type type; + if (!hasType) + type = NoneType::get(getContext()); + else if (failed(reader.readType(type))) + return StringAttr(); + return StringAttr::get(string, type); +} + +SymbolRefAttr BuiltinDialectBytecodeInterface::readSymbolRefAttr( + DialectBytecodeReader &reader, bool hasNestedRefs) const { + StringAttr rootReference; + if (failed(reader.readAttribute(rootReference))) + return SymbolRefAttr(); + SmallVector nestedReferences; + if (hasNestedRefs && failed(reader.readAttributes(nestedReferences))) + return SymbolRefAttr(); + return SymbolRefAttr::get(rootReference, nestedReferences); +} + +TypeAttr BuiltinDialectBytecodeInterface::readTypeAttr( + DialectBytecodeReader &reader) const { + Type type; + if (failed(reader.readType(type))) + return TypeAttr(); + return TypeAttr::get(type); } //===----------------------------------------------------------------------===// @@ -171,10 +297,15 @@ LogicalResult BuiltinDialectBytecodeInterface::writeAttribute( Attribute attr, DialectBytecodeWriter &writer) const { return TypeSwitch(attr) - .Case([&](auto attr) { + .Case([&](auto attr) { write(attr, writer); return success(); }) + .Case([&](UnitAttr) { + writer.writeVarInt(builtin_encoding::kUnitAttr); + return success(); + }) .Default([&](Attribute) { return failure(); }); } @@ -193,12 +324,52 @@ }); } +void BuiltinDialectBytecodeInterface::write( + FloatAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kFloatAttr); + writer.writeType(attr.getType()); + writer.writeKnownSemanticsAPFloat(attr.getValue()); +} + +void BuiltinDialectBytecodeInterface::write( + IntegerAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kIntegerAttr); + writer.writeType(attr.getType()); + writer.writeKnownWidthAPInt(attr.getValue()); +} + void BuiltinDialectBytecodeInterface::write( StringAttr attr, DialectBytecodeWriter &writer) const { + // We only encode the type if it isn't NoneType, which is significantly less + // common. + Type type = attr.getType(); + if (!type.isa()) { + writer.writeVarInt(builtin_encoding::kTypedStringAttr); + writer.writeOwnedString(attr.getValue()); + writer.writeType(type); + return; + } writer.writeVarInt(builtin_encoding::kStringAttr); writer.writeOwnedString(attr.getValue()); } +void BuiltinDialectBytecodeInterface::write( + SymbolRefAttr attr, DialectBytecodeWriter &writer) const { + ArrayRef nestedRefs = attr.getNestedReferences(); + writer.writeVarInt(nestedRefs.empty() ? builtin_encoding::kFlatSymbolRefAttr + : builtin_encoding::kSymbolRefAttr); + + writer.writeAttribute(attr.getRootReference()); + if (!nestedRefs.empty()) + writer.writeAttributes(nestedRefs); +} + +void BuiltinDialectBytecodeInterface::write( + TypeAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kTypeAttr); + writer.writeType(attr.getValue()); +} + //===----------------------------------------------------------------------===// // Types: Reader diff --git a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir --- a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir +++ b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir @@ -9,8 +9,48 @@ bytecode.array = [unit] } {} +// CHECK-LABEL: @TestFloat +module @TestFloat attributes { + // CHECK: bytecode.float = 1.000000e+01 : f64 + // CHECK: bytecode.float1 = 0.10000{{.*}} : f80 + // CHECK: bytecode.float2 = 0.10000{{.*}} : f128 + // CHECK: bytecode.float3 = -5.000000e-01 : bf16 + bytecode.float = 10.0 : f64, + bytecode.float1 = 0.1 : f80, + bytecode.float2 = 0.1 : f128, + bytecode.float3 = -0.5 : bf16 +} {} + +// CHECK-LABEL: @TestInt +module @TestInt attributes { + // CHECK: bytecode.int = false + // CHECK: bytecode.int1 = -1 : i8 + // CHECK: bytecode.int2 = 800 : ui64 + // CHECK: bytecode.int3 = 90000000000000000300000000000000000001 : i128 + bytecode.int = false, + bytecode.int1 = -1 : i8, + bytecode.int2 = 800 : ui64, + bytecode.int3 = 90000000000000000300000000000000000001 : i128 +} {} + // CHECK-LABEL: @TestString module @TestString attributes { // CHECK: bytecode.string = "hello" - bytecode.string = "hello" + // CHECK: bytecode.string2 = "hello" : i32 + bytecode.string = "hello", + bytecode.string2 = "hello" : i32 +} {} + +// CHECK-LABEL: @TestSymbolRef +module @TestSymbolRef attributes { + // CHECK: bytecode.ref = @foo + // CHECK: bytecode.ref2 = @foo::@bar::@foo + bytecode.ref = @foo, + bytecode.ref2 = @foo::@bar::@foo +} {} + +// CHECK-LABEL: @TestType +module @TestType attributes { + // CHECK: bytecode.type = i178 + bytecode.type = i178 } {}