diff --git a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp index 4d5db81e326d..991d7c179f90 100644 --- a/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/UniformSupport.cpp @@ -1,102 +1,102 @@ //===- UniformSupport.cpp - Support utilities for uniform quant -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Quant/UniformSupport.h" #include "mlir/IR/StandardTypes.h" #include using namespace mlir; using namespace mlir::quant; static bool isQuantizablePrimitiveType(Type inputType) { return inputType.isa(); } const ExpressedToQuantizedConverter ExpressedToQuantizedConverter::forInputType(Type inputType) { switch (inputType.getKind()) { default: if (isQuantizablePrimitiveType(inputType)) { // Supported primitive type (which just is the expressed type). return ExpressedToQuantizedConverter{inputType, inputType}; } // Unsupported. return ExpressedToQuantizedConverter{inputType, nullptr}; case StandardTypes::RankedTensor: case StandardTypes::UnrankedTensor: case StandardTypes::Vector: { Type elementType = inputType.cast().getElementType(); if (!isQuantizablePrimitiveType(elementType)) { // Unsupported. return ExpressedToQuantizedConverter{inputType, nullptr}; } return ExpressedToQuantizedConverter{ inputType, inputType.cast().getElementType()}; } } } Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { assert(expressedType && "convert() on unsupported conversion"); switch (inputType.getKind()) { default: - if (isQuantizablePrimitiveType(elementalType)) { - // For primitives, just use the new elemental type. + if (elementalType.getExpressedType() == expressedType) { + // If the expressed types match, just use the new elemental type. return elementalType; } // Unsupported. return nullptr; case StandardTypes::RankedTensor: return RankedTensorType::get(inputType.cast().getShape(), elementalType); case StandardTypes::UnrankedTensor: return UnrankedTensorType::get(elementalType); case StandardTypes::Vector: return VectorType::get(inputType.cast().getShape(), elementalType); } } ElementsAttr UniformQuantizedPerAxisValueConverter::convert(Attribute realValue) { if (auto attr = realValue.dyn_cast()) { return convert(attr); } // TODO(fengliuai): handles sparse elements attribute return nullptr; } DenseElementsAttr UniformQuantizedPerAxisValueConverter::convert(DenseFPElementsAttr attr) { // Creates the converter for each chunk. Normally the size of the // quantization dim is 3, so we can cache all the converters. ShapedType type = attr.getType(); size_t dimSize = type.getDimSize(quantizationDim); if (dimSize != scales.size()) { return {}; } SmallVector converters; converters.reserve(dimSize); for (int i = 0, e = dimSize; i != e; ++i) { converters.push_back(getPerChunkConverter(i)); } // Scan the elements of the dense elements attributes and quantize them by // using the right quantization parameters. int64_t flattenIndex = 0; auto shape = type.getShape(); int64_t chunkSize = std::accumulate(std::next(shape.begin(), quantizationDim + 1), shape.end(), 1, std::multiplies()); Type newElementType = IntegerType::get(storageBitWidth, attr.getContext()); return attr.mapValues(newElementType, [&](const APFloat &old) { int chunkIndex = (flattenIndex++) / chunkSize; return converters[chunkIndex % dimSize].quantizeFloatToInt(old); }); }