diff --git a/mlir/include/mlir/Quantizer/Support/Statistics.h b/mlir/include/mlir/Quantizer/Support/Statistics.h --- a/mlir/include/mlir/Quantizer/Support/Statistics.h +++ b/mlir/include/mlir/Quantizer/Support/Statistics.h @@ -27,11 +27,25 @@ double mean = 0; double variance = 0; + int64_t sampleSizePerAxis = 0; + SmallVector minValuePerAxis; + SmallVector maxValuePerAxis; + SmallVector meanPerAxis; + SmallVector variancePerAxis; + TensorAxisStatistics() {} TensorAxisStatistics(int64_t sampleSize, double minValue, double maxValue, double mean, double variance) : sampleSize(sampleSize), minValue(minValue), maxValue(maxValue), mean(mean), variance(variance) {} + TensorAxisStatistics(int64_t sampleSize, ArrayRef minValues, + ArrayRef maxValues, ArrayRef means, + ArrayRef variances) + : sampleSizePerAxis(sampleSize), + minValuePerAxis(minValues.begin(), minValues.end()), + maxValuePerAxis(maxValues.begin(), maxValues.end()), + meanPerAxis(means.begin(), means.end()), + variancePerAxis(variances.begin(), variances.end()) {} void clear() { *this = TensorAxisStatistics(); } }; @@ -70,7 +84,11 @@ bool get(TensorAxisStatistics &stats) const override; - // TODO: Implement per-axis. + bool supportsPerAxis() const override; + + unsigned getAxisCount() const override; + + bool getForAxis(unsigned axis, TensorAxisStatistics &stats) const override; private: Attribute attr; diff --git a/mlir/lib/Quantizer/Support/Statistics.cpp b/mlir/lib/Quantizer/Support/Statistics.cpp --- a/mlir/lib/Quantizer/Support/Statistics.cpp +++ b/mlir/lib/Quantizer/Support/Statistics.cpp @@ -50,12 +50,53 @@ } } +static void collectElementsStatisticsDimForAxis( + unsigned axis, ElementsAttr attr, unsigned numElements, + ArrayRef shape, SmallVectorImpl &indices, uint64_t dim, + TensorAxisStatistics &statistics) { + // Recursive terminating condition. + if (dim >= shape.size()) + return; + + // Axis is passed separately + if (dim == axis) { + collectElementsStatisticsDimForAxis(axis, attr, numElements, shape, indices, + dim + 1, statistics); + return; + } + + // Go to last not axis dim + if (dim < (shape.size() - 2) || + (dim == (shape.size() - 2) && axis != (shape.size() - 1))) { + // Recurse past dim. + for (uint64_t i = 0, s = shape[dim]; i < s; ++i) { + indices[dim] = i; + collectElementsStatisticsDimForAxis(axis, attr, numElements, shape, + indices, dim + 1, statistics); + } + return; + } + + // Pass axis + uint64_t axisSize = shape[axis]; + for (uint64_t axisIdx = 0; axisIdx < axisSize; ++axisIdx) { + indices[axis] = axisIdx; + // Collection dim. + for (uint64_t i = 0, s = shape[dim]; i < s; ++i) { + indices[dim] = i; + double value = attr.getValue(indices).getValueAsDouble(); + statistics.minValuePerAxis[axisIdx] = + std::min(statistics.minValuePerAxis[axisIdx], value); + statistics.maxValuePerAxis[axisIdx] = + std::max(statistics.maxValuePerAxis[axisIdx], value); + statistics.meanPerAxis[axisIdx] += value / numElements; + // TODO: Calculate a running variance. + } + } +} + static bool getElementsStatistics(ElementsAttr attr, TensorAxisStatistics &statistics) { - statistics.clear(); - statistics.minValue = std::numeric_limits::infinity(); - statistics.maxValue = -std::numeric_limits::infinity(); - ShapedType sType = attr.getType(); if (!sType.hasStaticShape()) return false; @@ -67,6 +108,11 @@ indices.resize(sType.getRank()); ArrayRef shape = sType.getShape(); + statistics.minValue = std::numeric_limits::infinity(); + statistics.maxValue = -std::numeric_limits::infinity(); + statistics.mean = 0; + statistics.variance = 0; + auto numElements = sType.getNumElements(); collectElementsStatisticsDim(attr, numElements, shape, indices, 0, statistics); @@ -75,6 +121,35 @@ return true; } +static bool getElementsStatisticsForAxis(unsigned axis, ElementsAttr attr, + TensorAxisStatistics &statistics) { + ShapedType sType = attr.getType(); + if (!sType.hasStaticShape() || axis >= sType.getRank()) + return false; + Type elementTy = sType.getElementType(); + if (!elementTy.isa()) + return false; + + SmallVector indices; + indices.resize(sType.getRank()); + ArrayRef shape = sType.getShape(); + + uint64_t axisSize = shape[axis]; + statistics.minValuePerAxis.assign(axisSize, + std::numeric_limits::infinity()); + statistics.maxValuePerAxis.assign(axisSize, + -std::numeric_limits::infinity()); + statistics.meanPerAxis.assign(axisSize, 0); + statistics.variancePerAxis.assign(axisSize, 0); + + uint64_t numElements = sType.getNumElements() / shape[axis]; + collectElementsStatisticsDimForAxis(axis, attr, numElements, shape, indices, + 0, statistics); + statistics.sampleSizePerAxis = numElements; + + return true; +} + bool AttributeTensorStatistics::get(TensorAxisStatistics &stats) const { if (FloatAttr floatAttr = attr.dyn_cast()) { double value = floatAttr.getValueAsDouble(); @@ -86,10 +161,41 @@ return false; } +bool AttributeTensorStatistics::supportsPerAxis() const { + if (auto eltAttr = attr.dyn_cast()) + return eltAttr.getType().getRank() > 1; + return false; +} + +unsigned AttributeTensorStatistics::getAxisCount() const { + if (!supportsPerAxis()) + return 0; + return attr.cast().getType().getRank(); +} + +bool AttributeTensorStatistics::getForAxis(unsigned axis, + TensorAxisStatistics &stats) const { + if (!supportsPerAxis()) + return false; + auto eltAttr = attr.cast(); + return getElementsStatisticsForAxis(axis, eltAttr, stats); +} + raw_ostream &mlir::quantizer::operator<<(raw_ostream &os, const TensorAxisStatistics &stats) { - os << "STATS[sampleSize=" << stats.sampleSize << ", min=" << stats.minValue - << ", maxValue=" << stats.maxValue << ", mean=" << stats.mean - << ", variance=" << stats.variance << "]"; + os << "STATS[sampleSizeLayer=" << stats.sampleSize + << ", minValueLayer=" << stats.minValue + << ", maxValueLayer=" << stats.maxValue << ", meanLayer=" << stats.mean + << ", varianceLayer=" << stats.variance + << ", sampleSizePerAxis=" << stats.sampleSizePerAxis << ", statsPerAxis={"; + for (unsigned i = 0, n = stats.minValuePerAxis.size(); i < n; ++i) { + os << "minValue=" << stats.minValuePerAxis[i] + << ", maxValue=" << stats.maxValuePerAxis[i] + << ", mean=" << stats.meanPerAxis[i] + << ", variance=" << stats.variancePerAxis[i]; + if (i != n - 1) + os << "; "; + } + os << "}]"; return os; }