diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -28,6 +28,7 @@ struct AnyQuantizedTypeStorage; struct UniformQuantizedTypeStorage; struct UniformQuantizedPerAxisTypeStorage; +struct CalibratedQuantizedTypeStorage; } // namespace detail @@ -371,6 +372,34 @@ } }; +/// A quantized type that infers its range from given min/max values. +/// +/// Typical syntax: +/// quant.calibrated> +class CalibratedQuantizedType + : public Type::TypeBase { +public: + using Base::Base; + + /// Gets an instance of the type with all parameters specified but not + /// checked. + static CalibratedQuantizedType get(Type expressedType, double min, + double max); + + /// Gets an instance of the type with all specified parameters checked. + /// Returns a nullptr convertible type on failure. + static CalibratedQuantizedType getChecked(Type expressedType, double min, + double max, Location location); + + /// Verifies construction invariants and issues errors/warnings. + static LogicalResult verifyConstructionInvariants(Location loc, + Type expressedType, + double min, double max); + double getMin() const; + double getMax() const; +}; + } // namespace quant } // namespace mlir diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -24,7 +24,7 @@ using namespace mlir::quant::detail; void QuantizationDialect::initialize() { - addTypes(); addOperations< #define GET_OP_LIST diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -354,3 +354,32 @@ int32_t UniformQuantizedPerAxisType::getQuantizedDimension() const { return getImpl()->quantizedDimension; } + +CalibratedQuantizedType CalibratedQuantizedType::get(Type expressedType, + double min, double max) { + return Base::get(expressedType.getContext(), expressedType, min, max); +} + +CalibratedQuantizedType CalibratedQuantizedType::getChecked(Type expressedType, + double min, + double max, + Location location) { + return Base::getChecked(location, expressedType, min, max); +} + +LogicalResult CalibratedQuantizedType::verifyConstructionInvariants( + Location loc, Type expressedType, double min, double max) { + // Verify that the expressed type is floating point. + // If this restriction is ever eliminated, the parser/printer must be + // extended. + if (!expressedType.isa()) + return emitError(loc, "expressed type must be floating point"); + if (max <= min) + return emitError(loc, "illegal min and max: (") << min << ":" << max << ")"; + + return success(); +} + +double CalibratedQuantizedType::getMin() const { return getImpl()->min; } + +double CalibratedQuantizedType::getMax() const { return getImpl()->max; } diff --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/lib/Dialect/Quant/IR/TypeDetail.h --- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h +++ b/mlir/lib/Dialect/Quant/IR/TypeDetail.h @@ -253,6 +253,56 @@ int32_t quantizedDimension; }; +struct CalibratedQuantizedTypeStorage : public QuantizedTypeStorage { + struct KeyTy { + KeyTy(Type expressedType, double min, double max) + : expressedType(expressedType), min(min), max(max) {} + // Floating point type that the quantized type approximates. + Type expressedType; + + double min; + double max; + + // Check for equality of two structures that share KeyTy data members + // (by name). + template + static bool genericIsEqual(const T &lhs, const U &rhs) { + return lhs.expressedType == rhs.expressedType && lhs.min == rhs.min && + lhs.max == rhs.max; + } + + bool operator==(const KeyTy &other) const { + return genericIsEqual(*this, other); + } + + unsigned getHashValue() const { + int64_t minBits = llvm::bit_cast(min); + int64_t maxBits = llvm::bit_cast(max); + return llvm::hash_combine(expressedType, minBits, maxBits); + } + }; + + CalibratedQuantizedTypeStorage(const KeyTy &key) + : QuantizedTypeStorage(0, NoneType(), key.expressedType, 0, 0), + min(key.min), max(key.max) {} + + bool operator==(const KeyTy &key) const { + return KeyTy::genericIsEqual(*this, key); + } + + /// Construction. + static CalibratedQuantizedTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + CalibratedQuantizedTypeStorage(key); + } + + static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + + double min; + double max; +}; + } // namespace detail } // namespace quant } // namespace mlir diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -91,9 +91,28 @@ return success(); } -/// Parses a UniformQuantizedType. +static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, + double &min, double &max) { + auto typeLoc = parser.getCurrentLocation(); + FloatType type; + + if (failed(parser.parseType(type))) { + parser.emitError(typeLoc, "expecting float expressed type"); + return nullptr; + } + + // Calibrated min and max values. + if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() || + parser.parseFloat(max) || parser.parseGreater()) { + parser.emitError(typeLoc, "calibrated values must be present"); + return nullptr; + } + return type; +} + +/// Parses an AnyQuantizedType. /// -/// uniform_per_layer ::= `any<` storage-spec (expressed-type-spec)?`>` +/// any ::= `any<` storage-spec (expressed-type-spec)?`>` /// storage-spec ::= storage-type (`<` storage-range `>`)? /// storage-range ::= integer-literal `:` integer-literal /// storage-type ::= (`i` | `u`) integer-literal @@ -269,6 +288,34 @@ storageTypeMin, storageTypeMax, loc); } +/// Parses an CalibratedQuantizedType. +/// +/// calibrated ::= `calibrated<` expressed-spec `>` +/// expressed-spec ::= expressed-type `<` calibrated-range `>` +/// expressed-type ::= `f` integer-literal +/// calibrated-range ::= float-literal `:` float-literal +static Type parseCalibratedType(DialectAsmParser &parser, Location loc) { + FloatType expressedType; + double min; + double max; + + // Type specification. + if (parser.parseLess()) + return nullptr; + + // Expressed type. + expressedType = parseExpressedTypeAndRange(parser, min, max); + if (!expressedType) { + return nullptr; + } + + if (parser.parseGreater()) { + return nullptr; + } + + return CalibratedQuantizedType::getChecked(expressedType, min, max, loc); +} + /// Parse a type registered to this dialect. Type QuantizationDialect::parseType(DialectAsmParser &parser) const { Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); @@ -282,6 +329,8 @@ return parseUniformType(parser, loc); if (typeNameSpelling == "any") return parseAnyType(parser, loc); + if (typeNameSpelling == "calibrated") + return parseCalibratedType(parser, loc); parser.emitError(parser.getNameLoc(), "unknown quantized type " + typeNameSpelling); @@ -318,7 +367,7 @@ } } -/// Helper that prints a UniformQuantizedType. +/// Helper that prints a AnyQuantizedType. static void printAnyQuantizedType(AnyQuantizedType type, DialectAsmPrinter &out) { out << "any<"; @@ -363,6 +412,14 @@ out << "}>"; } +/// Helper that prints a CalibratedQuantizedType. +static void printCalibratedQuantizedType(CalibratedQuantizedType type, + DialectAsmPrinter &out) { + out << "calibrated<" << type.getExpressedType(); + out << "<" << type.getMin() << ", " << type.getMax() << ">"; + out << ">"; +} + /// Print a type registered to this dialect. void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { if (auto anyType = type.dyn_cast()) @@ -371,6 +428,8 @@ printUniformQuantizedType(uniformType, os); else if (auto perAxisType = type.dyn_cast()) printUniformQuantizedPerAxisType(perAxisType, os); + else if (auto calibratedType = type.dyn_cast()) + printCalibratedQuantizedType(calibratedType, os); else llvm_unreachable("Unhandled quantized type"); } diff --git a/mlir/test/Dialect/Quant/parse-calibrated-invalid.mlir b/mlir/test/Dialect/Quant/parse-calibrated-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Quant/parse-calibrated-invalid.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- +// Unrecognized token: missing calibrated type maximum +// expected-error@+2 {{calibrated values must be present}} +// expected-error@+1 {{expected ':'}} +!qalias = type !quant.calibrated> + +// ----- +// Unrecognized token: missing closing angle bracket +// expected-error@+1 {{expected '>'}} +!qalias = type !quant<"calibrated"> + +// ----- +// Unrecognized expressed type: integer type +// expected-error@+2 {{invalid kind of type specified}} +// expected-error@+1 {{expecting float expressed type}} +!qalias = type !quant.calibrated> + +// ----- +// Illegal storage min/max: max - min < 0 +// expected-error@+1 {{illegal min and max: (1.000000e+00:-1.000000e+00)}} +!qalias = type !quant.calibrated> + +// ----- +// Illegal storage min/max: max - min == 0 +// expected-error@+1 {{illegal min and max: (1.000000e+00:1.000000e+00)}} +!qalias = type !quant.calibrated> diff --git a/mlir/test/Dialect/Quant/parse-calibrated.mlir b/mlir/test/Dialect/Quant/parse-calibrated.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Quant/parse-calibrated.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file | FileCheck %s + +// ----- +// CHECK-LABEL: parseCalibrated +// CHECK: !quant.calibrated +!qalias = type !quant.calibrated> +func @parseCalibrated() -> !qalias { + %0 = "foo"() : () -> !qalias + return %0 : !qalias +}