diff --git a/mlir/docs/DefiningDialects/_index.md b/mlir/docs/DefiningDialects/_index.md --- a/mlir/docs/DefiningDialects/_index.md +++ b/mlir/docs/DefiningDialects/_index.md @@ -299,8 +299,131 @@ void MyDialect::getCanonicalizationPatterns(RewritePatternSet &results) const; ``` -See the documentation for [Canonicalization in MLIR](../Canonicalization.md) for a much more -detailed description about canonicalization patterns. +See the documentation for [Canonicalization in MLIR](../Canonicalization.md) for +a more detailed description about canonicalization patterns. + +### Defining bytecode format for dialect attributes and types + +By default bytecode serialization of dialect attributes and types uses the +regular textual format. Dialects can define a more compact bytecode format for +the attributes and types in dialect by defining & attaching +`BytecodeDialectInterface` to the dialect. Basic support for generating +readers/writers for the bytecode dialect interface can be generated using ODS's +`-gen-bytecode`. The rest of the section will show an example. + +One can define the printing and parsing for a type in dialect `Foo` as follow: + +```td +include "mlir/IR/BytecodeBase.td" + +let cType = "MemRefType" in { +// Written in pseudo code showing the lowered encoding: +// /// MemRefType { +// /// shape: svarint[], +// /// elementType: Type, +// /// layout: Attribute +// /// } +// /// +// and the enum value: +// kMemRefType = 1, +// +// The corresponding definition in the ODS generator: +def MemRefType : DialectType<(type + Array:$shape, + Type:$elementType, + MemRefLayout:$layout +)> { + let printerPredicate = "!$_val.getMemorySpace()"; +} + +// /// MemRefTypeWithMemSpace { +// /// memorySpace: Attribute, +// /// shape: svarint[], +// /// elementType: Type, +// /// layout: Attribute +// /// } +// /// Variant of MemRefType with non-default memory space. +// kMemRefTypeWithMemSpace = 2, +def MemRefTypeWithMemSpace : DialectType<(type + Attribute:$memorySpace, + Array:$shape, + Type:$elementType, + MemRefLayout:$layout +)> { + let printerPredicate = "!!$_val.getMemorySpace()"; + // Note: order of serialization does not match order of builder. + let cBuilder = "get<$_resultType>(context, shape, elementType, layout, memorySpace)"; +} +} + +def FooDialectTypes : DialectTypes<"Foo"> { + let elems = [ + ReservedOrDead, // assigned index 0 + MemRefType, // assigned index 1 + MemRefTypeWithMemSpace, // assigned index 2 + ... + ]; +} +... +``` + +Here we have: + +* An outer most `cType` as we are representing encoding one C++ type using two + different variants. +* The different `DialectType` instances are differentiated in printing by the + printer predicate while parsing the different variant is already encoded and + different builder functions invoked. +* Custom `cBuilder` is specified as the way its laid out on disk in the + bytecode doesn't match the order of arguments to the build methods of the + type. +* Many of the common dialect bytecode reading and writing atoms (such as + `VarInt`, `SVarInt`, `Blob`) are defined in `BytecodeBase` while one can + also define custom forms or combine via `CompositeBytecode` instances. +* `ReservedOrDead` is a special keyword to indicate a skipped enum instance + for which no read/write or dispatch code is generated. +* `Array` is a helper method for which during printing a list is serialized + (e.g., a varint of number of items followed by said number of items) or + parsed. + +The generated code consists of a four standalone methods with which the +following interface can define the bytecode dialect interface: + +```c++ +#include "mlir/Dialect/Foo/FooDialectBytecode.cpp.inc" + +struct FooDialectBytecodeInterface : public BytecodeDialectInterface { + FooDialectBytecodeInterface(Dialect *dialect) + : BytecodeDialectInterface(dialect) {} + + //===--------------------------------------------------------------------===// + // Attributes + + Attribute readAttribute(DialectBytecodeReader &reader) const override { + return ::readAttribute(getContext(), reader); + } + + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override { + return ::writeAttribute(attr, writer); + } + + //===--------------------------------------------------------------------===// + // Types + + Type readType(DialectBytecodeReader &reader) const override { + return ::readType(getContext(), reader); + } + + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override { + return ::writeType(type, writer); + } +}; +``` + +along with defining the corresponding build rules to invoke generator +(`-gen-bytecode -bytecode-dialect="Quant"`). ## Defining an Extensible dialect @@ -452,7 +575,6 @@ rewriter.createOperation(state); ``` - ### Defining a type at runtime Contrary to types defined in C++ or in TableGen, types defined at runtime can diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt @@ -1,2 +1,6 @@ add_mlir_dialect(QuantOps quant) add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc) + +set(LLVM_TARGET_DEFINITIONS QuantDialectBytecode.td) +mlir_tablegen(QuantDialectBytecode.cpp.inc -gen-bytecode -bytecode-dialect="Quant") +add_public_tablegen_target(MLIRQuantDialectBytecodeIncGen) diff --git a/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Quant/QuantDialectBytecode.td @@ -0,0 +1,100 @@ +//===-- QuantBytecode.td - Quant bytecode defs -------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the Quant bytecode reader/writer definition file. +// +//===----------------------------------------------------------------------===// + +#ifndef QUANT_BYTECODE +#define QUANT_BYTECODE + +include "mlir/IR/BytecodeBase.td" + +def DoubleAPFloat: + WithParser <"succeeded(readDoubleAPFloat($_reader, $_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeAPFloatWithKnownSemantics(APFloat($_getter))", + WithType <"double">>>>; +def DoubleAPFloatList : List; + +let cType = "AnyQuantizedType" in { + +def AnyQuantizedType: DialectType<(type + VarInt:$flags, + Type:$storageType, + SignedVarInt:$storageTypeMin, + SignedVarInt:$storageTypeMax +)> { + let printerPredicate = "!$_val.getExpressedType()"; + let cBuilder = [{ + get<$_resultType>(context, flags, storageType, nullptr, + storageTypeMin, storageTypeMax) + }]; +} + +def AnyQuantizedTypeWithExpressedType: DialectType<(type + VarInt:$flags, + Type:$storageType, + Type:$expressedType, + SignedVarInt:$storageTypeMin, + SignedVarInt:$storageTypeMax +)> { + let printerPredicate = "!!$_val.getExpressedType()"; +} +} + +def CalibratedQuantizedType: DialectType<(type + Type:$expressedType, + DoubleAPFloat:$min, + DoubleAPFloat:$max +)>; + +def UniformQuantizedType: DialectType<(type + VarInt:$flags, + Type:$storageType, + Type:$expressedType, + DoubleAPFloat:$scale, + SignedVarInt:$zeroPoint, + SignedVarInt:$storageTypeMin, + SignedVarInt:$storageTypeMax +)>; + +def UniformQuantizedPerAxisType: DialectType<(type + VarInt:$flags, + Type:$storageType, + Type:$expressedType, + VarInt:$quantizedDimension, + SignedVarInt:$storageTypeMin, + SignedVarInt:$storageTypeMax, + Array:$scales, + Array:$zeroPoints +)> { + // Note: builder order differs from bytecode. + let cBuilder = [{ + get<$_resultType>(context, flags, storageType, expressedType, scales, + zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax) + }]; +} + +/// This enum contains marker codes used to indicate which attribute is +/// currently being decoded, and how it should be decoded. The order of these +/// codes should generally be unchanged, as any changes will inevitably break +/// compatibility with older bytecode. + +def QuantDialectTypes : DialectTypes<"Quant"> { + let elems = [ + ReservedOrDead, + AnyQuantizedType, + AnyQuantizedTypeWithExpressedType, + CalibratedQuantizedType, + UniformQuantizedType, + UniformQuantizedPerAxisType + ]; +} + +#endif // QUANT_BYTECODE \ No newline at end of file diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td --- a/mlir/include/mlir/IR/BytecodeBase.td +++ b/mlir/include/mlir/IR/BytecodeBase.td @@ -61,6 +61,9 @@ class WithGetter> : Bytecode; +// Representation of a bytecode element consisting of other bytecode atoms. +// E.g., it is effectively a struct of bytecode elements. Set the members by +// define a members dag: `dag members = (attr ...)`. class CompositeBytecode : WithType; class AttributeKind : diff --git a/mlir/lib/Dialect/Quant/IR/CMakeLists.txt b/mlir/lib/Dialect/Quant/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Quant/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/IR/CMakeLists.txt @@ -11,6 +11,7 @@ DEPENDS MLIRQuantOpsIncGen + MLIRQuantDialectBytecodeIncGen LINK_LIBS PUBLIC MLIRIR diff --git a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantDialectBytecode.cpp @@ -20,279 +20,51 @@ using namespace mlir; using namespace mlir::quant; -//===----------------------------------------------------------------------===// -// Encoding -//===----------------------------------------------------------------------===// - namespace { -namespace quant_encoding { -/// This enum contains marker codes used to indicate which type is currently -/// being decoded, and how it should be decoded. The order of these codes should -/// generally be unchanged, as any changes will inevitably break compatibility -/// with older bytecode. -enum TypeCode { - /// AnyQuantizedType { - /// flags: varint - /// storageType: Type - /// storageTypeMin: svarint - /// storageTypeMax: svarint - /// } - /// - kAnyQuantizedType = 1, - - /// AnyQuantizedType { - /// flags: varint - /// storageType: Type - /// expressedType: Type - /// storageTypeMin: svarint - /// storageTypeMax: svarint - /// } - /// - kAnyQuantizedTypeWithExpressedType = 2, - /// CalibratedQuantizedType { - /// expressedType: Type - /// min: APFloat - /// max: APFloat - /// } - /// - kCalibratedQuantizedType = 3, - - /// UniformQuantizedType { - /// flags: varint - /// storageType: Type - /// expressedType: Type - /// scale: APFloat - /// zeroPoint: svarint - /// storageTypeMin: svarint - /// storageTypeMax: svarint - /// } - /// - kUniformQuantizedType = 4, - - /// UniformQuantizedPerAxisType { - /// flags: varint - /// storageType: Type - /// expressedType: Type - /// quantizedDimension: varint - /// storageTypeMin: svarint - /// storageTypeMax: svarint - /// scale: APFloat[] - /// zeroPoint: svarint[] - /// } - /// - kUniformQuantizedPerAxisType = 5, -}; - -} // namespace quant_encoding -} // namespace +static LogicalResult readDoubleAPFloat(DialectBytecodeReader &reader, + double &val) { + auto valOr = + reader.readAPFloatWithKnownSemantics(llvm::APFloat::IEEEdouble()); + if (failed(valOr)) + return failure(); + val = valOr->convertToDouble(); + return success(); +} -//===----------------------------------------------------------------------===// -// QuantDialectBytecodeInterface -//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Quant/QuantDialectBytecode.cpp.inc" -namespace { /// This class implements the bytecode interface for the Quant dialect. struct QuantDialectBytecodeInterface : public BytecodeDialectInterface { QuantDialectBytecodeInterface(Dialect *dialect) : BytecodeDialectInterface(dialect) {} //===--------------------------------------------------------------------===// - // Types + // Attributes - Type readType(DialectBytecodeReader &reader) const override; - LogicalResult writeType(Type type, - DialectBytecodeWriter &writer) const override; + Attribute readAttribute(DialectBytecodeReader &reader) const override { + return ::readAttribute(getContext(), reader); + } - AnyQuantizedType readAnyQuantizedType(bool withExpressedType, - DialectBytecodeReader &reader) const; - void write(AnyQuantizedType type, DialectBytecodeWriter &writer) const; + LogicalResult writeAttribute(Attribute attr, + DialectBytecodeWriter &writer) const override { + return ::writeAttribute(attr, writer); + } - CalibratedQuantizedType - readCalibratedQuantizedType(DialectBytecodeReader &reader) const; - void write(CalibratedQuantizedType type, DialectBytecodeWriter &writer) const; + //===--------------------------------------------------------------------===// + // Types - UniformQuantizedType - readUniformQuantizedType(DialectBytecodeReader &reader) const; - void write(UniformQuantizedType type, DialectBytecodeWriter &writer) const; + Type readType(DialectBytecodeReader &reader) const override { + return ::readType(getContext(), reader); + } - UniformQuantizedPerAxisType - readUniformQuantizedPerAxisType(DialectBytecodeReader &reader) const; - void write(UniformQuantizedPerAxisType type, - DialectBytecodeWriter &writer) const; + LogicalResult writeType(Type type, + DialectBytecodeWriter &writer) const override { + return ::writeType(type, writer); + } }; } // namespace void quant::detail::addBytecodeInterface(QuantizationDialect *dialect) { dialect->addInterfaces(); } - -//===----------------------------------------------------------------------===// -// Types -//===----------------------------------------------------------------------===// - -Type QuantDialectBytecodeInterface::readType( - DialectBytecodeReader &reader) const { - uint64_t code; - if (failed(reader.readVarInt(code))) - return Type(); - - switch (code) { - case quant_encoding::kAnyQuantizedType: - return readAnyQuantizedType(/*withExpressedType=*/false, reader); - case quant_encoding::kAnyQuantizedTypeWithExpressedType: - return readAnyQuantizedType(/*withExpressedType=*/true, reader); - case quant_encoding::kCalibratedQuantizedType: - return readCalibratedQuantizedType(reader); - case quant_encoding::kUniformQuantizedType: - return readUniformQuantizedType(reader); - case quant_encoding::kUniformQuantizedPerAxisType: - return readUniformQuantizedPerAxisType(reader); - - default: - reader.emitError() << "unknown builtin type code: " << code; - return Type(); - } -} - -LogicalResult -QuantDialectBytecodeInterface::writeType(Type type, - DialectBytecodeWriter &writer) const { - return TypeSwitch(type) - .Case( - [&](auto attr) { return write(attr, writer), success(); }) - .Default([&](Type) { return failure(); }); -} - -AnyQuantizedType QuantDialectBytecodeInterface::readAnyQuantizedType( - bool withExpressedType, DialectBytecodeReader &reader) const { - uint64_t flags; - Type storageType, expressedType; - int64_t storageTypeMin, storageTypeMax; - if (failed(reader.readVarInt(flags)) || - failed(reader.readType(storageType)) || - (withExpressedType && failed(reader.readType(expressedType))) || - failed(reader.readSignedVarInt(storageTypeMin)) || - failed(reader.readSignedVarInt(storageTypeMax))) - return reader.emitError("invalid AnyQuantizedType"), AnyQuantizedType(); - return AnyQuantizedType::get(flags, storageType, expressedType, - storageTypeMin, storageTypeMax); -} -void QuantDialectBytecodeInterface::write(AnyQuantizedType type, - DialectBytecodeWriter &writer) const { - if (type.getExpressedType()) - writer.writeVarInt(quant_encoding::kAnyQuantizedTypeWithExpressedType); - else - writer.writeVarInt(quant_encoding::kAnyQuantizedType); - - writer.writeVarInt(type.getFlags()); - writer.writeType(type.getStorageType()); - if (type.getExpressedType()) - writer.writeType(type.getExpressedType()); - writer.writeSignedVarInt(type.getStorageTypeMin()); - writer.writeSignedVarInt(type.getStorageTypeMax()); -} - -CalibratedQuantizedType -QuantDialectBytecodeInterface::readCalibratedQuantizedType( - DialectBytecodeReader &reader) const { - Type expressedType; - FailureOr min, max; - if (failed(reader.readType(expressedType)) || - failed(min = reader.readAPFloatWithKnownSemantics( - llvm::APFloat::IEEEdouble())) || - failed(max = reader.readAPFloatWithKnownSemantics( - llvm::APFloat::IEEEdouble()))) - return reader.emitError("invalid CalibratedQuantizedType"), - CalibratedQuantizedType(); - return CalibratedQuantizedType::get(expressedType, min->convertToDouble(), - max->convertToDouble()); -} -void QuantDialectBytecodeInterface::write(CalibratedQuantizedType type, - DialectBytecodeWriter &writer) const { - writer.writeVarInt(quant_encoding::kCalibratedQuantizedType); - writer.writeType(type.getExpressedType()); - writer.writeAPFloatWithKnownSemantics(APFloat(type.getMin())); - writer.writeAPFloatWithKnownSemantics(APFloat(type.getMax())); -} - -UniformQuantizedType QuantDialectBytecodeInterface::readUniformQuantizedType( - DialectBytecodeReader &reader) const { - uint64_t flags; - Type storageType, expressedType; - FailureOr scale; - int64_t zeroPoint, storageTypeMin, storageTypeMax; - if (failed(reader.readVarInt(flags)) || - failed(reader.readType(storageType)) || - failed(reader.readType(expressedType)) || - failed(scale = reader.readAPFloatWithKnownSemantics( - llvm::APFloat::IEEEdouble())) || - failed(reader.readSignedVarInt(zeroPoint)) || - failed(reader.readSignedVarInt(storageTypeMin)) || - failed(reader.readSignedVarInt(storageTypeMax))) - return reader.emitError("invalid UniformQuantizedType"), - UniformQuantizedType(); - return UniformQuantizedType::get(flags, storageType, expressedType, - scale->convertToDouble(), zeroPoint, - storageTypeMin, storageTypeMax); -} -void QuantDialectBytecodeInterface::write(UniformQuantizedType type, - DialectBytecodeWriter &writer) const { - writer.writeVarInt(quant_encoding::kUniformQuantizedType); - writer.writeVarInt(type.getFlags()); - writer.writeType(type.getStorageType()); - writer.writeType(type.getExpressedType()); - writer.writeAPFloatWithKnownSemantics(APFloat(type.getScale())); - writer.writeSignedVarInt(type.getZeroPoint()); - writer.writeSignedVarInt(type.getStorageTypeMin()); - writer.writeSignedVarInt(type.getStorageTypeMax()); -} - -UniformQuantizedPerAxisType -QuantDialectBytecodeInterface::readUniformQuantizedPerAxisType( - DialectBytecodeReader &reader) const { - uint64_t flags; - Type storageType, expressedType; - SmallVector scales; - SmallVector zeroPoints; - uint64_t quantizedDimension; - int64_t storageTypeMin, storageTypeMax; - - auto scalesRead = [&](double &val) -> LogicalResult { - FailureOr fl = - reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble()); - if (succeeded(fl)) { - val = fl->convertToDouble(); - return success(); - } - return failure(); - }; - - if (failed(reader.readVarInt(flags)) || - failed(reader.readType(storageType)) || - failed(reader.readType(expressedType)) || - failed(reader.readList(scales, scalesRead)) || - failed(reader.readSignedVarInts(zeroPoints)) || - failed(reader.readVarInt(quantizedDimension)) || - failed(reader.readSignedVarInt(storageTypeMin)) || - failed(reader.readSignedVarInt(storageTypeMax))) - return reader.emitError("invalid UniformQuantizedPerAxisType"), - UniformQuantizedPerAxisType(); - return UniformQuantizedPerAxisType::get( - flags, storageType, expressedType, scales, zeroPoints, - (int32_t)quantizedDimension, storageTypeMin, storageTypeMax); -} -void QuantDialectBytecodeInterface::write(UniformQuantizedPerAxisType type, - DialectBytecodeWriter &writer) const { - writer.writeVarInt(quant_encoding::kUniformQuantizedType); - writer.writeVarInt(type.getFlags()); - writer.writeType(type.getStorageType()); - writer.writeType(type.getExpressedType()); - writer.writeList(type.getScales(), [&](double val) { - writer.writeAPFloatWithKnownSemantics(APFloat(val)); - }); - writer.writeSignedVarInts(type.getZeroPoints()); - writer.writeVarInt(type.getQuantizedDimension()); - writer.writeSignedVarInt(type.getStorageTypeMin()); - writer.writeSignedVarInt(type.getStorageTypeMax()); -} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -8515,6 +8515,25 @@ deps = [":QuantizationOpsTdFiles"], ) +gentbl_cc_library( + name = "QuantDialectBytecodeGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-bytecode", + "-bytecode-dialect=Quant", + ], + "include/mlir/Dialect/Quant/QuantDialectBytecode.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Quant/QuantDialectBytecode.td", + deps = [ + ":BytecodeTdFiles", + ], +) + cc_library( name = "QuantOps", srcs = [ @@ -8540,6 +8559,7 @@ ":IR", ":InferTypeOpInterface", ":Pass", + ":QuantDialectBytecodeGen", ":QuantOpsIncGen", ":SideEffectInterfaces", ":Support", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/mlir-tblgen/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/mlir-tblgen/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/mlir-tblgen/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/mlir-tblgen/BUILD.bazel @@ -20,6 +20,8 @@ "//mlir:include/mlir/Dialect/LLVMIR/LLVMDialect.td", "//mlir:include/mlir/Dialect/LLVMIR/LLVMInterfaces.td", "//mlir:include/mlir/Dialect/LLVMIR/LLVMOpBase.td", + "//mlir:include/mlir/IR/BuiltinDialectBytecode.td", + "//mlir:include/mlir/IR/BytecodeBase.td", "//mlir:include/mlir/IR/OpBase.td", "//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "//mlir:include/mlir/Interfaces/SideEffectInterfaces.td",