diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -38,10 +38,6 @@ /// Emit an error to the reader. virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0; - //===--------------------------------------------------------------------===// - // IR - //===--------------------------------------------------------------------===// - /// Read out a list of elements, invoking the provided callback for each /// element. The callback function may be in any of the following forms: /// * LogicalResult(T &) @@ -71,6 +67,10 @@ return success(); } + //===--------------------------------------------------------------------===// + // IR + //===--------------------------------------------------------------------===// + /// Read a reference to the given attribute. virtual LogicalResult readAttribute(Attribute &result) = 0; template @@ -114,6 +114,10 @@ /// Read a signed variable width integer. virtual LogicalResult readSignedVarInt(int64_t &result) = 0; + LogicalResult readSignedVarInts(SmallVectorImpl &result) { + return readList(result, + [this](int64_t &value) { return readSignedVarInt(value); }); + } /// Read an APInt that is known to have been encoded with the given width. virtual FailureOr readAPIntWithKnownWidth(unsigned bitWidth) = 0; @@ -178,6 +182,9 @@ /// Write a signed variable width integer to the output stream. This should be /// the preferred method for emitting signed integers whenever possible. virtual void writeSignedVarInt(int64_t value) = 0; + void writeSignedVarInts(ArrayRef value) { + writeList(value, [this](int64_t value) { writeSignedVarInt(value); }); + } /// Write an APInt to the bytecode stream whose bitwidth will be known /// externally at read time. This method is useful for encoding APInt values diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -140,6 +140,118 @@ /// } /// kFunctionType = 2, + + /// BFloat16Type { + /// } + /// + kBFloat16Type = 3, + + /// Float16Type { + /// } + /// + kFloat16Type = 4, + + /// Float32Type { + /// } + /// + kFloat32Type = 5, + + /// Float64Type { + /// } + /// + kFloat64Type = 6, + + /// Float80Type { + /// } + /// + kFloat80Type = 7, + + /// Float128Type { + /// } + /// + kFloat128Type = 8, + + /// ComplexType { + /// elementType: Type + /// } + /// + kComplexType = 9, + + /// MemRefType { + /// shape: svarint[], + /// elementType: Type, + /// layout: Attribute + /// } + /// + kMemRefType = 10, + + /// MemRefTypeWithMemSpace { + /// memorySpace: Attribute, + /// shape: svarint[], + /// elementType: Type, + /// layout: Attribute + /// } + /// Variant of MemRefType with non-default memory space. + kMemRefTypeWithMemSpace = 11, + + /// NoneType { + /// } + /// + kNoneType = 12, + + /// RankedTensorType { + /// shape: svarint[], + /// elementType: Type, + /// } + /// + kRankedTensorType = 13, + + /// RankedTensorTypeWithEncoding { + /// encoding: Attribute, + /// shape: svarint[], + /// elementType: Type + /// } + /// Variant of RankedTensorType with an encoding. + kRankedTensorTypeWithEncoding = 14, + + /// TupleType { + /// elementTypes: Type[] + /// } + kTupleType = 15, + + /// UnrankedMemRefType { + /// shape: svarint[] + /// } + /// + kUnrankedMemRefType = 16, + + /// UnrankedMemRefTypeWithMemSpace { + /// memorySpace: Attribute, + /// shape: svarint[] + /// } + /// Variant of UnrankedMemRefType with non-default memory space. + kUnrankedMemRefTypeWithMemSpace = 17, + + /// UnrankedTensorType { + /// elementType: Type + /// } + /// + kUnrankedTensorType = 18, + + /// VectorType { + /// shape: svarint[], + /// elementType: Type + /// } + /// + kVectorType = 19, + + /// VectorTypeWithScalableDims { + /// numScalableDims: varint, + /// shape: svarint[], + /// elementType: Type + /// } + /// Variant of VectorType with scalable dimensions. + kVectorTypeWithScalableDims = 20, }; } // namespace builtin_encoding @@ -194,13 +306,32 @@ // Types Type readType(DialectBytecodeReader &reader) const override; + ComplexType readComplexType(DialectBytecodeReader &reader) const; IntegerType readIntegerType(DialectBytecodeReader &reader) const; FunctionType readFunctionType(DialectBytecodeReader &reader) const; + MemRefType readMemRefType(DialectBytecodeReader &reader, + bool hasMemSpace) const; + RankedTensorType readRankedTensorType(DialectBytecodeReader &reader, + bool hasEncoding) const; + TupleType readTupleType(DialectBytecodeReader &reader) const; + UnrankedMemRefType readUnrankedMemRefType(DialectBytecodeReader &reader, + bool hasMemSpace) const; + UnrankedTensorType + readUnrankedTensorType(DialectBytecodeReader &reader) const; + VectorType readVectorType(DialectBytecodeReader &reader, + bool hasScalableDims) const; LogicalResult writeType(Type type, DialectBytecodeWriter &writer) const override; + void write(ComplexType type, DialectBytecodeWriter &writer) const; void write(IntegerType type, DialectBytecodeWriter &writer) const; void write(FunctionType type, DialectBytecodeWriter &writer) const; + void write(MemRefType type, DialectBytecodeWriter &writer) const; + void write(RankedTensorType type, DialectBytecodeWriter &writer) const; + void write(TupleType type, DialectBytecodeWriter &writer) const; + void write(UnrankedMemRefType type, DialectBytecodeWriter &writer) const; + void write(UnrankedTensorType type, DialectBytecodeWriter &writer) const; + void write(VectorType type, DialectBytecodeWriter &writer) const; }; } // namespace @@ -576,9 +707,45 @@ return readIntegerType(reader); case builtin_encoding::kIndexType: return IndexType::get(getContext()); - case builtin_encoding::kFunctionType: return readFunctionType(reader); + case builtin_encoding::kBFloat16Type: + return BFloat16Type::get(getContext()); + case builtin_encoding::kFloat16Type: + return Float16Type::get(getContext()); + case builtin_encoding::kFloat32Type: + return Float32Type::get(getContext()); + case builtin_encoding::kFloat64Type: + return Float64Type::get(getContext()); + case builtin_encoding::kFloat80Type: + return Float80Type::get(getContext()); + case builtin_encoding::kFloat128Type: + return Float128Type::get(getContext()); + case builtin_encoding::kComplexType: + return readComplexType(reader); + case builtin_encoding::kMemRefType: + return readMemRefType(reader, /*hasMemSpace=*/false); + case builtin_encoding::kMemRefTypeWithMemSpace: + return readMemRefType(reader, /*hasMemSpace=*/true); + case builtin_encoding::kNoneType: + return NoneType::get(getContext()); + case builtin_encoding::kRankedTensorType: + return readRankedTensorType(reader, /*hasEncoding=*/false); + case builtin_encoding::kRankedTensorTypeWithEncoding: + return readRankedTensorType(reader, /*hasEncoding=*/true); + case builtin_encoding::kTupleType: + return readTupleType(reader); + case builtin_encoding::kUnrankedMemRefType: + return readUnrankedMemRefType(reader, /*hasMemSpace=*/false); + case builtin_encoding::kUnrankedMemRefTypeWithMemSpace: + return readUnrankedMemRefType(reader, /*hasMemSpace=*/true); + case builtin_encoding::kUnrankedTensorType: + return readUnrankedTensorType(reader); + case builtin_encoding::kVectorType: + return readVectorType(reader, /*hasScalableDims=*/false); + case builtin_encoding::kVectorTypeWithScalableDims: + return readVectorType(reader, /*hasScalableDims=*/true); + default: reader.emitError() << "unknown builtin type code: " << code; return Type(); @@ -588,16 +755,56 @@ LogicalResult BuiltinDialectBytecodeInterface::writeType( Type type, DialectBytecodeWriter &writer) const { return TypeSwitch(type) - .Case([&](auto type) { + .Case([&](auto type) { write(type, writer); return success(); }) .Case([&](IndexType) { return writer.writeVarInt(builtin_encoding::kIndexType), success(); }) + .Case([&](BFloat16Type) { + return writer.writeVarInt(builtin_encoding::kBFloat16Type), success(); + }) + .Case([&](Float16Type) { + return writer.writeVarInt(builtin_encoding::kFloat16Type), success(); + }) + .Case([&](Float32Type) { + return writer.writeVarInt(builtin_encoding::kFloat32Type), success(); + }) + .Case([&](Float64Type) { + return writer.writeVarInt(builtin_encoding::kFloat64Type), success(); + }) + .Case([&](Float80Type) { + return writer.writeVarInt(builtin_encoding::kFloat80Type), success(); + }) + .Case([&](Float128Type) { + return writer.writeVarInt(builtin_encoding::kFloat128Type), success(); + }) + .Case([&](NoneType) { + return writer.writeVarInt(builtin_encoding::kNoneType), success(); + }) .Default([&](Type) { return failure(); }); } +//===----------------------------------------------------------------------===// +// ComplexType + +ComplexType BuiltinDialectBytecodeInterface::readComplexType( + DialectBytecodeReader &reader) const { + Type elementType; + if (failed(reader.readType(elementType))) + return ComplexType(); + return ComplexType::get(elementType); +} + +void BuiltinDialectBytecodeInterface::write( + ComplexType type, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kComplexType); + writer.writeType(type.getElementType()); +} + //===----------------------------------------------------------------------===// // IntegerType @@ -634,3 +841,151 @@ writer.writeTypes(type.getInputs()); writer.writeTypes(type.getResults()); } + +//===----------------------------------------------------------------------===// +// MemRefType + +MemRefType +BuiltinDialectBytecodeInterface::readMemRefType(DialectBytecodeReader &reader, + bool hasMemSpace) const { + Attribute memorySpace; + if (hasMemSpace && failed(reader.readAttribute(memorySpace))) + return MemRefType(); + SmallVector shape; + Type elementType; + MemRefLayoutAttrInterface layout; + if (failed(reader.readSignedVarInts(shape)) || + failed(reader.readType(elementType)) || + failed(reader.readAttribute(layout))) + return MemRefType(); + return MemRefType::get(shape, elementType, layout, memorySpace); +} + +void BuiltinDialectBytecodeInterface::write( + MemRefType type, DialectBytecodeWriter &writer) const { + if (Attribute memSpace = type.getMemorySpace()) { + writer.writeVarInt(builtin_encoding::kMemRefTypeWithMemSpace); + writer.writeAttribute(memSpace); + } else { + writer.writeVarInt(builtin_encoding::kMemRefType); + } + writer.writeSignedVarInts(type.getShape()); + writer.writeType(type.getElementType()); + writer.writeAttribute(type.getLayout()); +} + +//===----------------------------------------------------------------------===// +// RankedTensorType + +RankedTensorType BuiltinDialectBytecodeInterface::readRankedTensorType( + DialectBytecodeReader &reader, bool hasEncoding) const { + Attribute encoding; + if (hasEncoding && failed(reader.readAttribute(encoding))) + return RankedTensorType(); + SmallVector shape; + Type elementType; + if (failed(reader.readSignedVarInts(shape)) || + failed(reader.readType(elementType))) + return RankedTensorType(); + return RankedTensorType::get(shape, elementType, encoding); +} + +void BuiltinDialectBytecodeInterface::write( + RankedTensorType type, DialectBytecodeWriter &writer) const { + if (Attribute encoding = type.getEncoding()) { + writer.writeVarInt(builtin_encoding::kRankedTensorTypeWithEncoding); + writer.writeAttribute(encoding); + } else { + writer.writeVarInt(builtin_encoding::kRankedTensorType); + } + writer.writeSignedVarInts(type.getShape()); + writer.writeType(type.getElementType()); +} + +//===----------------------------------------------------------------------===// +// TupleType + +TupleType BuiltinDialectBytecodeInterface::readTupleType( + DialectBytecodeReader &reader) const { + SmallVector elements; + if (failed(reader.readTypes(elements))) + return TupleType(); + return TupleType::get(getContext(), elements); +} + +void BuiltinDialectBytecodeInterface::write( + TupleType type, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kTupleType); + writer.writeTypes(type.getTypes()); +} + +//===----------------------------------------------------------------------===// +// UnrankedMemRefType + +UnrankedMemRefType BuiltinDialectBytecodeInterface::readUnrankedMemRefType( + DialectBytecodeReader &reader, bool hasMemSpace) const { + Attribute memorySpace; + if (hasMemSpace && failed(reader.readAttribute(memorySpace))) + return UnrankedMemRefType(); + Type elementType; + if (failed(reader.readType(elementType))) + return UnrankedMemRefType(); + return UnrankedMemRefType::get(elementType, memorySpace); +} + +void BuiltinDialectBytecodeInterface::write( + UnrankedMemRefType type, DialectBytecodeWriter &writer) const { + if (Attribute memSpace = type.getMemorySpace()) { + writer.writeVarInt(builtin_encoding::kUnrankedMemRefTypeWithMemSpace); + writer.writeAttribute(memSpace); + } else { + writer.writeVarInt(builtin_encoding::kUnrankedMemRefType); + } + writer.writeType(type.getElementType()); +} + +//===----------------------------------------------------------------------===// +// UnrankedTensorType + +UnrankedTensorType BuiltinDialectBytecodeInterface::readUnrankedTensorType( + DialectBytecodeReader &reader) const { + Type elementType; + if (failed(reader.readType(elementType))) + return UnrankedTensorType(); + return UnrankedTensorType::get(elementType); +} + +void BuiltinDialectBytecodeInterface::write( + UnrankedTensorType type, DialectBytecodeWriter &writer) const { + writer.writeVarInt(builtin_encoding::kUnrankedTensorType); + writer.writeType(type.getElementType()); +} + +//===----------------------------------------------------------------------===// +// VectorType + +VectorType +BuiltinDialectBytecodeInterface::readVectorType(DialectBytecodeReader &reader, + bool hasScalableDims) const { + uint64_t numScalableDims = 0; + if (hasScalableDims && failed(reader.readVarInt(numScalableDims))) + return VectorType(); + SmallVector shape; + Type elementType; + if (failed(reader.readSignedVarInts(shape)) || + failed(reader.readType(elementType))) + return VectorType(); + return VectorType::get(shape, elementType, numScalableDims); +} + +void BuiltinDialectBytecodeInterface::write( + VectorType type, DialectBytecodeWriter &writer) const { + if (unsigned numScalableDims = type.getNumScalableDims()) { + writer.writeVarInt(builtin_encoding::kVectorTypeWithScalableDims); + writer.writeVarInt(numScalableDims); + } else { + writer.writeVarInt(builtin_encoding::kVectorType); + } + writer.writeSignedVarInts(type.getShape()); + writer.writeType(type.getElementType()); +} diff --git a/mlir/test/Dialect/Builtin/Bytecode/types.mlir b/mlir/test/Dialect/Builtin/Bytecode/types.mlir --- a/mlir/test/Dialect/Builtin/Bytecode/types.mlir +++ b/mlir/test/Dialect/Builtin/Bytecode/types.mlir @@ -3,6 +3,40 @@ // Bytecode currently does not support big-endian platforms // UNSUPPORTED: s390x- +//===----------------------------------------------------------------------===// +// ComplexType +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestComplex +module @TestComplex attributes { + // CHECK: bytecode.test = complex + bytecode.test = complex +} {} + +//===----------------------------------------------------------------------===// +// FloatType +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestFloat +module @TestFloat attributes { + // CHECK: bytecode.test = bf16, + // CHECK: bytecode.test1 = f16, + // CHECK: bytecode.test2 = f32, + // CHECK: bytecode.test3 = f64, + // CHECK: bytecode.test4 = f80, + // CHECK: bytecode.test5 = f128 + bytecode.test = bf16, + bytecode.test1 = f16, + bytecode.test2 = f32, + bytecode.test3 = f64, + bytecode.test4 = f80, + bytecode.test5 = f128 +} {} + +//===----------------------------------------------------------------------===// +// IntegerType +//===----------------------------------------------------------------------===// + // CHECK-LABEL: @TestInteger module @TestInteger attributes { // CHECK: bytecode.int = i1024, @@ -13,12 +47,20 @@ bytecode.int2 = ui512 } {} +//===----------------------------------------------------------------------===// +// IndexType +//===----------------------------------------------------------------------===// + // CHECK-LABEL: @TestIndex module @TestIndex attributes { // CHECK: bytecode.index = index bytecode.index = index } {} +//===----------------------------------------------------------------------===// +// FunctionType +//===----------------------------------------------------------------------===// + // CHECK-LABEL: @TestFunc module @TestFunc attributes { // CHECK: bytecode.func = () -> (), @@ -26,3 +68,83 @@ bytecode.func = () -> (), bytecode.func1 = (i1) -> (i32) } {} + +//===----------------------------------------------------------------------===// +// MemRefType +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestMemRef +module @TestMemRef attributes { + // CHECK: bytecode.test = memref<2xi8>, + // CHECK: bytecode.test1 = memref<2xi8, 1> + bytecode.test = memref<2xi8>, + bytecode.test1 = memref<2xi8, 1> +} {} + +//===----------------------------------------------------------------------===// +// NoneType +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestNone +module @TestNone attributes { + // CHECK: bytecode.test = none + bytecode.test = none +} {} + +//===----------------------------------------------------------------------===// +// RankedTensorType +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestRankedTensor +module @TestRankedTensor attributes { + // CHECK: bytecode.test = tensor<16x32x?xf64>, + // CHECK: bytecode.test1 = tensor<16xf64, "sparse"> + bytecode.test = tensor<16x32x?xf64>, + bytecode.test1 = tensor<16xf64, "sparse"> +} {} + +//===----------------------------------------------------------------------===// +// TupleType +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestTuple +module @TestTuple attributes { + // CHECK: bytecode.test = tuple<>, + // CHECK: bytecode.test1 = tuple + bytecode.test = tuple<>, + bytecode.test1 = tuple +} {} + +//===----------------------------------------------------------------------===// +// UnrankedMemRefType +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestUnrankedMemRef +module @TestUnrankedMemRef attributes { + // CHECK: bytecode.test = memref<*xi8>, + // CHECK: bytecode.test1 = memref<*xi8, 1> + bytecode.test = memref<*xi8>, + bytecode.test1 = memref<*xi8, 1> +} {} + +//===----------------------------------------------------------------------===// +// UnrankedTensorType +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestUnrankedTensor +module @TestUnrankedTensor attributes { + // CHECK: bytecode.test = tensor<*xi8> + bytecode.test = tensor<*xi8> +} {} + +//===----------------------------------------------------------------------===// +// VectorType +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @TestVector +module @TestVector attributes { + // CHECK: bytecode.test = vector<8x8x128xi8>, + // CHECK: bytecode.test1 = vector<8x[8]xf32> + bytecode.test = vector<8x8x128xi8>, + bytecode.test1 = vector<8x[8]xf32> +} {}