diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 4f6eb8cb9855..1000b1fabbf7 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -1,114 +1,158 @@ //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" #include "mlir/Dialect/QuantOps/Passes.h" #include "mlir/Dialect/QuantOps/QuantOps.h" #include "mlir/Dialect/QuantOps/UniformSupport.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" using namespace mlir; using namespace mlir::quant; namespace { class ConvertSimulatedQuantPass : public FunctionPass { public: void runOnFunction() override; }; } // end anonymous namespace -/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair. -class ConstFakeQuantRewrite : public RewritePattern { +/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. +template +class FakeQuantRewrite : public OpRewritePattern { public: - bool *hadFailure; + using OpRewritePattern::OpRewritePattern; - ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure) - : RewritePattern(ConstFakeQuant::getOperationName(), 1, context), - hadFailure(hadFailure) {} + FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : OpRewritePattern(ctx), hadFailure(hadFailure) {} - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(FakeQuantOp op, PatternRewriter &rewriter) const override { // TODO: If this pattern comes up more frequently, consider adding core // support for failable rewrites. if (failableRewrite(op, rewriter)) { *hadFailure = true; - return matchFailure(); + return Pattern::matchFailure(); } - return matchSuccess(); + return Pattern::matchSuccess(); } - bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { - auto fqOp = cast(op); +private: + bool *hadFailure; - auto converter = - ExpressedToQuantizedConverter::forInputType(fqOp.getType()); + bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { + auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); if (!converter) { - return (op->emitError("unsupported quantized type conversion"), true); + return (op.emitError("unsupported quantized type conversion"), true); } - UniformQuantizedType uniformElementType = fakeQuantAttrsToType( - fqOp.getLoc(), fqOp.num_bits().getSExtValue(), - fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), - fqOp.narrow_range(), converter.expressedType, fqOp.is_signed()); + QuantizedType elementType = + static_cast(this) + ->convertFakeQuantAttrsToType(op, converter.expressedType); - if (!uniformElementType) { + if (!elementType) { // Note that the fakeQuantAttrsToType will have emitted the error. return true; } - Type quantizedType = converter.convert(uniformElementType); + Type quantizedType = converter.convert(elementType); assert(quantizedType && "Converter accepted a type that it did not convert"); // TODO: Map to a qbarrier with an attribute like [Forced] to signal that // this is a forced/hard-coded constraint. - auto qbarrier = rewriter.create(op->getLoc(), quantizedType, - fqOp.inputs()); + auto qbarrier = rewriter.create(op.getLoc(), quantizedType, + op.inputs()); rewriter.replaceOpWithNewOp(op, converter.inputType, qbarrier.getResult()); return false; } }; +class ConstFakeQuantRewrite + : public FakeQuantRewrite { +public: + using BaseRewrite = FakeQuantRewrite; + + ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, + Type expressedType) const { + return fakeQuantAttrsToType( + fqOp.getLoc(), fqOp.num_bits().getSExtValue(), + fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), + fqOp.narrow_range(), expressedType, fqOp.is_signed()); + } +}; + +class ConstFakeQuantPerAxisRewrite + : public FakeQuantRewrite { +public: + using BaseRewrite = + FakeQuantRewrite; + + ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, + Type expressedType) const { + SmallVector min, max; + min.reserve(fqOp.min().size()); + max.reserve(fqOp.max().size()); + for (auto m : fqOp.min()) + min.push_back(m.cast().getValueAsDouble()); + for (auto m : fqOp.max()) + max.push_back(m.cast().getValueAsDouble()); + + return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(), + fqOp.axis().getSExtValue(), min, max, + fqOp.narrow_range(), expressedType, + fqOp.is_signed()); + } +}; + void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; auto func = getFunction(); - auto *context = &getContext(); - patterns.insert(context, &hadFailure); + auto ctx = func.getContext(); + patterns.insert( + ctx, &hadFailure); applyPatternsGreedily(func, patterns); if (hadFailure) signalPassFailure(); } std::unique_ptr mlir::quant::createConvertSimulatedQuantPass() { return std::make_unique(); } static PassRegistration pass("quant-convert-simulated-quantization", "Converts training-time simulated quantization ops to corresponding " "quantize/dequantize casts."); diff --git a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 02f803ac8396..5d4561be81b2 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -1,188 +1,187 @@ //===- FakeQuantSupport.cpp - Support utilities for FakeQuant ops ---------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" namespace mlir { namespace quant { namespace { bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, MLIRContext *ctx, Type &storageType, int64_t &qmin, int64_t &qmax) { // Hard-coded type mapping from TFLite. if (numBits <= 8) { storageType = IntegerType::get(8, ctx); if (isSigned) { qmin = -128; qmax = 127; } else { qmin = 0; qmax = 255; } } else if (numBits <= 16) { storageType = IntegerType::get(16, ctx); if (isSigned) { qmin = -32768; qmax = 32767; } else { qmin = 0; qmax = 65535; } } else { return true; } // Handle narrowRange. if (narrowRange) { qmin += 1; } return false; } void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax, double &scale, int64_t &nudgedZeroPoint) { // Determine the scale. const double qminDouble = qmin; const double qmaxDouble = qmax; scale = (rmax - rmin) / (qmaxDouble - qminDouble); // Zero point computation. // In float, solve the affine equation for any known pair // (real value, corresponding quantized value), of which, two such pairs // are known: (rmin, qmin), (rmax, qmax). // The arithmetic error on the zero point computed from either pair will be // roughly machine_epsilon * (sum of absolute values of terms). // Use the variant that adds the smaller error. const double zeroPointFromMin = qminDouble - rmin / scale; const double zeroPointFromMinError = std::abs(qminDouble) + std::abs(rmin / scale); const double zeroPointFromMax = qmaxDouble - rmax / scale; const double zeroPointFromMaxError = std::abs(qmaxDouble) + std::abs(rmax / scale); const double zeroPointDouble = (zeroPointFromMinError < zeroPointFromMaxError) ? zeroPointFromMin : zeroPointFromMax; // Now nudge the zero point to be an integer. nudgedZeroPoint = 0; if (zeroPointDouble < qminDouble) { nudgedZeroPoint = qmin; } else if (zeroPointDouble > qmaxDouble) { nudgedZeroPoint = qmax; } else { nudgedZeroPoint = round(zeroPointDouble); } // By construction, the nudged zero point should always be in range. assert(nudgedZeroPoint >= qmin); assert(nudgedZeroPoint <= qmax); } } // end namespace UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange, Type expressedType, bool isSigned) { // Range must straddle zero. // TODO(b/140641593): remove this constraint. if (rmin > 0.0 || rmax < 0.0) { return (emitError(loc, "FakeQuant range must straddle zero: [") << rmin << "," << rmax << "]", nullptr); } MLIRContext *ctx = expressedType.getContext(); unsigned flags = isSigned ? QuantizationFlags::Signed : 0; Type storageType; int64_t qmin; int64_t qmax; if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, qmin, qmax)) { return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, nullptr); } // Special case where min/max is close enough. The tensor contents are all // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero // points and dequantized to 0.0. if (std::fabs(rmax - rmin) < std::numeric_limits::epsilon()) { return UniformQuantizedType::getChecked(flags, storageType, expressedType, 1.0, qmin, qmin, qmax, loc); } double scale; int64_t nudgedZeroPoint; getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); return UniformQuantizedType::getChecked(flags, storageType, expressedType, scale, nudgedZeroPoint, qmin, qmax, loc); } -// TODO(fengliuai): test this method once the quantizeAttr method is fixed. UniformQuantizedPerAxisType fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, ArrayRef rmins, ArrayRef rmaxs, bool narrowRange, Type expressedType, bool isSigned) { size_t axis_size = rmins.size(); if (axis_size != rmaxs.size()) { return (emitError(loc, "mismatched per-axis min and max size: ") << axis_size << " vs. " << rmaxs.size(), nullptr); } MLIRContext *ctx = expressedType.getContext(); Type storageType; int64_t qmin; int64_t qmax; if (getDefaultStorageParams(numBits, narrowRange, isSigned, ctx, storageType, qmin, qmax)) { return (emitError(loc, "unsupported FakeQuant number of bits: ") << numBits, nullptr); } SmallVector scales; SmallVector zeroPoints; scales.reserve(axis_size); zeroPoints.reserve(axis_size); for (size_t axis = 0; axis != axis_size; ++axis) { double rmin = rmins[axis]; double rmax = rmaxs[axis]; if (std::fabs(rmax - rmin) < std::numeric_limits::epsilon()) { scales.push_back(1.0); zeroPoints.push_back(qmin); continue; } double scale; int64_t nudgedZeroPoint; getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); scales.push_back(scale); zeroPoints.push_back(nudgedZeroPoint); } unsigned flags = isSigned ? QuantizationFlags::Signed : 0; return UniformQuantizedPerAxisType::getChecked( - flags, storageType, expressedType, scales, zeroPoints, qmin, qmax, - quantizedDimension, loc); + flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, + qmin, qmax, loc); } } // namespace quant } // namespace mlir diff --git a/mlir/test/Dialect/QuantOps/convert-fakequant.mlir b/mlir/test/Dialect/QuantOps/convert-fakequant.mlir index 15de088f39ce..316702cc5288 100644 --- a/mlir/test/Dialect/QuantOps/convert-fakequant.mlir +++ b/mlir/test/Dialect/QuantOps/convert-fakequant.mlir @@ -1,182 +1,201 @@ // RUN: mlir-opt %s -split-input-file -quant-convert-simulated-quantization | FileCheck %s --dump-input=fail // ----- // Verifies a quint8 single point. // CHECK-LABEL: fakeQuantArgs_Quint8_0 func @fakeQuantArgs_Quint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 0.0 : f32, num_bits = 8 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a quint8 single point (with narrow_range = true). // CHECK-LABEL: fakeQuantArgs_Quint8_0_NarrowRange func @fakeQuantArgs_Quint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform:f32, 1.000000e+00:1>> // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform:f32, 1.000000e+00:1>>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a quint8 asymmetric 0..1 range. // CHECK-LABEL: fakeQuantArgs_Quint8_0_1 func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 1.0 : f32, num_bits = 8 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a quint8 asymmetric 0..1 range (with narrow_range = true). // CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:1>> // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:1>>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a quint8 symmetric range of -1..127/128. // CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a qint8 single point. // CHECK-LABEL: fakeQuantArgs_Qint8_0 func @fakeQuantArgs_Qint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, is_signed = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a qint8 single point (with narrow_range = true). // CHECK-LABEL: fakeQuantArgs_Qint8_0_NarrowRange func @fakeQuantArgs_Qint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform:f32, 1.000000e+00:-127>> // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform:f32, 1.000000e+00:-127>>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true, is_signed = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a qint8 asymmetric 0..1 range. // CHECK-LABEL: fakeQuantArgs_Qint8_0_1 func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, is_signed = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a qint8 asymmetric 0..1 range (with narrow_range = true). // CHECK_LABEL: fakeQuantArgs_Qint8_NarrowRange func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:-127>> // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:-127>>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true, is_signed = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a qint8 symmetric range of -1..127/128. // CHECK_LABEL: fakeQuantArgs_Qint8_SymmetricRange func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false, is_signed = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verifies a commonly used -1..1 symmetric 16bit range with a zero point of // 0 and range -1.0 .. 32767/32768. // CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) // CHECK-SAME: -> tensor<8x4x3xf32> %0 = "quant.const_fake_quant"(%arg0) { min = -1.0 : f32, max = 0.999969482 : f32, num_bits = 16, is_signed = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // Verify that lowering to barriers of unranked tensors functions. // CHECK-LABEL: fakeQuantArgs_UnrankedTensor func @fakeQuantArgs_UnrankedTensor(tensor) -> tensor { ^bb0(%arg0: tensor): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor) // CHECK-SAME: -> tensor> // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor>) // CHECK-SAME: -> tensor %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 1.0 : f32, num_bits = 8 } : (tensor) -> tensor return %0 : tensor } + +// ----- +// Verifies a qint8 per axis +// CHECK_LABEL: fakeQuantPerAxis +func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + + // CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> + // CHECK: %[[d:.*]] = "quant.dcast"(%[[q]]) + // CHECK-SAME: (tensor<8x4x3x!quant.uniform>) + + %0 = "quant.const_fake_quant_per_axis"(%arg0) { + min = [-1.0 : f32, 0.0 : f32, 0.0 : f32], + max = [0.9921875 : f32, 0.0: f32, 1.0 : f32], + num_bits = 8, narrow_range = false, is_signed = true, axis = 2 + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +}