diff --git a/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h --- a/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h +++ b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h @@ -191,9 +191,12 @@ ElementsAttr convert(Attribute realValue); private: - /// Quantize an DenseFPElementsAttr by the quantization parameters. + /// Quantize a DenseFPElementsAttr by the quantization parameters. DenseElementsAttr convert(DenseFPElementsAttr attr); + /// Quantize a SparseElementsAttr by the quantization parameters. + ElementsAttr convert(SparseElementsAttr attr); + /// Get a uniform converter for the index-th chunk along the quantizationDim. /// All the elements in this chunk is quantized by the returned converter. UniformQuantizedValueConverter getPerChunkConverter(int index) const { @@ -203,6 +206,11 @@ return converter; } + int64_t getChunkSize(ArrayRef shape) { + return std::accumulate(std::next(shape.begin(), quantizationDim + 1), + shape.end(), 1, std::multiplies()); + } + const ArrayRef scales; const ArrayRef zeroPoints; const APFloat clampMin; diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -235,6 +235,9 @@ /// shape or if its elemental type does not have a known bit width. int64_t getSizeInBits() const; + /// Returns the same kind of type with the same shape and new element type. + ShapedType withElementType(Type newElementType) const; + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type) { return type.getKind() == StandardTypes::Vector || diff --git a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp --- a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp @@ -35,8 +35,7 @@ // Unsupported. return ExpressedToQuantizedConverter{inputType, nullptr}; } - return ExpressedToQuantizedConverter{ - inputType, inputType.cast().getElementType()}; + return ExpressedToQuantizedConverter{inputType, elementType}; } } } @@ -44,23 +43,15 @@ Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { assert(expressedType && "convert() on unsupported conversion"); - switch (inputType.getKind()) { - default: - if (isQuantizablePrimitiveType(elementalType)) { - // For primitives, just use the new elemental type. - return elementalType; - } - // Unsupported. - return nullptr; - case StandardTypes::RankedTensor: - return RankedTensorType::get(inputType.cast().getShape(), - elementalType); - case StandardTypes::UnrankedTensor: - return UnrankedTensorType::get(elementalType); - case StandardTypes::Vector: - return VectorType::get(inputType.cast().getShape(), - elementalType); + if (auto shapedInputType = inputType.dyn_cast()) + return shapedInputType.withElementType(elementalType); + + if (isQuantizablePrimitiveType(elementalType)) { + // For primitives, just use the new elemental type. + return elementalType; } + // Unsupported. + return nullptr; } ElementsAttr @@ -68,7 +59,9 @@ if (auto attr = realValue.dyn_cast()) { return convert(attr); } - // TODO(fengliuai): handles sparse elements attribute + if (auto attr = realValue.dyn_cast()) { + return convert(attr); + } return nullptr; } @@ -90,13 +83,106 @@ // Scan the elements of the dense elements attributes and quantize them by // using the right quantization parameters. int64_t flattenIndex = 0; - auto shape = type.getShape(); - int64_t chunkSize = - std::accumulate(std::next(shape.begin(), quantizationDim + 1), - shape.end(), 1, std::multiplies()); + int64_t chunkSize = getChunkSize(type.getShape()); Type newElementType = IntegerType::get(storageBitWidth, attr.getContext()); return attr.mapValues(newElementType, [&](const APFloat &old) { int chunkIndex = (flattenIndex++) / chunkSize; return converters[chunkIndex % dimSize].quantizeFloatToInt(old); }); } + +/// Returns the 1 dimensional flattened row-major index from the given +/// multi-dimensional index. +static uint64_t getFlattenedIndex(ArrayRef shape, + ArrayRef index) { + // Duplicates ElementsAttr::getFlattenedIndex logic + size_t rank = shape.size(); + uint64_t valueIndex = 0; + uint64_t dimMultiplier = 1; + for (int i = rank - 1; i >= 0; --i) { + valueIndex += index[i] * dimMultiplier; + dimMultiplier *= shape[i]; + } + return valueIndex; +} + +ElementsAttr +UniformQuantizedPerAxisValueConverter::convert(SparseElementsAttr attr) { + if (auto valuesAttr = attr.getValues().dyn_cast()) { + ShapedType type = attr.getType(); + size_t dimSize = type.getDimSize(quantizationDim); + if (dimSize != scales.size()) { + return {}; + } + SmallVector converters; + converters.reserve(dimSize); + for (int i = 0, e = dimSize; i != e; ++i) { + converters.push_back(getPerChunkConverter(i)); + } + + Type newElementType = IntegerType::get(storageBitWidth, attr.getContext()); + auto numValues = static_cast(valuesAttr.getType().getNumElements()); + DenseIntElementsAttr indicesAttr = attr.getIndices(); + SmallVector newValues; + // if all zero-points are 0, emit a sparse attribute with the same indices + if (llvm::all_of(zeroPoints, + [](int64_t zeroPoint) { return zeroPoint == 0; })) { + newValues.reserve(numValues); + for (size_t i = 0; i < numValues; ++i) { + auto value = valuesAttr.getValue(i); + uint64_t indexInQuantDim = + indicesAttr + .getValue({i, static_cast(quantizationDim)}) + .getZExtValue(); + APInt quantized = converters[indexInQuantDim].quantizeFloatToInt(value); + newValues.push_back(quantized); + } + + auto newValuesAttr = DenseIntElementsAttr::get( + valuesAttr.getType().withElementType(newElementType), newValues); + return SparseElementsAttr::get(type.withElementType(newElementType), + indicesAttr, newValuesAttr); + } + + // there are non-0 zero-points + // TODO optimize size in this case (e.g. when _most_ zero-points are 0) + auto numElements = static_cast(type.getNumElements()); + newValues.reserve(numElements); + ArrayRef shape = type.getShape(); + size_t chunkSize = getChunkSize(shape); + size_t numChunks = numElements / chunkSize; + // initialize zero-points vector + SmallVector apZeroPoints; + apZeroPoints.reserve(dimSize); + APFloat zero = + APFloat::getZero(valuesAttr.getValue(0).getSemantics()); + for (size_t i = 0; i < dimSize; ++i) { + // can't just use APInt(zeroPoints[i]) because it can have wrong bitwidth + apZeroPoints.push_back(converters[i].quantizeFloatToInt(zero)); + } + // initialize newValues vector from zero-points + for (size_t chunkIndex = 0; chunkIndex < numChunks; ++chunkIndex) { + APInt zeroPoint = apZeroPoints[chunkIndex % dimSize]; + for (size_t i = 0; i < chunkSize; ++i) + newValues.push_back(zeroPoint); + } + // finally update from valuesAttr + size_t rank = shape.size(); + SmallVector index(rank); + auto indicesIter = indicesAttr.getValues().begin(); + for (size_t i = 0; i < numValues; ++i) { + auto value = valuesAttr.getValue(i); + for (size_t j = 0; j < rank; ++j) { + index[j] = (*indicesIter++).getZExtValue(); + } + APInt quantized = + converters[index[quantizationDim]].quantizeFloatToInt(value); + newValues[getFlattenedIndex(shape, index)] = quantized; + } + + return DenseIntElementsAttr::get(type.withElementType(newElementType), + newValues); + } + + return {}; +} diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -192,6 +192,26 @@ return hasStaticShape() && getShape() == shape; } +ShapedType ShapedType::withElementType(Type newElementType) const { + switch (getKind()) { + case StandardTypes::Kind::Vector: + return VectorType::get(cast().getShape(), newElementType); + case StandardTypes::Kind::RankedTensor: + return RankedTensorType::get(cast().getShape(), + newElementType); + case StandardTypes::Kind::UnrankedTensor: + return UnrankedTensorType::get(newElementType); + case StandardTypes::Kind::MemRef: + return MemRefType::Builder(cast()) + .setElementType(newElementType); + case StandardTypes::Kind::UnrankedMemRef: + return UnrankedMemRefType::get(newElementType, + cast().getMemorySpace()); + default: + return {}; + } +} + //===----------------------------------------------------------------------===// // VectorType //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/QuantOps/convert-const.mlir b/mlir/test/Dialect/QuantOps/convert-const.mlir --- a/mlir/test/Dialect/QuantOps/convert-const.mlir +++ b/mlir/test/Dialect/QuantOps/convert-const.mlir @@ -191,3 +191,30 @@ return %2, %4 : tensor<2x3xf32>, tensor<2x3xf32> } + +// ----- +// Verifies per-axis quantization results for sparse. +// CHECK-LABEL: per_axis_sparse_quantization +func @per_axis_sparse_quantization() -> (tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32>) { + +// CHECK-NEXT: %[[cst:.*]] = constant sparse<{{\[}}[0, 0], [0, 1], [1, 1], [1, 2]], [-128, -64, 1, 2]> : tensor<2x3xi8> +// CHECK-NEXT: %[[cst0:.*]] = constant dense<{{\[}}[-128, -1, 1], [127, 1, 3]]> : tensor<2x3xi8> +// CHECK-NEXT: %[[cst1:.*]] = constant dense<{{\[}}[-128, 64, 127], [0, 1, 2]]> : tensor<2x3xi8> +// CHECK: "quant.scast"(%[[cst]]) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform> +// CHECK: "quant.scast"(%cst_0) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform> +// CHECK: "quant.scast"(%[[cst1]]) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform> + + %cst = constant sparse<[[0, 0], [0, 1], [1, 1], [1, 2]], [-2.0, -0.5, 1.0, 2.0]> : tensor<2x3xf32> + %1 = "quant.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> + %2 = "quant.dcast"(%1) : (tensor<2x3x!quant.uniform>) -> (tensor<2x3xf32>) + + %cst0 = constant sparse<[[0, 0], [0, 1], [1, 1], [1, 2]], [-2.0, -0.5, 1.0, 2.0]> : tensor<2x3xf32> + %3 = "quant.qcast"(%cst0) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> + %4 = "quant.dcast"(%3) : (tensor<2x3x!quant.uniform>) -> (tensor<2x3xf32>) + + %cst1 = constant sparse<[[0, 0], [0, 1], [1, 1], [1, 2]], [-2.0, -0.5, 1.0, 2.0]> : tensor<2x3xf32> + %5 = "quant.qcast"(%cst1) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> + %6 = "quant.dcast"(%5) : (tensor<2x3x!quant.uniform>) -> (tensor<2x3xf32>) + + return %2, %4, %6 : tensor<2x3xf32>, tensor<2x3xf32>, tensor<2x3xf32> +}