diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td index d95b45276074..761f6ce34030 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -1,259 +1,270 @@ //===- QuantOps.td - Quantization operation definition -----*- tablegen -*-===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This is the operation definition file for Quantization. // //===----------------------------------------------------------------------===// #ifdef DIALECT_QUANTOPS_QUANT_OPS_ #else #define DIALECT_QUANTOPS_QUANT_OPS_ #ifdef OP_BASE #else include "mlir/IR/OpBase.td" include "mlir/Dialect/QuantOps/QuantPredicates.td" #endif // OP_BASE def quant_Dialect : Dialect { let name = "quant"; } //===----------------------------------------------------------------------===// // Base classes //===----------------------------------------------------------------------===// class quant_Op traits> : Op; //===----------------------------------------------------------------------===// // Quantization casts //===----------------------------------------------------------------------===// // A QuantizeCast (qcast) represents a potential type shift from a quantizable // type to a quantized type. // // At runtime, a qcast will apply the transformation expressed by its // operand and result type. For flexibility during transformation, it is also // possible to have a qcast that performs no transformation (both its // operand and result type are quantizable). // // A qcast will typically originate from either: // a) An expressed or implied constraint in the source dialect which signals // that a certain level of quantization is possible or required. // b) An inference made by a quantization algorithm indicating that a // quantized representation may be acceptable. // // Especially early in transformation, it is common to have pairs of // qcast/dcast at points where a transition to a quantized type is // required. In addition, it is also common to have an identity qcast // (where the operand and result type are not quantized) at all points where // it is legal to use a quantized representation (but is not known to be // acceptable). def quant_QuantizeCastOp : quant_Op<"qcast", [NoSideEffect]> { let arguments = (ins quant_RealValueType:$arg); let results = (outs quant_RealValueType); } // A DequantizeCast op (dcast) represents the inverse of a qcast, // converting back from a quantized to quantizable (expressed) type. // // Like qcasts, a dcast is allowed to have both its operand and result // as non quantized types. This facilitates transformations and marks edges // where the computation must be carried out in the expressed type. // // Especially early in transformation, it is common to have dcasts on // all operands to ops that must operate with the expressed type (typically // math ops prior to lowering to target-specific, quantized kernels). def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> { let arguments = (ins quant_RealValueType:$arg); let results = (outs quant_RealValueType); } // A StorageCast (scast) represents a cast from or to a type based on the // storage type and a type based on a corresponding quantized type. // // This op exists to ensure type coherency for between parts of the computation // which are operating directly on an underlying storage type and those which // operate on quantized values. // // Examples from storage to quantized type: // i8 -> !quant<"uniform[i8:f32]{1.0}"> // tensor<4xi8> -> tensor<4x!quant<"uniform[i8:f32]{1.0}">> // vector<4xi8> -> vector<4x!quant<"uniform[i8:f32]{1.0}">> def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> { let arguments = (ins quant_RealOrStorageValueType:$arg); let results = (outs quant_RealOrStorageValueType); let hasCanonicalizer = 0b1; } //===----------------------------------------------------------------------===// // Training integration and instrumentation ops //===----------------------------------------------------------------------===// def quant_ConstFakeQuant : quant_Op<"const_fake_quant", [SameOperandsAndResultType, NoSideEffect]> { let summary = "Simulates the effect of uniform quantization with const range."; let description = [{ Given a const min, max, num_bits and narrow_range attribute, applies the same uniform quantization simulation as is done by the TensorFlow fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility method and the quant-convert-simulated-quantization pass for futher details. }]; let arguments = (ins F32Tensor:$inputs, F32Attr:$min, F32Attr:$max, // The bitwidth of the quantization; between 2 and 16, inclusive. I64Attr:$num_bits, // Quantization range starts from 0 or 1; starts from 1 if true. DefaultValuedAttr:$narrow_range, // The sign of the quantization. DefaultValuedAttr:$is_signed ); let results = (outs F32Tensor:$outputs ); } def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis", [SameOperandsAndResultType, NoSideEffect]> { let summary = "Simulates the effect of per axis uniform quantization with const range."; let description = [{ Given a const min, max, num_bits and narrow_range attribute, applies the same per axis uniform quantization simulation as is done by the TensorFlow fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType() utility method and the quant-convert-simulated-quantization pass for futher details. }]; let arguments = (ins F32Tensor:$inputs, F32ArrayAttr:$min, F32ArrayAttr:$max, // The quantized dimension of the inputs tensor. I64Attr:$axis, // The bitwidth of the quantization; between 2 and 16, inclusive. I64Attr:$num_bits, // Quantization range starts from 0 or 1; starts from 1 if true. DefaultValuedAttr:$narrow_range, // The sign of the quantization. DefaultValuedAttr:$is_signed ); let results = (outs F32Tensor:$outputs ); } def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> { let summary = "Indicates that statistics are resolved by reference."; let description = [{ This op acts as an identity that, when encountered at runtime, should result in statistics being collected about about the value of its operand/result. Such statistics will be stored with the provided key, allowing this node to later be converted to a 'stats' op if statistics with that key have been encountered. }]; let arguments = (ins quant_RealValueType:$arg, StrAttr:$statsKey ); let results = (outs quant_RealValueType); } def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> { let summary = "Identity op which associates statistics with the value."; let description = [{ Associates statistics about the runtime ranges of values observed for evaluations of this node. Statistics about the entire type are reported in the 'layerStats' attribute and those for each axis, in the (optional) `axisStats` attribute. The interpretation of each is determined by the last dimension of its shape. Currently, only dim=2 is supported, which is interpreted as [min, max]. `layerStats` must be a rank 1 tensor: [2] - `axisStats` must be a rank 2 tensor: [N, 2], where N=the rank of `arg`. + `axisStats` must be a rank 2 tensor: [N, 2], where N=the slice size + splitted by the `axis` dimension. For example: + , axis=3 => N=2 + , axis=2 => N=6 }]; let arguments = (ins quant_RealValueType:$arg, ElementsAttr:$layerStats, - OptionalAttr:$axisStats); + OptionalAttr:$axisStats, + OptionalAttr:$axis); let results = (outs quant_RealValueType); let verifier = [{ auto tensorArg = arg()->getType().dyn_cast(); - auto argRank = tensorArg ? tensorArg.getRank() : 0; + if (!tensorArg) return emitOpError("arg needs to be tensor type."); + // Verify layerStats attribute. { auto layerStatsType = layerStats().getType(); if (!layerStatsType.getElementType().isa()) { return emitOpError( "layerStats must have a floating point element type"); } if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { return emitOpError("layerStats must have shape [2]"); } } // Verify axisStats (optional) attribute. if (axisStats()) { + if (!axis()) return emitOpError("axis must be specified for axisStats"); + + auto shape = tensorArg.getShape(); + auto argSliceSize = std::accumulate(std::next(shape.begin(), + axis()->getSExtValue()), shape.end(), 1, std::multiplies()); + auto axisStatsType = axisStats()->getType(); if (!axisStatsType.getElementType().isa()) { return emitOpError("axisStats must have a floating point element type"); } if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || - axisStatsType.getDimSize(0) != argRank) { + axisStatsType.getDimSize(0) != argSliceSize) { return emitOpError("axisStats must have shape [N,2] " - "where N = the argument rank"); + "where N = the slice size defined by the axis dim"); } } return success(); }]; } def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> { let summary = "Indicates that one point of the computation is coupled to another."; let description = [{ Ordinarily, relationships between ops for the purposes of determining compatible quantized types is explicit based on the use-def chain. However, in some situations, a use may be separated from its def by arbitrary external connections. In such a case, during analysis, all coupled_ref nodes in a module which share a coupledKey will be considered to be directly connected as via an identity op for the purpose of type inference. }]; let arguments = (ins quant_RealValueType:$arg, StrAttr:$coupledKey); let results = (outs quant_RealValueType); } #endif // DIALECT_QUANTOPS_QUANT_OPS_ diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index 3bd49d43adcf..b618ac07f177 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -1,74 +1,75 @@ //===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Dialect/QuantOps/QuantOps.h" #include "TypeDetail.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/MathExtras.h" +#include using namespace mlir; using namespace mlir::quant; using namespace mlir::quant::detail; #define GET_OP_CLASSES #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" namespace { /// Matches x -> [scast -> scast] -> y, replacing the second scast with the /// value of x if the casts invert each other. class RemoveRedundantStorageCastsRewrite : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; PatternMatchResult matchAndRewrite(StorageCastOp op, PatternRewriter &rewriter) const override { if (!matchPattern(op.arg(), m_Op())) return matchFailure(); auto srcScastOp = cast(op.arg()->getDefiningOp()); if (srcScastOp.arg()->getType() != op.getType()) return matchFailure(); rewriter.replaceOp(op, srcScastOp.arg()); return matchSuccess(); } }; } // end anonymous namespace void StorageCastOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } QuantizationDialect::QuantizationDialect(MLIRContext *context) : Dialect(/*name=*/"quant", context) { addTypes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" >(); } diff --git a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp index 696c1e2db3a6..a82a288caf37 100644 --- a/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp +++ b/mlir/lib/Quantizer/Transforms/AddDefaultStatsTestPass.cpp @@ -1,129 +1,129 @@ //===- AddDefaultStatsTestPass.cpp - Testing pass to add default stats ----===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= // // This file defines a testing pass to add default statistics nodes to every // quantization eligible op. Useful for unit testing. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/QuantOps/QuantOps.h" #include "mlir/Dialect/QuantOps/QuantTypes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/Quantizer/Configurations/FxpMathConfig.h" #include "mlir/Quantizer/Support/Configuration.h" #include "mlir/Quantizer/Support/ConstraintAnalysisGraph.h" #include "mlir/Quantizer/Support/ConstraintAnalysisGraphTraits.h" #include "mlir/Quantizer/Transforms/Passes.h" #include "mlir/Support/LogicalResult.h" #include "llvm/Support/GraphWriter.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::quantizer; using namespace mlir::quant; namespace { class AddDefaultStatsPass : public FunctionPass { public: AddDefaultStatsPass() = default; AddDefaultStatsPass(SolverContext &solverContext, const TargetConfiguration &config) : explicitSolverContext(&solverContext), explicitConfig(&config) {} void runOnFunction() override; void runWithConfig(SolverContext &solverContext, const TargetConfiguration &config); private: SolverContext *explicitSolverContext = nullptr; const TargetConfiguration *explicitConfig = nullptr; }; } // end anonymous namespace void AddDefaultStatsPass::runOnFunction() { if (explicitSolverContext && explicitConfig) { // If explicitly constructed with a config and context. runWithConfig(*explicitSolverContext, *explicitConfig); return; } // For global pass registration, use defaults. SolverContext solverContext(*getFunction().getContext()); auto config = FxpMathTargetConfig::create(solverContext); runWithConfig(solverContext, *config); } void AddDefaultStatsPass::runWithConfig(SolverContext &solverContext, const TargetConfiguration &config) { auto func = getFunction(); // Insert stats for each argument. for (auto *arg : func.getArguments()) { if (!config.isHandledType(arg->getType())) continue; OpBuilder b(func.getBody()); APFloat minValue(-1.0f); APFloat maxValue(1.0f); ElementsAttr layerStats = DenseFPElementsAttr::get( b.getTensorType({2}, b.getF32Type()), {minValue, maxValue}); - auto statsOp = - b.create(func.getLoc(), arg, layerStats, nullptr); + auto statsOp = b.create(func.getLoc(), arg, layerStats, + nullptr, nullptr); arg->replaceAllUsesWith(statsOp); // StatsOp contained a use to 'arg' so make sure to reset it after replacing // all of the uses of 'arg'. statsOp.getOperation()->replaceUsesOfWith(statsOp, arg); } // Walk the ops and insert stats. func.walk([&](Operation *op) { if (!config.isRequireStatsOp(op)) { return; } assert(op->getNumResults() == 1); auto originalResult = op->getResult(0); if (!config.isHandledType(originalResult->getType())) return; OpBuilder b(op->getBlock(), ++op->getIterator()); APFloat minValue(-1.0f); APFloat maxValue(1.0f); ElementsAttr layerStats = DenseFPElementsAttr::get( b.getTensorType({2}, b.getF32Type()), {minValue, maxValue}); auto statsOp = b.create(op->getLoc(), op->getResult(0), - layerStats, nullptr); + layerStats, nullptr, nullptr); originalResult->replaceAllUsesWith(statsOp); // StatsOp contained a use to 'op' so make sure to reset it after replacing // all of the uses of 'op'. statsOp.getOperation()->replaceUsesOfWith(statsOp, originalResult); }); } std::unique_ptr> mlir::quantizer::createAddDefaultStatsPass() { return std::make_unique(); } static PassRegistration pass( "quantizer-add-default-stats-test", "Adds default (dummy) statistics to all ops that can benefit from " "runtime statistics. This is meant to help in early stage bootstrapping."); diff --git a/mlir/test/Dialect/QuantOps/parse-ops-invalid.mlir b/mlir/test/Dialect/QuantOps/parse-ops-invalid.mlir index 7a9b96bb2cd3..272c53070c7e 100644 --- a/mlir/test/Dialect/QuantOps/parse-ops-invalid.mlir +++ b/mlir/test/Dialect/QuantOps/parse-ops-invalid.mlir @@ -1,77 +1,93 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics // ----- func @invalidStatisticsMismatchedLayerType(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // expected-error@+1 {{layerStats must have a floating point element type}} %0 = "quant.stats"(%arg0) { layerStats = dense<[-1, 1]> : tensor<2xi8> } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- func @invalidStatisticsMismatchedLayerRank(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // expected-error@+1 {{layerStats must have shape [2]}} %0 = "quant.stats"(%arg0) { layerStats = dense<[[-1.0, 1.0]]> : tensor<1x2xf32> } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- func @invalidStatisticsMismatchedLayerShape(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // expected-error@+1 {{layerStats must have shape [2]}} %0 = "quant.stats"(%arg0) { layerStats = dense<[-1.0, 1.0, 2.0]> : tensor<3xf32> } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // CHECK-LABEL: validStatistics func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // expected-error@+1 {{axisStats must have a floating point element type}} %0 = "quant.stats"(%0) { layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, axisStats = dense<[ [-1, 1], [-8, 8], [-1, 0] - ]> : tensor<3x2xi8> + ]> : tensor<3x2xi8>, axis = 3 : i64 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- -func @invalidStatisticsMismatchedAxisRank(%arg0: tensor<8x4x3xf32>) -> +func @invalidStatisticsMismatchedAxisSize(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { - // expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}} + // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}} %0 = "quant.stats"(%arg0) { layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, axisStats = dense<[ [-1.0, 1.0], [-8.0, 8.0], [-0.5, 0.5], [-2.0, 3.5] - ]> : tensor<4x2xf32> + ]> : tensor<4x2xf32>, axis = 3 : i64 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { - // expected-error@+1 {{axisStats must have shape [N,2] where N = the argument rank}} + // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}} %0 = "quant.stats"(%arg0) { layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, axisStats = dense<[ [-1.0, 1.0, 1.0], [-8.0, 8.0, 1.0], [-0.5, 0.5, 1.0] - ]> : tensor<3x3xf32> + ]> : tensor<3x3xf32>, axis = 3 : i64 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } + +// ----- +func @axisIsRequiredForAxisStats(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + // expected-error@+1 {{axis must be specified for axisStats}} + %1 = "quant.stats"(%arg0) { + layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, + axisStats = dense<[ + [-1.0, 1.0], + [-8.0, 8.0], + [-0.5, 0.5] + ]> : tensor<3x2xf32> + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// ----- diff --git a/mlir/test/Dialect/QuantOps/parse-ops.mlir b/mlir/test/Dialect/QuantOps/parse-ops.mlir index 7d6d1abb2538..bdcd751a969d 100644 --- a/mlir/test/Dialect/QuantOps/parse-ops.mlir +++ b/mlir/test/Dialect/QuantOps/parse-ops.mlir @@ -1,64 +1,64 @@ // RUN: mlir-opt %s -split-input-file | FileCheck %s // ----- // CHECK-LABEL: validConstFakeQuant func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { %0 = "quant.const_fake_quant"(%arg0) { min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> %1 = "quant.const_fake_quant"(%0) { min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = false } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> %2 = "quant.const_fake_quant"(%1) { min = 0.0 : f32, max = 1.0 : f32, num_bits = 8 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %2 : tensor<8x4x3xf32> } // ----- // CHECK-LABEL: validConstFakeQuantPerAxis func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> { %0 = "quant.const_fake_quant_per_axis"(%arg0) { min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = true } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> %1 = "quant.const_fake_quant_per_axis"(%0) { min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = false } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> %2 = "quant.const_fake_quant_per_axis"(%1) { min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8 } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> return %2 : tensor<8x4x2xf32> } // ----- // CHECK-LABEL: validStatisticsRef func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { %0 = "quant.stats_ref"(%arg0) { statsKey = "foobar" } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> } // ----- // CHECK-LABEL: validStatistics func @validStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { %0 = "quant.stats"(%arg0) { layerStats = dense<[-1.0, 1.0]> : tensor<2xf32> } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> %1 = "quant.stats"(%0) { layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, axisStats = dense<[ [-1.0, 1.0], [-8.0, 8.0], [-0.5, 0.5] - ]> : tensor<3x2xf32> + ]> : tensor<3x2xf32>, axis = 2 : i64 } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %1 : tensor<8x4x3xf32> } // ----- // CHECK-LABEL: validCoupledRef func @validCoupledRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { %0 = "quant.coupled_ref"(%arg0) { coupledKey = "foobar" } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> return %0 : tensor<8x4x3xf32> }