diff --git a/mlir/lib/Dialect/Quant/IR/CMakeLists.txt b/mlir/lib/Dialect/Quant/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Quant/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/IR/CMakeLists.txt @@ -1,4 +1,6 @@ add_mlir_dialect_library(MLIRQuantDialect + QuantDialectBytecode.h + QuantDialectBytecode.cpp QuantOps.cpp QuantTypes.cpp TypeDetail.h diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.h @@ -0,0 +1,27 @@ +//===- QuantDialectBytecode.h - Quant Bytecode Implementation --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header defines hooks into the quantization dialect bytecode +// implementation. +// +//===----------------------------------------------------------------------===// + +#ifndef LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H +#define LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H + +namespace mlir::quant { +class QuantizationDialect; + +namespace detail { +/// Add the interfaces necessary for encoding the quantization dialect +/// components in bytecode. +void addBytecodeInterface(QuantizationDialect *dialect); +} // namespace detail +} // namespace mlir::quant + +#endif // LIB_MLIR_DIALECT_QUANT_IR_QUANTDIALECTBYTECODE_H diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp @@ -0,0 +1,299 @@ +//===- QuantDialectBytecode.cpp - Quant Bytecode Implementation +//------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "QuantDialectBytecode.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::quant; + +//===----------------------------------------------------------------------===// +// Encoding +//===----------------------------------------------------------------------===// + +namespace { +namespace quant_encoding { +/// This enum contains marker codes used to indicate which type is currently +/// being decoded, and how it should be decoded. The order of these codes should +/// generally be unchanged, as any changes will inevitably break compatibility +/// with older bytecode. +enum TypeCode { + /// AnyQuantizedType { + /// flags: varint + /// storageType: Type + /// storageTypeMin: svarint + /// storageTypeMax: svarint + /// } + /// + kAnyQuantizedType = 1, + + /// AnyQuantizedType { + /// flags: varint + /// storageType: Type + /// expressedType: Type + /// storageTypeMin: svarint + /// storageTypeMax: svarint + /// } + /// + kAnyQuantizedTypeWithExpressedType = 2, + + /// CalibratedQuantizedType { + /// expressedType: Type + /// min: APFloat + /// max: APFloat + /// } + /// + kCalibratedQuantizedType = 3, + + /// UniformQuantizedType { + /// flags: varint + /// storageType: Type + /// expressedType: Type + /// scale: APFloat + /// zeroPoint: svarint + /// storageTypeMin: svarint + /// storageTypeMax: svarint + /// } + /// + kUniformQuantizedType = 4, + + /// UniformQuantizedPerAxisType { + /// flags: varint + /// storageType: Type + /// expressedType: Type + /// quantizedDimension: varint + /// storageTypeMin: svarint + /// storageTypeMax: svarint + /// scale: APFloat[] + /// zeroPoint: svarint[] + /// } + /// + kUniformQuantizedPerAxisType = 5, +}; + +} // namespace quant_encoding +} // namespace + +//===----------------------------------------------------------------------===// +// QuantDialectBytecodeInterface +//===----------------------------------------------------------------------===// + +namespace { +/// This class implements the bytecode interface for the Quant dialect. +struct QuantDialectBytecodeInterface : public BytecodeDialectInterface { + QuantDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + //===--------------------------------------------------------------------===// + // Types + + Type readType(DialectBytecodeReader &reader) const override; + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override; + + AnyQuantizedType readAnyQuantizedType(bool withExpressedType, + DialectBytecodeReader &reader) const; + void write(AnyQuantizedType type, DialectBytecodeWriter &writer) const; + + CalibratedQuantizedType + readCalibratedQuantizedType(DialectBytecodeReader &reader) const; + void write(CalibratedQuantizedType type, DialectBytecodeWriter &writer) const; + + UniformQuantizedType + readUniformQuantizedType(DialectBytecodeReader &reader) const; + void write(UniformQuantizedType type, DialectBytecodeWriter &writer) const; + + UniformQuantizedPerAxisType + readUniformQuantizedPerAxisType(DialectBytecodeReader &reader) const; + void write(UniformQuantizedPerAxisType type, + DialectBytecodeWriter &writer) const; +}; +} // namespace + +void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) { + dialect->addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// Types +//===----------------------------------------------------------------------===// + +Type QuantDialectBytecodeInterface::readType( + DialectBytecodeReader &reader) const { + uint64_t code; + if (failed(reader.readVarInt(code))) + return Type(); + + switch (code) { + case quant_encoding::kAnyQuantizedType: + return readAnyQuantizedType(/*withExpressedType=*/false, reader); + case quant_encoding::kAnyQuantizedTypeWithExpressedType: + return readAnyQuantizedType(/*withExpressedType=*/true, reader); + case quant_encoding::kCalibratedQuantizedType: + return readCalibratedQuantizedType(reader); + case quant_encoding::kUniformQuantizedType: + return readUniformQuantizedType(reader); + case quant_encoding::kUniformQuantizedPerAxisType: + return readUniformQuantizedPerAxisType(reader); + + default: + reader.emitError() << "unknown builtin type code: " << code; + return Type(); + } +} + +LogicalResult +QuantDialectBytecodeInterface::writeType(Type type, + DialectBytecodeWriter &writer) const { + return TypeSwitch(type) + .Case( + [&](auto attr) { return write(attr, writer), success(); }) + .Default([&](Type) { return failure(); }); +} + +AnyQuantizedType QuantDialectBytecodeInterface::readAnyQuantizedType( + bool withExpressedType, DialectBytecodeReader &reader) const { + uint64_t flags; + Type storageType, expressedType; + int64_t storageTypeMin, storageTypeMax; + if (failed(reader.readVarInt(flags)) || + failed(reader.readType(storageType)) || + (withExpressedType && failed(reader.readType(expressedType))) || + failed(reader.readSignedVarInt(storageTypeMin)) || + failed(reader.readSignedVarInt(storageTypeMax))) + return reader.emitError("invalid AnyQuantizedType"), AnyQuantizedType(); + return AnyQuantizedType::get(flags, storageType, expressedType, + storageTypeMin, storageTypeMax); +} +void QuantDialectBytecodeInterface::write(AnyQuantizedType type, + DialectBytecodeWriter &writer) const { + if (type.getExpressedType()) + writer.writeVarInt(quant_encoding::kAnyQuantizedTypeWithExpressedType); + else + writer.writeVarInt(quant_encoding::kAnyQuantizedType); + + writer.writeVarInt(type.getFlags()); + writer.writeType(type.getStorageType()); + if (type.getExpressedType()) + writer.writeType(type.getExpressedType()); + writer.writeSignedVarInt(type.getStorageTypeMin()); + writer.writeSignedVarInt(type.getStorageTypeMax()); +} + +CalibratedQuantizedType +QuantDialectBytecodeInterface::readCalibratedQuantizedType( + DialectBytecodeReader &reader) const { + Type expressedType; + FailureOr min, max; + if (failed(reader.readType(expressedType)) || + failed(min = reader.readAPFloatWithKnownSemantics( + llvm::APFloat::IEEEdouble())) || + failed(max = reader.readAPFloatWithKnownSemantics( + llvm::APFloat::IEEEdouble()))) + return reader.emitError("invalid CalibratedQuantizedType"), + CalibratedQuantizedType(); + return CalibratedQuantizedType::get(expressedType, + min.value().convertToDouble(), + max.value().convertToDouble()); +} +void QuantDialectBytecodeInterface::write(CalibratedQuantizedType type, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(quant_encoding::kCalibratedQuantizedType); + writer.writeType(type.getExpressedType()); + writer.writeAPFloatWithKnownSemantics(APFloat(type.getMin())); + writer.writeAPFloatWithKnownSemantics(APFloat(type.getMax())); +} + +UniformQuantizedType QuantDialectBytecodeInterface::readUniformQuantizedType( + DialectBytecodeReader &reader) const { + uint64_t flags; + Type storageType, expressedType; + FailureOr scale; + int64_t zeroPoint, storageTypeMin, storageTypeMax; + if (failed(reader.readVarInt(flags)) || + failed(reader.readType(storageType)) || + failed(reader.readType(expressedType)) || + failed(scale = reader.readAPFloatWithKnownSemantics( + llvm::APFloat::IEEEdouble())) || + failed(reader.readSignedVarInt(zeroPoint)) || + failed(reader.readSignedVarInt(storageTypeMin)) || + failed(reader.readSignedVarInt(storageTypeMax))) + return reader.emitError("invalid UniformQuantizedType"), + UniformQuantizedType(); + return UniformQuantizedType::get(flags, storageType, expressedType, + scale.value().convertToDouble(), zeroPoint, + storageTypeMin, storageTypeMax); +} +void QuantDialectBytecodeInterface::write(UniformQuantizedType type, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(quant_encoding::kUniformQuantizedType); + writer.writeVarInt(type.getFlags()); + writer.writeType(type.getStorageType()); + writer.writeType(type.getExpressedType()); + writer.writeAPFloatWithKnownSemantics(APFloat(type.getScale())); + writer.writeSignedVarInt(type.getZeroPoint()); + writer.writeSignedVarInt(type.getStorageTypeMin()); + writer.writeSignedVarInt(type.getStorageTypeMax()); +} + +UniformQuantizedPerAxisType +QuantDialectBytecodeInterface::readUniformQuantizedPerAxisType( + DialectBytecodeReader &reader) const { + uint64_t flags; + Type storageType, expressedType; + SmallVector scales; + SmallVector zeroPoints; + uint64_t quantizedDimension; + int64_t storageTypeMin, storageTypeMax; + + auto scalesRead = [&](double &val) -> LogicalResult { + FailureOr fl = + reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble()); + if (succeeded(fl)) { + val = fl.value().convertToDouble(); + return success(); + } + return failure(); + }; + + if (failed(reader.readVarInt(flags)) || + failed(reader.readType(storageType)) || + failed(reader.readType(expressedType)) || + failed(reader.readList(scales, scalesRead)) || + failed(reader.readSignedVarInts(zeroPoints)) || + failed(reader.readVarInt(quantizedDimension)) || + failed(reader.readSignedVarInt(storageTypeMin)) || + failed(reader.readSignedVarInt(storageTypeMax))) + return reader.emitError("invalid UniformQuantizedPerAxisType"), + UniformQuantizedPerAxisType(); + return UniformQuantizedPerAxisType::get( + flags, storageType, expressedType, scales, zeroPoints, + (int32_t)quantizedDimension, storageTypeMin, storageTypeMax); +} +void QuantDialectBytecodeInterface::write(UniformQuantizedPerAxisType type, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(quant_encoding::kUniformQuantizedType); + writer.writeVarInt(type.getFlags()); + writer.writeType(type.getStorageType()); + writer.writeType(type.getExpressedType()); + writer.writeList(type.getScales(), [&](double val) { + writer.writeAPFloatWithKnownSemantics(APFloat(val)); + }); + writer.writeSignedVarInts(type.getZeroPoints()); + writer.writeVarInt(type.getQuantizedDimension()); + writer.writeSignedVarInt(type.getStorageTypeMin()); + writer.writeSignedVarInt(type.getStorageTypeMax()); +} diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Quant/QuantOps.h" +#include "QuantDialectBytecode.h" #include "TypeDetail.h" #include "mlir/Dialect/Quant/QuantTypes.h" @@ -32,6 +33,7 @@ #define GET_OP_LIST #include "mlir/Dialect/Quant/QuantOps.cpp.inc" >(); + addBytecodeInterface(this); } OpFoldResult StorageCastOp::fold(ArrayRef operands) { diff --git a/mlir/test/Dialect/Quant/Bytecode/types.mlir b/mlir/test/Dialect/Quant/Bytecode/types.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Quant/Bytecode/types.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s + +// Bytecode currently does not support big-endian platforms +// UNSUPPORTED: s390x- + +//===----------------------------------------------------------------------===// +// AnyQuantized +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseAnyFullySpecified +module @parseAnyFullySpecified attributes { + // CHECK: bytecode.test = !quant.any:f32> + bytecode.test = !quant.any:f32> +} {} + +// CHECK-LABEL: parseAnyNoExpressedType +module @parseAnyNoExpressedType attributes { + // CHECK: bytecode.test = !quant.any> + bytecode.test = !quant.any> +} {} + +// CHECK-LABEL: parseAnyOnlyStorageType +module @parseAnyOnlyStorageType attributes { + // CHECK: bytecode.test = !quant.any> + bytecode.test = !quant.any> +} {} + +//===----------------------------------------------------------------------===// +// CalibratedQuantized +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseCalibrated +module @parseCalibrated attributes { + // CHECK: !quant.calibrated> + bytecode.test = !quant.calibrated> +} {} + +//===----------------------------------------------------------------------===// +// UniformQuantized +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseUniformPerLayer +module @parseUniformPerLayer attributes { + // CHECK: !quant.uniform:f32, 9.987200e-01:127> + bytecode.test = !quant.uniform:f32, 9.987200e-01:127> +} {} + +//===----------------------------------------------------------------------===// +// UniformQuantizedPerAxis +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseUniformPerAxisScaleZero +module @parseUniformPerAxisScaleZero attributes { + // CHECK: !quant.uniform + bytecode.test = !quant.uniform +} {} + +// CHECK-LABEL: parseUniformPerAxisScaleNoZero +module @parseUniformPerAxisScaleNoZero attributes { + // CHECK: !quant.uniform + bytecode.test = !quant.uniform +} {} + +// CHECK-LABEL: parseUniformPerAxisMixed +module @parseUniformPerAxisMixed attributes { + // CHECK: !quant.uniform + bytecode.test = !quant.uniform +} {} + diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7406,6 +7406,8 @@ cc_library( name = "QuantOps", srcs = [ + "lib/Dialect/Quant/IR/QuantDialectBytecode.h" + "lib/Dialect/Quant/IR/QuantDialectBytecode.cpp" "lib/Dialect/Quant/IR/QuantOps.cpp", "lib/Dialect/Quant/IR/QuantTypes.cpp", "lib/Dialect/Quant/IR/TypeDetail.h",