diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -143,6 +143,9 @@ /// Read a string from the bytecode. virtual LogicalResult readString(StringRef &result) = 0; + /// Read a blob from the bytecode. + virtual LogicalResult readBlob(ArrayRef &result) = 0; + private: /// Read a handle to a dialect resource. virtual FailureOr readResourceHandle() = 0; @@ -225,6 +228,11 @@ /// only be called if such a guarantee can be made, such as when the string is /// owned by an attribute or type. virtual void writeOwnedString(StringRef str) = 0; + + /// Write a blob to the bytecode, which is owned by the caller and is + /// guaranteed to not die before the end of the bytecode process. The blob is + /// written as-is, with no additional compression or compaction. + virtual void writeOwnedBlob(ArrayRef blob) = 0; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -887,6 +887,17 @@ return stringReader.parseString(reader, result); } + LogicalResult readBlob(ArrayRef &result) override { + uint64_t dataSize; + ArrayRef data; + if (failed(reader.parseVarInt(dataSize)) || + failed(reader.parseBytes(dataSize, data))) + return failure(); + result = llvm::makeArrayRef(reinterpret_cast(data.data()), + data.size()); + return success(); + } + private: AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -543,6 +543,12 @@ emitter.emitVarInt(stringSection.insert(str)); } + void writeOwnedBlob(ArrayRef blob) override { + emitter.emitVarInt(blob.size()); + emitter.emitOwnedBlob(ArrayRef( + reinterpret_cast(blob.data()), blob.size())); + } + private: EncodingEmitter &emitter; IRNumberingState &numberingState; diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -39,6 +39,7 @@ // references. This could potentially be useful for optimizing things like // file locations. } + void writeOwnedBlob(ArrayRef blob) override {} /// The parent numbering state that is populated by this writer. IRNumberingState &state; diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -123,6 +123,32 @@ /// handle: ResourceHandle /// } kDenseResourceElementsAttr = 16, + + /// DenseArrayAttr { + /// type: RankedTensorType, + /// data: blob + /// } + kDenseArrayAttr = 17, + + /// DenseIntOrFPElementsAttr { + /// type: ShapedType, + /// data: blob + /// } + kDenseIntOrFPElementsAttr = 18, + + /// DenseStringElementsAttr { + /// type: ShapedType, + /// isSplat: varint, + /// data: string[] + /// } + kDenseStringElementsAttr = 19, + + /// SparseElementsAttr { + /// type: ShapedType, + /// indices: DenseIntElementsAttr, + /// values: DenseElementsAttr + /// } + kSparseElementsAttr = 20, }; /// This enum contains marker codes used to indicate which type is currently @@ -279,11 +305,18 @@ Attribute readAttribute(DialectBytecodeReader &reader) const override; ArrayAttr readArrayAttr(DialectBytecodeReader &reader) const; + DenseArrayAttr readDenseArrayAttr(DialectBytecodeReader &reader) const; + DenseElementsAttr + readDenseIntOrFPElementsAttr(DialectBytecodeReader &reader) const; + DenseStringElementsAttr + readDenseStringElementsAttr(DialectBytecodeReader &reader) const; DenseResourceElementsAttr readDenseResourceElementsAttr(DialectBytecodeReader &reader) const; DictionaryAttr readDictionaryAttr(DialectBytecodeReader &reader) const; FloatAttr readFloatAttr(DialectBytecodeReader &reader) const; IntegerAttr readIntegerAttr(DialectBytecodeReader &reader) const; + SparseElementsAttr + readSparseElementsAttr(DialectBytecodeReader &reader) const; StringAttr readStringAttr(DialectBytecodeReader &reader, bool hasType) const; SymbolRefAttr readSymbolRefAttr(DialectBytecodeReader &reader, bool hasNestedRefs) const; @@ -298,11 +331,16 @@ LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const override; void write(ArrayAttr attr, DialectBytecodeWriter &writer) const; + void write(DenseArrayAttr attr, DialectBytecodeWriter &writer) const; + void write(DenseIntOrFPElementsAttr attr, + DialectBytecodeWriter &writer) const; + void write(DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const; void write(DenseResourceElementsAttr attr, DialectBytecodeWriter &writer) const; void write(DictionaryAttr attr, DialectBytecodeWriter &writer) const; void write(IntegerAttr attr, DialectBytecodeWriter &writer) const; void write(FloatAttr attr, DialectBytecodeWriter &writer) const; + void write(SparseElementsAttr attr, DialectBytecodeWriter &writer) const; void write(StringAttr attr, DialectBytecodeWriter &writer) const; void write(SymbolRefAttr attr, DialectBytecodeWriter &writer) const; void write(TypeAttr attr, DialectBytecodeWriter &writer) const; @@ -394,6 +432,14 @@ return UnknownLoc::get(getContext()); case builtin_encoding::kDenseResourceElementsAttr: return readDenseResourceElementsAttr(reader); + case builtin_encoding::kDenseArrayAttr: + return readDenseArrayAttr(reader); + case builtin_encoding::kDenseIntOrFPElementsAttr: + return readDenseIntOrFPElementsAttr(reader); + case builtin_encoding::kDenseStringElementsAttr: + return readDenseStringElementsAttr(reader); + case builtin_encoding::kSparseElementsAttr: + return readSparseElementsAttr(reader); default: reader.emitError() << "unknown builtin attribute code: " << code; return Attribute(); @@ -403,8 +449,10 @@ LogicalResult BuiltinDialectBytecodeInterface::writeAttribute( Attribute attr, DialectBytecodeWriter &writer) const { return TypeSwitch(attr) - .Case([&](auto attr) { + .Case([&](auto attr) { write(attr, writer); return success(); }) @@ -441,6 +489,78 @@ writer.writeAttributes(attr.getValue()); } +//===----------------------------------------------------------------------===// +// DenseArrayAttr + +DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr( + DialectBytecodeReader &reader) const { + RankedTensorType type; + ArrayRef blob; + if (failed(reader.readType(type)) || failed(reader.readBlob(blob))) + return DenseArrayAttr(); + return DenseArrayAttr::get(type, blob); +} + +void BuiltinDialectBytecodeInterface::write( + DenseArrayAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kDenseArrayAttr); + writer.writeType(attr.getType()); + writer.writeOwnedBlob(attr.getRawData()); +} + +//===----------------------------------------------------------------------===// +// DenseIntOrFPElementsAttr + +DenseElementsAttr BuiltinDialectBytecodeInterface::readDenseIntOrFPElementsAttr( + DialectBytecodeReader &reader) const { + ShapedType type; + ArrayRef blob; + if (failed(reader.readType(type)) || failed(reader.readBlob(blob))) + return DenseIntOrFPElementsAttr(); + return DenseIntOrFPElementsAttr::getFromRawBuffer(type, blob); +} + +void BuiltinDialectBytecodeInterface::write( + DenseIntOrFPElementsAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kDenseIntOrFPElementsAttr); + writer.writeType(attr.getType()); + writer.writeOwnedBlob(attr.getRawData()); +} + +//===----------------------------------------------------------------------===// +// DenseStringElementsAttr + +DenseStringElementsAttr +BuiltinDialectBytecodeInterface::readDenseStringElementsAttr( + DialectBytecodeReader &reader) const { + ShapedType type; + uint64_t isSplat; + if (failed(reader.readType(type)) || failed(reader.readVarInt(isSplat))) + return DenseStringElementsAttr(); + + SmallVector values(isSplat ? 1 : type.getNumElements()); + for (StringRef &value : values) + if (failed(reader.readString(value))) + return DenseStringElementsAttr(); + return DenseStringElementsAttr::get(type, values); +} + +void BuiltinDialectBytecodeInterface::write( + DenseStringElementsAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kDenseStringElementsAttr); + writer.writeType(attr.getType()); + + bool isSplat = attr.isSplat(); + writer.writeVarInt(isSplat); + + // If the attribute is a splat, only write out the single value. + if (isSplat) + return writer.writeOwnedString(attr.getRawStringData().front()); + + for (StringRef str : attr.getRawStringData()) + writer.writeOwnedString(str); +} + //===----------------------------------------------------------------------===// // DenseResourceElementsAttr @@ -550,6 +670,28 @@ writer.writeAPIntWithKnownWidth(attr.getValue()); } +//===----------------------------------------------------------------------===// +// SparseElementsAttr + +SparseElementsAttr BuiltinDialectBytecodeInterface::readSparseElementsAttr( + DialectBytecodeReader &reader) const { + ShapedType type; + DenseIntElementsAttr indices; + DenseElementsAttr values; + if (failed(reader.readType(type)) || failed(reader.readAttribute(indices)) || + failed(reader.readAttribute(values))) + return SparseElementsAttr(); + return SparseElementsAttr::get(type, indices, values); +} + +void BuiltinDialectBytecodeInterface::write( + SparseElementsAttr attr, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kSparseElementsAttr); + writer.writeType(attr.getType()); + writer.writeAttribute(attr.getIndices()); + writer.writeAttribute(attr.getValues()); +} + //===----------------------------------------------------------------------===// // StringAttr diff --git a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir --- a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir +++ b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -emit-bytecode %s | mlir-opt -mlir-print-local-scope | FileCheck %s +// RUN: mlir-opt -emit-bytecode -allow-unregistered-dialect %s | mlir-opt -allow-unregistered-dialect -mlir-print-local-scope | FileCheck %s // Bytecode currently does not support big-endian platforms // UNSUPPORTED: s390x- @@ -13,6 +13,44 @@ bytecode.array = [unit] } {} +//===----------------------------------------------------------------------===// +// DenseArrayAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestDenseArray +module @TestDenseArray attributes { + // CHECK: bytecode.test1 = array + // CHECK: bytecode.test2 = array + // CHECK: bytecode.test3 = array, + bytecode.test2 = array, + bytecode.test3 = array +} {} + +//===----------------------------------------------------------------------===// +// DenseIntOfFPElementsAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestDenseIntOrFPElements +// CHECK: bytecode.test1 = dense : tensor<256xi1> +// CHECK: bytecode.test2 = dense<[10, 32, -1]> : tensor<3xi8> +// CHECK: bytecode.test3 = dense<[1.{{.*}}e+01, 3.2{{.*}}e+01, 1.809{{.*}}e+03]> : tensor<3xf64> +module @TestDenseIntOrFPElements attributes { + bytecode.test1 = dense : tensor<256xi1>, + bytecode.test2 = dense<[10, 32, 255]> : tensor<3xi8>, + bytecode.test3 = dense<[10.0, 32.0, 1809.0]> : tensor<3xf64> +} {} + +//===----------------------------------------------------------------------===// +// DenseStringElementsAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestDenseStringElementsAttr +module @TestDenseStringElementsAttr attributes { + bytecode.test1 = dense<"splat"> : tensor<256x!bytecode.string>, + bytecode.test2 = dense<["foo", "bar", "baz"]> : tensor<3x!bytecode.string> +} {} + //===----------------------------------------------------------------------===// // FloatAttr //===----------------------------------------------------------------------===// @@ -45,6 +83,17 @@ bytecode.int3 = 90000000000000000300000000000000000001 : i128 } {} +//===----------------------------------------------------------------------===// +// SparseElementsAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestSparseElements +module @TestSparseElements attributes { + // CHECK-LITERAL: bytecode.sparse = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32> + bytecode.sparse = sparse<[[0, 0], [1, 2]], [1, 5]> : tensor<3x4xi32> +} {} + + //===----------------------------------------------------------------------===// // StringAttr //===----------------------------------------------------------------------===//