diff --git a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Quant/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Quant/CMakeLists.txt @@ -1,8 +1,2 @@ add_mlir_dialect(QuantOps quant) add_mlir_doc(QuantOps QuantDialect Dialects/ -gen-dialect-doc) - -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name Quant) -add_public_tablegen_target(MLIRQuantPassIncGen) - -add_mlir_doc(Passes QuantPasses ./ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/Quant/Passes.h b/mlir/include/mlir/Dialect/Quant/Passes.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Quant/Passes.h +++ /dev/null @@ -1,46 +0,0 @@ -//===- Passes.h - Quantization Passes ------ --------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines all of the passes owned by the quantization dialect. As -// things mature, it is expected that passes specific to certain frontend or -// backend dialects will move to those dialects directly. For now, they are -// incubated here. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_QUANT_PASSES_H -#define MLIR_DIALECT_QUANT_PASSES_H - -#include "mlir/Pass/Pass.h" - -namespace mlir { -namespace quant { - -/// Creates a pass that converts quantization simulation operations (i.e. -/// FakeQuant and those like it) to casts into/out of supported QuantizedTypes. -std::unique_ptr> createConvertSimulatedQuantPass(); - -/// Creates a pass that converts constants followed by a qbarrier to a -/// constant whose value is quantized. This is typically one of the last -/// passes done when lowering to express actual quantized arithmetic in a -/// low level representation. Because it modifies the constant, it is -/// destructive and cannot be undone. -std::unique_ptr> createConvertConstPass(); - -//===----------------------------------------------------------------------===// -// Registration -//===----------------------------------------------------------------------===// - -/// Generate the code for registering passes. -#define GEN_PASS_REGISTRATION -#include "mlir/Dialect/Quant/Passes.h.inc" - -} // namespace quant -} // namespace mlir - -#endif // MLIR_DIALECT_QUANT_PASSES_H diff --git a/mlir/include/mlir/Dialect/Quant/Passes.td b/mlir/include/mlir/Dialect/Quant/Passes.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Quant/Passes.td +++ /dev/null @@ -1,27 +0,0 @@ -//===-- Passes.td - Quant pass definition file -------------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_QUANT_PASSES -#define MLIR_DIALECT_QUANT_PASSES - -include "mlir/Pass/PassBase.td" - -def QuantConvertConst : Pass<"quant-convert-const", "FuncOp"> { - let summary = "Converts constants followed by qbarrier to actual quantized " - "values"; - let constructor = "mlir::quant::createConvertConstPass()"; -} - -def QuantConvertSimulatedQuant - : Pass<"quant-convert-simulated-quantization", "FuncOp"> { - let summary = "Converts training-time simulated quantization ops to " - "corresponding quantize/dequantize casts"; - let constructor = "mlir::quant::createConvertSimulatedQuantPass()"; -} - -#endif // MLIR_DIALECT_QUANT_PASSES diff --git a/mlir/include/mlir/Dialect/Quant/QuantOps.td b/mlir/include/mlir/Dialect/Quant/QuantOps.td --- a/mlir/include/mlir/Dialect/Quant/QuantOps.td +++ b/mlir/include/mlir/Dialect/Quant/QuantOps.td @@ -84,170 +84,4 @@ let hasFolder = 1; } -// A QuantizeRegion (region) represents a quantization unit which wraps -// high-precision ops with quantization specifications for all the inputs -// and outputs. Some quantization specifications can be undetermined and -// derived from other ports by the target specification of the kernel. -def quant_QuantizeRegionOp : quant_Op<"region", [ - NoSideEffect, - IsolatedFromAbove, - SingleBlockImplicitTerminator<"ReturnOp">]> { - let summary = [{ - The `region` operation wraps high-precision ops as a logical low-precision - quantized kernel. - }]; - - let arguments = (ins Variadic:$inputs, - TypeArrayAttr:$input_specs, - TypeArrayAttr:$output_specs, - StrAttr:$logical_kernel); - let results = (outs Variadic:$outputs); - let regions = (region SizedRegion<1>:$body); - let hasVerifier = 1; -} - -def quant_ReturnOp : quant_Op<"return", [Terminator]> { - let summary = [{ - The `return` operation terminates a quantize region and returns values. - }]; - - let arguments = (ins Variadic:$results); -} - -//===----------------------------------------------------------------------===// -// Training integration and instrumentation ops -//===----------------------------------------------------------------------===// - -def quant_ConstFakeQuant : quant_Op<"const_fake_quant", - [SameOperandsAndResultType, NoSideEffect]> { - let summary = [{ - Simulates the effect of uniform quantization with const range. - }]; - - let description = [{ - Given a const min, max, num_bits and narrow_range attribute, applies the - same uniform quantization simulation as is done by the TensorFlow - fake_quant_with_min_max_args op. See the fakeQuantAttrsToType() utility - method and the quant-convert-simulated-quantization pass for further details. - }]; - - let arguments = (ins - F32Tensor:$inputs, - F32Attr:$min, - F32Attr:$max, - // The bitwidth of the quantization; between 2 and 16, inclusive. - I64Attr:$num_bits, - // Quantization range starts from 0 or 1; starts from 1 if true. - DefaultValuedAttr:$narrow_range, - // The sign of the quantization. - DefaultValuedAttr:$is_signed - ); - - let results = (outs - F32Tensor:$outputs - ); -} - -def quant_ConstFakeQuantPerAxis : quant_Op<"const_fake_quant_per_axis", - [SameOperandsAndResultType, NoSideEffect]> { - let summary = [{ - Simulates the effect of per axis uniform quantization with const range. - }]; - - let description = [{ - Given a const min, max, num_bits and narrow_range attribute, applies the - same per axis uniform quantization simulation as is done by the TensorFlow - fake_quant_with_min_max_vars_per_channel op. See the fakeQuantAttrsToType() - utility method and the quant-convert-simulated-quantization pass for further - details. - }]; - - let arguments = (ins - F32Tensor:$inputs, - F32ArrayAttr:$min, - F32ArrayAttr:$max, - // The quantized dimension of the inputs tensor. - I64Attr:$axis, - // The bitwidth of the quantization; between 2 and 16, inclusive. - I64Attr:$num_bits, - // Quantization range starts from 0 or 1; starts from 1 if true. - DefaultValuedAttr:$narrow_range, - // The sign of the quantization. - DefaultValuedAttr:$is_signed - ); - - let results = (outs - F32Tensor:$outputs - ); -} - -def quant_StatisticsRefOp : quant_Op<"stats_ref", [SameOperandsAndResultType]> { - let summary = "Indicates that statistics are resolved by reference."; - - let description = [{ - This op acts as an identity that, when encountered at runtime, should result - in statistics being collected about about the value of its operand/result. - Such statistics will be stored with the provided key, allowing this node - to later be converted to a 'stats' op if statistics with that key have been - encountered. - }]; - - let arguments = (ins - quant_RealValueType:$arg, - StrAttr:$statsKey - ); - let results = (outs quant_RealValueType); -} - -def quant_StatisticsOp : quant_Op<"stats", [SameOperandsAndResultType]> { - let summary = "Identity op which associates statistics with the value."; - - let description = [{ - Associates statistics about the runtime ranges of values observed for - evaluations of this node. - - Statistics about the entire type are reported in the 'layerStats' attribute - and those for each axis, in the (optional) `axisStats` attribute. The - interpretation of each is determined by the last dimension of its shape. - Currently, only dim=2 is supported, which is interpreted as [min, max]. - - `layerStats` must be a rank 1 tensor: [2] - `axisStats` must be a rank 2 tensor: [N, 2], where N=the slice size - splitted by the `axis` dimension. For example: - - ``` - , axis=3 => N=2 - , axis=2 => N=6 - ``` - }]; - - let arguments = (ins - quant_RealValueType:$arg, - ElementsAttr:$layerStats, - OptionalAttr:$axisStats, - OptionalAttr:$axis); - let results = (outs quant_RealValueType); - let hasVerifier = 1; -} - -def quant_CoupledRefOp : quant_Op<"coupled_ref", [SameOperandsAndResultType]> { - let summary = [{ - Indicates that one point of the computation is coupled to another. - }]; - - let description = [{ - Ordinarily, relationships between ops for the purposes of determining - compatible quantized types is explicit based on the use-def chain. However, - in some situations, a use may be separated from its def by arbitrary - external connections. In such a case, during analysis, all coupled_ref - nodes in a module which share a coupledKey will be considered to be - directly connected as via an identity op for the purpose of type inference. - }]; - - let arguments = (ins - quant_RealValueType:$arg, - StrAttr:$coupledKey); - let results = (outs quant_RealValueType); -} - #endif // DIALECT_QUANT_QUANT_OPS_ diff --git a/mlir/include/mlir/Dialect/Quant/QuantizeUtils.h b/mlir/include/mlir/Dialect/Quant/QuantizeUtils.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Quant/QuantizeUtils.h +++ /dev/null @@ -1,61 +0,0 @@ -//===- QuantizeUtils.h - Support utilities for quantization -----*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_QUANT_QUANTIZEUTILS_H_ -#define MLIR_DIALECT_QUANT_QUANTIZEUTILS_H_ - -namespace mlir { -class Attribute; -class Type; - -namespace quant { -class QuantizedType; -class UniformQuantizedType; -class UniformQuantizedValueConverter; - -/// Converts an attribute from a type based on -/// quantizedElementType.getExpressedType() to one based on -/// quantizedElementType.getStorageType(), where quantizedElementType is as from -/// QuantizedType::getQuantizedElementType(). -/// Returns nullptr if the conversion is not supported. On success, stores the -/// converted type in outConvertedType. -/// -/// Examples: -/// 1. realValue is a primitive value attribute: -/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) -/// -> (IntegerAttr, outConvertedType: i8) -/// 2. realValue is an elements attribute: -/// (realValue: DenseElementsAttr[tensor<2x2xf32>], -/// quantizedElementType: UniformQuantizedType[i8:f32]) -/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) -Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, - Type &outConvertedType); - -/// Converts an attribute from a type based on -/// quantizedElementType.getExpressedType() to one based on -/// quantizedElementType.getStorageType(), where quantizedElementType is as from -/// QuantizedType::getQuantizedElementType() and casted to an -/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On -/// success, stores the converted type in outConvertedType. -/// -/// Examples: -/// 1. realValue is a primitive value attribute: -/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) -/// -> (IntegerAttr, outConvertedType: i8) -/// 2. realValue is an elements attribute: -/// (realValue: DenseElementsAttr[tensor<2x2xf32>], -/// quantizedElementType: UniformQuantizedType[i8:f32]) -/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) -Attribute quantizeAttrUniform(Attribute realValue, - UniformQuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter, - Type &outConvertedType); -} // namespace quant -} // namespace mlir - -#endif // MLIR_DIALECT_QUANT_QUANTIZEUTILS_H_ diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -23,7 +23,6 @@ #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" -#include "mlir/Dialect/Quant/Passes.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SPIRV/Transforms/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" @@ -65,7 +64,6 @@ registerSparseTensorPasses(); LLVM::registerLLVMPasses(); memref::registerMemRefPasses(); - quant::registerQuantPasses(); registerSCFPasses(); registerShapePasses(); spirv::registerSPIRVPasses(); diff --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt --- a/mlir/lib/Dialect/Quant/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/CMakeLists.txt @@ -1,3 +1,2 @@ add_subdirectory(IR) -add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -43,93 +43,5 @@ return srcScastOp.arg(); } -/// The quantization specification should match the expressed type. -static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { - if (auto typeAttr = quantSpec.dyn_cast()) { - Type spec = typeAttr.getValue(); - if (spec.isa()) - return false; - - // The spec should be either a quantized type which is compatible to the - // expressed type, or a primitive type which is as same as the - // (element type of) the expressed type. - if (auto quantizedType = spec.dyn_cast()) - return quantizedType.isCompatibleExpressedType(expressed); - - if (auto tensorType = expressed.dyn_cast()) - return spec == tensorType.getElementType(); - - if (auto vectorType = expressed.dyn_cast()) - return spec == vectorType.getElementType(); - } - return false; -} - -LogicalResult QuantizeRegionOp::verify() { - // There are specifications for both inputs and outputs. - if (getNumOperands() != input_specs().size() || - getNumResults() != output_specs().size()) - return emitOpError( - "has unmatched operands/results number and spec attributes number"); - - // Verify that quantization specifications are valid. - for (auto input : llvm::zip(getOperandTypes(), input_specs())) { - Type inputType = std::get<0>(input); - Attribute inputSpec = std::get<1>(input); - if (!isValidQuantizationSpec(inputSpec, inputType)) { - return emitOpError() << "has incompatible specification " << inputSpec - << " and input type " << inputType; - } - } - - for (auto result : llvm::zip(getResultTypes(), output_specs())) { - Type outputType = std::get<0>(result); - Attribute outputSpec = std::get<1>(result); - if (!isValidQuantizationSpec(outputSpec, outputType)) { - return emitOpError() << "has incompatible specification " << outputSpec - << " and output type " << outputType; - } - } - return success(); -} - -LogicalResult StatisticsOp::verify() { - auto tensorArg = arg().getType().dyn_cast(); - if (!tensorArg) - return emitOpError("arg needs to be tensor type."); - - // Verify layerStats attribute. - { - auto layerStatsType = layerStats().getType(); - if (!layerStatsType.getElementType().isa()) { - return emitOpError("layerStats must have a floating point element type"); - } - if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { - return emitOpError("layerStats must have shape [2]"); - } - } - // Verify axisStats (optional) attribute. - if (axisStats()) { - if (!axis()) - return emitOpError("axis must be specified for axisStats"); - - auto shape = tensorArg.getShape(); - auto argSliceSize = - std::accumulate(std::next(shape.begin(), *axis()), shape.end(), 1, - std::multiplies()); - - auto axisStatsType = axisStats()->getType(); - if (!axisStatsType.getElementType().isa()) { - return emitOpError("axisStats must have a floating point element type"); - } - if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || - axisStatsType.getDimSize(0) != argSliceSize) { - return emitOpError("axisStats must have shape [N,2] " - "where N = the slice size defined by the axis dim"); - } - } - return success(); -} - #define GET_OP_CLASSES #include "mlir/Dialect/Quant/QuantOps.cpp.inc" diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt deleted file mode 100644 --- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_dialect_library(MLIRQuantTransforms - ConvertConst.cpp - ConvertSimQuant.cpp - - ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/QuantOps/Transforms - - DEPENDS - MLIRQuantPassIncGen - - LINK_LIBS PUBLIC - MLIRArithmetic - MLIRIR - MLIRQuant - MLIRQuantUtils - MLIRPass - MLIRSupport - MLIRTransformUtils - ) diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp +++ /dev/null @@ -1,103 +0,0 @@ -//===- ConvertConst.cpp - Quantizes constant ops --------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Quant/Passes.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantizeUtils.h" -#include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Matchers.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; -using namespace mlir::quant; - -namespace { -struct ConvertConstPass : public QuantConvertConstBase { - void runOnOperation() override; -}; - -struct QuantizedConstRewrite : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(QuantizeCastOp qbarrier, - PatternRewriter &rewriter) const override; -}; - -} // namespace - -/// Matches a [constant] -> [qbarrier] where the qbarrier results type is -/// quantized and the operand type is quantizable. - -LogicalResult -QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, - PatternRewriter &rewriter) const { - Attribute value; - - // Is the operand a constant? - if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { - return failure(); - } - - // Does the qbarrier convert to a quantized type. This will not be true - // if a quantized type has not yet been chosen or if the cast to an equivalent - // storage type is not supported. - Type qbarrierResultType = qbarrier.getResult().getType(); - QuantizedType quantizedElementType = - QuantizedType::getQuantizedElementType(qbarrierResultType); - if (!quantizedElementType) { - return failure(); - } - if (!QuantizedType::castToStorageType(qbarrierResultType)) { - return failure(); - } - - // Is the operand type compatible with the expressed type of the quantized - // type? This will not be true if the qbarrier is superfluous (converts - // from and to a quantized type). - if (!quantizedElementType.isCompatibleExpressedType( - qbarrier.arg().getType())) { - return failure(); - } - - // Is the constant value a type expressed in a way that we support? - if (!value.isa()) { - return failure(); - } - - Type newConstValueType; - auto newConstValue = - quantizeAttr(value, quantizedElementType, newConstValueType); - if (!newConstValue) { - return failure(); - } - - // When creating the new const op, use a fused location that combines the - // original const and the qbarrier that led to the quantization. - auto fusedLoc = rewriter.getFusedLoc( - {qbarrier.arg().getDefiningOp()->getLoc(), qbarrier.getLoc()}); - auto newConstOp = rewriter.create( - fusedLoc, newConstValueType, newConstValue); - rewriter.replaceOpWithNewOp(qbarrier, qbarrier.getType(), - newConstOp); - return success(); -} - -void ConvertConstPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - auto func = getOperation(); - auto *context = &getContext(); - patterns.add(context); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); -} - -std::unique_ptr> mlir::quant::createConvertConstPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp +++ /dev/null @@ -1,140 +0,0 @@ -//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PassDetail.h" -#include "mlir/Dialect/Quant/FakeQuantSupport.h" -#include "mlir/Dialect/Quant/Passes.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; -using namespace mlir::quant; - -namespace { -struct ConvertSimulatedQuantPass - : public QuantConvertSimulatedQuantBase { - void runOnOperation() override; -}; - -/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. -template -class FakeQuantRewrite : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) - : OpRewritePattern(ctx), hadFailure(hadFailure) {} - - LogicalResult matchAndRewrite(FakeQuantOp op, - PatternRewriter &rewriter) const override { - // TODO: If this pattern comes up more frequently, consider adding core - // support for failable rewrites. - if (failableRewrite(op, rewriter)) { - *hadFailure = true; - return failure(); - } - - return success(); - } - -private: - bool *hadFailure; - - bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { - auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); - if (!converter) { - return (op.emitError("unsupported quantized type conversion"), true); - } - - QuantizedType elementType = - static_cast(this) - ->convertFakeQuantAttrsToType(op, converter.expressedType); - - if (!elementType) { - // Note that the fakeQuantAttrsToType will have emitted the error. - return true; - } - - Type quantizedType = converter.convert(elementType); - assert(quantizedType && - "Converter accepted a type that it did not convert"); - - // TODO: Map to a qbarrier with an attribute like [Forced] to signal that - // this is a forced/hard-coded constraint. - auto qbarrier = rewriter.create(op.getLoc(), quantizedType, - op.inputs()); - rewriter.replaceOpWithNewOp(op, converter.inputType, - qbarrier.getResult()); - - return false; - } -}; - -class ConstFakeQuantRewrite - : public FakeQuantRewrite { -public: - using BaseRewrite = FakeQuantRewrite; - - ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) - : BaseRewrite(ctx, hadFailure) {} - - QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, - Type expressedType) const { - return fakeQuantAttrsToType( - fqOp.getLoc(), fqOp.num_bits(), fqOp.min().convertToFloat(), - fqOp.max().convertToFloat(), fqOp.narrow_range(), expressedType, - fqOp.is_signed()); - } -}; - -class ConstFakeQuantPerAxisRewrite - : public FakeQuantRewrite { -public: - using BaseRewrite = - FakeQuantRewrite; - - ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) - : BaseRewrite(ctx, hadFailure) {} - - QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, - Type expressedType) const { - SmallVector min, max; - min.reserve(fqOp.min().size()); - max.reserve(fqOp.max().size()); - for (auto m : fqOp.min()) - min.push_back(m.cast().getValueAsDouble()); - for (auto m : fqOp.max()) - max.push_back(m.cast().getValueAsDouble()); - - return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits(), fqOp.axis(), - min, max, fqOp.narrow_range(), expressedType, - fqOp.is_signed()); - } -}; - -} // namespace - -void ConvertSimulatedQuantPass::runOnOperation() { - bool hadFailure = false; - auto func = getOperation(); - RewritePatternSet patterns(func.getContext()); - auto *ctx = func.getContext(); - patterns.add( - ctx, &hadFailure); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); - if (hadFailure) - signalPassFailure(); -} - -std::unique_ptr> -mlir::quant::createConvertSimulatedQuantPass() { - return std::make_unique(); -} diff --git a/mlir/lib/Dialect/Quant/Transforms/PassDetail.h b/mlir/lib/Dialect/Quant/Transforms/PassDetail.h deleted file mode 100644 --- a/mlir/lib/Dialect/Quant/Transforms/PassDetail.h +++ /dev/null @@ -1,21 +0,0 @@ -//===- PassDetail.h - Quant Pass class details ------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef DIALECT_QUANT_TRANSFORMS_PASSDETAIL_H_ -#define DIALECT_QUANT_TRANSFORMS_PASSDETAIL_H_ - -#include "mlir/Pass/Pass.h" - -namespace mlir { - -#define GEN_PASS_CLASSES -#include "mlir/Dialect/Quant/Passes.h.inc" - -} // namespace mlir - -#endif // DIALECT_QUANT_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/lib/Dialect/Quant/Utils/CMakeLists.txt b/mlir/lib/Dialect/Quant/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Quant/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/Utils/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_dialect_library(MLIRQuantUtils - QuantizeUtils.cpp UniformSupport.cpp FakeQuantSupport.cpp diff --git a/mlir/lib/Dialect/Quant/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/Quant/Utils/QuantizeUtils.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Quant/Utils/QuantizeUtils.cpp +++ /dev/null @@ -1,147 +0,0 @@ -//===- QuantizeUtils.cpp - Support utilities for quantization -------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Quant/QuantizeUtils.h" -#include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" - -using namespace mlir; -using namespace mlir::quant; - -/// Converts a possible primitive, real expressed value attribute to a -/// corresponding storage attribute (typically FloatAttr -> IntegerAttr). -/// quantizedElementType is the QuantizedType that describes the expressed -/// origValue. -/// Returns a converter Attribute or nullptr if conversion is not possible. -static Attribute convertPrimitiveValueAttr( - Attribute origRealValue, QuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter, Type &outConvertedType) { - if (origRealValue.isa()) { - FloatAttr floatAttr = origRealValue.cast(); - outConvertedType = quantizedElementType.getStorageType(); - return IntegerAttr::get(quantizedElementType.getStorageType(), - converter.quantizeFloatToInt(floatAttr.getValue())); - } - - return nullptr; -} - -/// Converts a real expressed DenseFPElementsAttr to a corresponding -/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized -/// storage values assuming the given quantizedElementType and converter. -static DenseElementsAttr -convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, - QuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter) { - // Convert to corresponding quantized value attributes. - SmallVector quantValues; - if (realFPElementsAttr.isSplat()) { - quantValues.push_back( - converter.quantizeFloatToInt(*realFPElementsAttr.begin())); - } else { - quantValues.reserve(realFPElementsAttr.getNumElements()); - for (APFloat realVal : realFPElementsAttr) { - quantValues.push_back(converter.quantizeFloatToInt(realVal)); - } - } - - // Cast from an expressed-type-based type to storage-type-based type, - // preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>). - ShapedType newDenseType = - quantizedElementType - .castExpressedToStorageType(realFPElementsAttr.getType()) - .dyn_cast_or_null(); - if (!newDenseType) { - return nullptr; - } - return DenseIntElementsAttr::get(newDenseType, quantValues); -} - -/// Converts a real expressed SplatElementsAttr to a corresponding -/// SplatElementsAttr containing quantized storage values assuming the given -/// quantizedElementType and converter. -static SparseElementsAttr -convertSparseElementsAttr(SparseElementsAttr realSparseAttr, - QuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter) { - DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); - if (!realDenseAttr.isa()) { - return nullptr; - } - DenseElementsAttr quantDenseAttr = - convertDenseFPElementsAttr(realDenseAttr.cast(), - quantizedElementType, converter); - if (!quantDenseAttr) { - return nullptr; - } - - // Cast from an expressed-type-based type to storage-type-based type, - // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). - ShapedType newSparseType = - quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) - .dyn_cast_or_null(); - if (!newSparseType) { - return nullptr; - } - return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(), - quantDenseAttr); -} - -/// Converts a real expressed Attribute to a corresponding Attribute containing -/// quantized storage values assuming the given uniform quantizedElementType and -/// converter. -Attribute mlir::quant::quantizeAttrUniform( - Attribute realValue, UniformQuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter, Type &outConvertedType) { - // Fork to handle different variants of constants supported. - if (realValue.isa()) { - // Dense tensor or vector constant. - auto converted = convertDenseFPElementsAttr( - realValue.cast(), quantizedElementType, converter); - outConvertedType = converted.getType(); - return converted; - } - if (realValue.isa()) { - // Sparse tensor or vector constant. - auto converted = convertSparseElementsAttr( - realValue.cast(), quantizedElementType, converter); - outConvertedType = converted.getType(); - return converted; - } - // Nothing else matched: try to convert a primitive. - return convertPrimitiveValueAttr(realValue, quantizedElementType, converter, - outConvertedType); -} - -/// Convert an attribute from a type based on -/// quantizedElementType.getExpressedType() to one based on -/// quantizedElementType.getStorageType(). -/// Returns nullptr if the conversion is not supported. -/// On success, stores the converted type in outConvertedType. -Attribute mlir::quant::quantizeAttr(Attribute realValue, - QuantizedType quantizedElementType, - Type &outConvertedType) { - if (auto uniformQuantized = - quantizedElementType.dyn_cast()) { - UniformQuantizedValueConverter converter(uniformQuantized); - return quantizeAttrUniform(realValue, uniformQuantized, converter, - outConvertedType); - } - if (auto uniformQuantizedPerAxis = - quantizedElementType.dyn_cast()) { - UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis); - auto converted = converter.convert(realValue); - // TODO: why we need this outConvertedType? remove it? - if (converted) { - outConvertedType = converted.getType(); - } - return converted; - } - return nullptr; -} diff --git a/mlir/test/Dialect/Quant/convert-const.mlir b/mlir/test/Dialect/Quant/convert-const.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Quant/convert-const.mlir +++ /dev/null @@ -1,193 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -quant-convert-const | FileCheck %s - -// Magic numbers: -// 7.8125e-03 = 1/128 = 2/256 : real range = [-1.0, 0.9921875] (for 8bit, zeroPoint=128) -// 1.250000e-01 = 1/8 = 2/16 : real range = [-1.0, 0.875] (for 4bit, zeroPoint=8) - -// ----- -// Verifies u8 affine quantization on a splat tensor. -// Note that MLIR prints int attributes as signed, so the constant, when -// quantized, is the signed printed version of an unsigned quantity -// (-64 signed == 192 unsigned). -// CHECK-LABEL: constant_splat_tensor_u8_affine -func @constant_splat_tensor_u8_affine() -> tensor<4xf32> { - // CHECK: %cst = arith.constant dense<-64> : tensor<4xi8> - // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform> - %cst = arith.constant dense<0.5> : tensor<4xf32> - %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<4x!quant.uniform>) -> (tensor<4xf32>) - return %2 : tensor<4xf32> -} - -// ----- -// Verifies i8 affine quantization on a splat tensor. -// CHECK-LABEL: constant_splat_tensor_i8_affine -func @constant_splat_tensor_i8_affine() -> tensor<4xf32> { - // CHECK: %cst = arith.constant dense<63> : tensor<4xi8> - // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform> - %cst = arith.constant dense<0.5> : tensor<4xf32> - %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<4x!quant.uniform>) -> (tensor<4xf32>) - return %2 : tensor<4xf32> -} - -// ----- -// Verifies i8 fixedpoint quantization on a splat tensor. -// CHECK-LABEL: const_splat_tensor_i8_fixedpoint -func @const_splat_tensor_i8_fixedpoint() -> tensor<4xf32> { - // CHECK: %cst = arith.constant dense<64> : tensor<4xi8> - // CHECK-NEXT: %0 = "quant.scast"(%cst) : (tensor<4xi8>) -> tensor<4x!quant.uniform> - %cst = arith.constant dense<0.5> : tensor<4xf32> - %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<4x!quant.uniform>) -> (tensor<4xf32>) - return %2 : tensor<4xf32> -} - -// ----- -// Verifies i8 fixedpoint quantization on a splat tensor resulting in a negative storage value. -// CHECK-LABEL: const_splat_tensor_i8_fixedpoint_neg -func @const_splat_tensor_i8_fixedpoint_neg() -> tensor<4xf32> { - // CHECK: %cst = arith.constant dense<-64> : tensor<4xi8> - %cst = arith.constant dense<-0.5> : tensor<4xf32> - %1 = "quant.qcast"(%cst) : (tensor<4xf32>) -> tensor<4x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<4x!quant.uniform>) -> (tensor<4xf32>) - return %2 : tensor<4xf32> -} - -// ----- -// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values. -// CHECK-LABEL: const_dense_tensor_i8_fixedpoint -func @const_dense_tensor_i8_fixedpoint() -> tensor<7xf32> { - // CHECK: %cst = arith.constant dense<[-128, -128, -64, 0, 64, 127, 127]> : tensor<7xi8> - %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform>) -> (tensor<7xf32>) - return %2 : tensor<7xf32> -} - -// ----- -// Verifies i8 fixedpoint quantization on a sparse tensor, sweeping values. -// CHECK-LABEL: const_sparse_tensor_i8_fixedpoint -func @const_sparse_tensor_i8_fixedpoint() -> tensor<2x7xf32> { - // NOTE: Ugly regex match pattern for opening "[[" of indices tensor. - // CHECK: %cst = arith.constant sparse<{{\[}}[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], [-128, -128, -64, 0, 64, 127, 127]> : tensor<2x7xi8> - %cst = arith.constant sparse< - [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6]], - [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<2x7xf32> - %1 = "quant.qcast"(%cst) : (tensor<2x7xf32>) -> tensor<2x7x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<2x7x!quant.uniform>) -> (tensor<2x7xf32>) - return %2 : tensor<2x7xf32> -} - -// ----- -// Verifies i8 fixedpoint quantization on a primitive const. -// CHECK-LABEL: const_primitive_float_i8_fixedpoint -func @const_primitive_float_i8_fixedpoint() -> f32 { - // CHECK: %c64_i8 = arith.constant 64 : i8 - // CHECK-NEXT: %0 = "quant.scast"(%c64_i8) : (i8) -> !quant.uniform - %cst = arith.constant 0.5 : f32 - %1 = "quant.qcast"(%cst) : (f32) -> !quant.uniform - %2 = "quant.dcast"(%1) : (!quant.uniform) -> (f32) - return %2 : f32 -} - -// ----- -// Verifies u4 affine quantization on a dense tensor, sweeping values. -// CHECK-LABEL: const_dense_tensor_u4_affine -func @const_dense_tensor_u4_affine() -> tensor<7xf32> { - // NOTE: Unsigned quantities printed by MLIR as signed. - // CHECK: %cst = arith.constant dense<[0, 0, 4, -8, -4, -1, -1]> : tensor<7xi4> - %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform>) -> (tensor<7xf32>) - return %2 : tensor<7xf32> -} - -// ----- -// Verifies i4 affine quantization on a dense tensor, sweeping values. -// CHECK-LABEL: const_dense_tensor_i4_affine -func @const_dense_tensor_i4_affine() -> tensor<7xf32> { - // NOTE: Unsigned quantities printed by MLIR as signed. - // CHECK: %cst = arith.constant dense<[-8, -8, -5, -1, 3, 7, 7]> : tensor<7xi4> - %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform>) -> (tensor<7xf32>) - return %2 : tensor<7xf32> -} - -// ----- -// Verifies i4 fixed point quantization on a dense tensor, sweeping values. -// CHECK-LABEL: const_dense_tensor_i4_fixedpoint -func @const_dense_tensor_i4_fixedpoint() -> tensor<7xf32> { - // CHECK: %cst = arith.constant dense<[-8, -8, -4, 0, 4, 7, 7]> : tensor<7xi4> - %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform>) -> (tensor<7xf32>) - return %2 : tensor<7xf32> -} - -// ----- -// Verifies i8 fixedpoint quantization on a dense tensor, sweeping values, and -// custom storage range. (the -128 should be clamped to -100, and the 127 should -// be clamped to 100). -// CHECK-LABEL: const_custom_storage_range_i8_fixedpoint -func @const_custom_storage_range_i8_fixedpoint() -> tensor<7xf32> { - // CHECK: %cst = arith.constant dense<[-100, -100, -64, 0, 64, 100, 100]> : tensor<7xi8> - %cst = arith.constant dense<[-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]> : tensor<7xf32> - %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform:f32, 7.812500e-03>> - %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform:f32, 7.812500e-03>>) -> (tensor<7xf32>) - return %2 : tensor<7xf32> -} - -// ----- -// Verifies quantization results of all-0.0 tensors are quantized to zero points. -// CHECK-LABEL: zero_tensors_to_zero_points -func @zero_tensors_to_zero_points() -> (tensor<7xf32>, tensor<7xf32>, tensor<7xf32>, tensor<7xf32>) { - -// CHECK-DAG: %[[cst1:.*]] = arith.constant dense<1> : tensor<7xi8> -// CHECK-DAG: %[[cst:.*]] = arith.constant dense<-127> : tensor<7xi8> -// CHECK-DAG: %[[cst0:.*]] = arith.constant dense<0> : tensor<7xi8> -// CHECK: "quant.scast"(%[[cst0]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform> -// CHECK: "quant.scast"(%[[cst]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform:f32, 1.000000e+00:-127>> -// CHECK: "quant.scast"(%[[cst0]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform> -// CHECK: "quant.scast"(%[[cst1]]) : (tensor<7xi8>) -> tensor<7x!quant.uniform:f32, 1.000000e+00:1>> - - %cst = arith.constant dense<0.0> : tensor<7xf32> - %1 = "quant.qcast"(%cst) : (tensor<7xf32>) -> tensor<7x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<7x!quant.uniform>) -> (tensor<7xf32>) - - %cst0 = arith.constant dense<0.0> : tensor<7xf32> - %3 = "quant.qcast"(%cst0) : (tensor<7xf32>) -> tensor<7x!quant.uniform:f32, 1.0:-127>> - %4 = "quant.dcast"(%3) : (tensor<7x!quant.uniform:f32, 1.0:-127>>) -> (tensor<7xf32>) - - %cst1 = arith.constant dense<0.0> : tensor<7xf32> - %5 = "quant.qcast"(%cst1) : (tensor<7xf32>) -> tensor<7x!quant.uniform> - %6 = "quant.dcast"(%5) : (tensor<7x!quant.uniform>) -> (tensor<7xf32>) - - %cst2 = arith.constant dense<0.0> : tensor<7xf32> - %7 = "quant.qcast"(%cst2) : (tensor<7xf32>) -> tensor<7x!quant.uniform:f32, 1.0:1>> - %8 = "quant.dcast"(%7) : (tensor<7x!quant.uniform:f32, 1.0:1>>) -> (tensor<7xf32>) - - return %2, %4, %6, %8 : tensor<7xf32>, tensor<7xf32>, tensor<7xf32>, tensor<7xf32> -} - -// ----- -// Verifies per-axis quantization results for dense. -// CHECK-LABEL: per_axis_dense_quantization -func @per_axis_dense_quantization() -> (tensor<2x3xf32>, tensor<2x3xf32>) { - -// CHECK-DAG: %[[cst0:.*]] = arith.constant dense<{{\[}}[-128, -1, 1], [127, 1, 3]]> : tensor<2x3xi8> -// CHECK-DAG: %[[cst:.*]] = arith.constant dense<{{\[}}[-128, 64, 127], [0, 1, 2]]> : tensor<2x3xi8> -// CHECK: "quant.scast"(%[[cst]]) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform> -// CHECK: "quant.scast"(%[[cst0]]) : (tensor<2x3xi8>) -> tensor<2x3x!quant.uniform> - - %cst = arith.constant dense<[[-2.0, -0.5, 0.0], [0.0, 1.0, 2.0]]> : tensor<2x3xf32> - %1 = "quant.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> - %2 = "quant.dcast"(%1) : (tensor<2x3x!quant.uniform>) -> (tensor<2x3xf32>) - - %cst0 = arith.constant dense<[[-2.0, -0.5, 0.0], [0.0, 1.0, 2.0]]> : tensor<2x3xf32> - %3 = "quant.qcast"(%cst0) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> - %4 = "quant.dcast"(%3) : (tensor<2x3x!quant.uniform>) -> (tensor<2x3xf32>) - - return %2, %4 : tensor<2x3xf32>, tensor<2x3xf32> -} diff --git a/mlir/test/Dialect/Quant/convert-fakequant-invalid.mlir b/mlir/test/Dialect/Quant/convert-fakequant-invalid.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Quant/convert-fakequant-invalid.mlir +++ /dev/null @@ -1,12 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -verify-diagnostics -quant-convert-simulated-quantization - -// ----- -// Unsupported quantizable type (i1 is currently not a supported element type). -func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> { -^bb0(%arg0: tensor<8x4x3xi1>): - // expected-error@+1 {{op operand #0 must be tensor of 32-bit float values}} - %0 = "quant.const_fake_quant"(%arg0) { - min = 1.1 : f32, max = 1.0 : f32, num_bits = 8 - } : (tensor<8x4x3xi1>) -> tensor<8x4x3xi1> - return %0 : tensor<8x4x3xi1> -} diff --git a/mlir/test/Dialect/Quant/convert-fakequant.mlir b/mlir/test/Dialect/Quant/convert-fakequant.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Quant/convert-fakequant.mlir +++ /dev/null @@ -1,233 +0,0 @@ -// RUN: mlir-opt %s -split-input-file -quant-convert-simulated-quantization | FileCheck %s - -// ----- -// Verifies a quint8 single point. -// CHECK-LABEL: fakeQuantArgs_Quint8_0 -func @fakeQuantArgs_Quint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 0.0 : f32, num_bits = 8 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a quint8 single point (with narrow_range = true). -// CHECK-LABEL: fakeQuantArgs_Quint8_0_NarrowRange -func @fakeQuantArgs_Quint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform:f32, 1.000000e+00:1>> - // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform:f32, 1.000000e+00:1>>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a quint8 asymmetric 0..1 range. -// CHECK-LABEL: fakeQuantArgs_Quint8_0_1 -func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 1.0 : f32, num_bits = 8 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true). -// CHECK-LABEL: fakeQuantArgs_Quint8_NarrowRange -func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:1>> - // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:1>>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a quint8 symmetric range of -1..127/128. -// CHECK-LABEL: fakeQuantArgs_Quint8_SymmetricRange -func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a qint8 single point. -// CHECK-LABEL: fakeQuantArgs_Qint8_0 -func @fakeQuantArgs_Qint8_0(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, is_signed = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a qint8 single point (with narrow_range = true). -// CHECK-LABEL: fakeQuantArgs_Qint8_0_NarrowRange -func @fakeQuantArgs_Qint8_0_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform:f32, 1.000000e+00:-127>> - // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform:f32, 1.000000e+00:-127>>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 0.0 : f32, num_bits = 8, narrow_range = true, is_signed = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a qint8 asymmetric 0..1 range. -// CHECK-LABEL: fakeQuantArgs_Qint8_0_1 -func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, is_signed = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true). -// CHECK-LABEL: fakeQuantArgs_Qint8_NarrowRange -func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:-127>> - // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform:f32, 0.003937007874015748:-127>>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true, is_signed = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a qint8 symmetric range of -1..127/128. -// CHECK-LABEL: fakeQuantArgs_Qint8_SymmetricRange -func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = -1.0 : f32, max = 0.9921875 : f32, num_bits = 8, narrow_range = false, is_signed = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a commonly used -1..1 symmetric 16bit range with a zero point of -// 0 and range -1.0 .. 32767/32768. -// CHECK-LABEL: fakeQuantArgs_Qint16_Symmetric -func @fakeQuantArgs_Qint16_Symmetric(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor<8x4x3x!quant.uniform>) - // CHECK-SAME: -> tensor<8x4x3xf32> - %0 = "quant.const_fake_quant"(%arg0) { - min = -1.0 : f32, max = 0.999969482 : f32, num_bits = 16, is_signed = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verify that lowering to barriers of unranked tensors functions. -// CHECK-LABEL: fakeQuantArgs_UnrankedTensor -func @fakeQuantArgs_UnrankedTensor(tensor) -> tensor { -^bb0(%arg0: tensor): - // CHECK: %0 = "quant.qcast"(%arg0) : (tensor) - // CHECK-SAME: -> tensor> - // CHECK-NEXT: %1 = "quant.dcast"(%0) : (tensor>) - // CHECK-SAME: -> tensor - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 1.0 : f32, num_bits = 8 - } : (tensor) -> tensor - return %0 : tensor -} - -// ----- -// CHECK-LABEL: fakeQuantArgs_all_positive -func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - - // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform>) - // CHECK-SAME: -> tensor<8x4x3xf32> - - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.5 : f32, max = 1.5 : f32, num_bits = 8, narrow_range = false, is_signed = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// CHECK-LABEL: fakeQuantArgs_all_negative -func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - - // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform>) - // CHECK-SAME: -> tensor<8x4x3xf32> - - %0 = "quant.const_fake_quant"(%arg0) { - min = -1.5 : f32, max = -0.5 : f32, num_bits = 8, narrow_range = false, is_signed = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verifies a qint8 per axis -// CHECK-LABEL: fakeQuantPerAxis -func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - - // CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) - // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> - // CHECK: %[[d:.*]] = "quant.dcast"(%[[q]]) - // CHECK-SAME: (tensor<8x4x3x!quant.uniform>) - - %0 = "quant.const_fake_quant_per_axis"(%arg0) { - min = [-1.0 : f32, 0.0 : f32, 0.0 : f32], - max = [0.9921875 : f32, 0.0: f32, 1.0 : f32], - num_bits = 8, narrow_range = false, is_signed = true, axis = 2 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} diff --git a/mlir/test/Dialect/Quant/parse-ops-invalid.mlir b/mlir/test/Dialect/Quant/parse-ops-invalid.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Quant/parse-ops-invalid.mlir +++ /dev/null @@ -1,93 +0,0 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics - -// ----- -func @invalidStatisticsMismatchedLayerType(%arg0: tensor<8x4x3xf32>) -> - tensor<8x4x3xf32> { - // expected-error@+1 {{layerStats must have a floating point element type}} - %0 = "quant.stats"(%arg0) { - layerStats = dense<[-1, 1]> : tensor<2xi8> - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -func @invalidStatisticsMismatchedLayerRank(%arg0: tensor<8x4x3xf32>) -> - tensor<8x4x3xf32> { - // expected-error@+1 {{layerStats must have shape [2]}} - %0 = "quant.stats"(%arg0) { - layerStats = dense<[[-1.0, 1.0]]> : tensor<1x2xf32> - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -func @invalidStatisticsMismatchedLayerShape(%arg0: tensor<8x4x3xf32>) -> - tensor<8x4x3xf32> { - // expected-error@+1 {{layerStats must have shape [2]}} - %0 = "quant.stats"(%arg0) { - layerStats = dense<[-1.0, 1.0, 2.0]> : tensor<3xf32> - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// CHECK-LABEL: validStatistics -func @invalidStatisticsMismatchedAxisType(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { - // expected-error@+1 {{axisStats must have a floating point element type}} - %0 = "quant.stats"(%0) { - layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, - axisStats = dense<[ - [-1, 1], - [-8, 8], - [-1, 0] - ]> : tensor<3x2xi8>, axis = 3 : i64 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -func @invalidStatisticsMismatchedAxisSize(%arg0: tensor<8x4x3xf32>) -> - tensor<8x4x3xf32> { - // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}} - %0 = "quant.stats"(%arg0) { - layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, - axisStats = dense<[ - [-1.0, 1.0], - [-8.0, 8.0], - [-0.5, 0.5], - [-2.0, 3.5] - ]> : tensor<4x2xf32>, axis = 3 : i64 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -func @invalidStatisticsMismatchedAxisShape(%arg0: tensor<8x4x3xf32>) -> - tensor<8x4x3xf32> { - // expected-error@+1 {{axisStats must have shape [N,2] where N = the slice size defined by the axis dim}} - %0 = "quant.stats"(%arg0) { - layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, - axisStats = dense<[ - [-1.0, 1.0, 1.0], - [-8.0, 8.0, 1.0], - [-0.5, 0.5, 1.0] - ]> : tensor<3x3xf32>, axis = 3 : i64 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -func @axisIsRequiredForAxisStats(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { - // expected-error@+1 {{axis must be specified for axisStats}} - %1 = "quant.stats"(%arg0) { - layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, - axisStats = dense<[ - [-1.0, 1.0], - [-8.0, 8.0], - [-0.5, 0.5] - ]> : tensor<3x2xf32> - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %1 : tensor<8x4x3xf32> -} - -// ----- diff --git a/mlir/test/Dialect/Quant/parse-ops.mlir b/mlir/test/Dialect/Quant/parse-ops.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Quant/parse-ops.mlir +++ /dev/null @@ -1,64 +0,0 @@ -// RUN: mlir-opt %s -split-input-file | FileCheck %s - -// ----- -// CHECK-LABEL: validConstFakeQuant -func @validConstFakeQuant(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { - %0 = "quant.const_fake_quant"(%arg0) { - min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = true - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - %1 = "quant.const_fake_quant"(%0) { - min = 0.0 : f32, max = 1.0 : f32, num_bits = 8, narrow_range = false - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - %2 = "quant.const_fake_quant"(%1) { - min = 0.0 : f32, max = 1.0 : f32, num_bits = 8 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %2 : tensor<8x4x3xf32> -} - -// ----- -// CHECK-LABEL: validConstFakeQuantPerAxis -func @validConstFakeQuantPerAxis(%arg0: tensor<8x4x2xf32>) -> tensor<8x4x2xf32> { - %0 = "quant.const_fake_quant_per_axis"(%arg0) { - min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = true - } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> - %1 = "quant.const_fake_quant_per_axis"(%0) { - min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8, narrow_range = false - } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> - %2 = "quant.const_fake_quant_per_axis"(%1) { - min = [0.0 : f32, 1.0 : f32], max = [2.0 : f32, 3.0 : f32], axis = 2, num_bits = 8 - } : (tensor<8x4x2xf32>) -> tensor<8x4x2xf32> - return %2 : tensor<8x4x2xf32> -} - -// ----- -// CHECK-LABEL: validStatisticsRef -func @validStatisticsRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { - %0 = "quant.stats_ref"(%arg0) { statsKey = "foobar" } : - (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// CHECK-LABEL: validStatistics -func @validStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { - %0 = "quant.stats"(%arg0) { - layerStats = dense<[-1.0, 1.0]> : tensor<2xf32> - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - %1 = "quant.stats"(%0) { - layerStats = dense<[-1.0, 1.0]> : tensor<2xf32>, - axisStats = dense<[ - [-1.0, 1.0], - [-8.0, 8.0], - [-0.5, 0.5] - ]> : tensor<3x2xf32>, axis = 2 : i64 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %1 : tensor<8x4x3xf32> -} - -// ----- -// CHECK-LABEL: validCoupledRef -func @validCoupledRef(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { - %0 = "quant.coupled_ref"(%arg0) { coupledKey = "foobar" } : - (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} diff --git a/mlir/test/Dialect/Quant/quant_region.mlir b/mlir/test/Dialect/Quant/quant_region.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Quant/quant_region.mlir +++ /dev/null @@ -1,101 +0,0 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s - -// CHECK-LABEL: @source -func @source(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - %0 = "quant.region"(%arg0, %arg1, %arg2) ({ - ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): - %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - "quant.return"(%14) : (tensor<4xf32>) -> () - }) {input_specs = [f32, f32, f32], output_specs = [f32], logical_kernel = "xyz"} - : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: @annotated -func @annotated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - %0 = "quant.region"(%arg0, %arg1, %arg2) ({ - ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): - %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - "quant.return"(%14) : (tensor<4xf32>) -> () - }) {input_specs = [!quant.uniform, !quant.uniform, f32], - output_specs = [!quant.uniform], logical_kernel = "xyz"} - : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) - return %0 : tensor<4xf32> -} - -// CHECK-LABEL: @quantized -func @quantized(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - %0 = "quant.region"(%arg0, %arg1, %arg2) ({ - ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): - %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - "quant.return"(%14) : (tensor<4xf32>) -> () - }) {input_specs = [!quant.uniform, !quant.uniform, !quant.uniform], - output_specs = [!quant.uniform], logical_kernel = "xyz"} - : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) - return %0 : tensor<4xf32> -} - -// ----- - -func @unmatched_quantize(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - // @expected-error @+1 {{'quant.region' op has incompatible specification !quant.uniform and input type 'tensor<4xf32>'}} - %0 = "quant.region"(%arg0, %arg1, %arg2) ({ - ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): - %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - "quant.return"(%14) : (tensor<4xf32>) -> () - }) {input_specs = [!quant.uniform, !quant.uniform, !quant.uniform], - output_specs = [!quant.uniform], logical_kernel = "xyz"} - : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) - return %0 : tensor<4xf32> -} - -// ----- - -func @unmatched_primitive(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - // @expected-error @+1 {{'quant.region' op has incompatible specification i32 and input type 'tensor<4xf32>'}} - %0 = "quant.region"(%arg0, %arg1, %arg2) ({ - ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): - %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - "quant.return"(%14) : (tensor<4xf32>) -> () - }) {input_specs = [!quant.uniform, !quant.uniform, i32], - output_specs = [!quant.uniform], logical_kernel = "xyz"} - : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) - return %0 : tensor<4xf32> -} - -// ----- - -func @unmatched_number(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - // @expected-error @+1 {{'quant.region' op has unmatched operands/results number and spec attributes number}} - %0 = "quant.region"(%arg0, %arg1, %arg2) ({ - ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>, %12: tensor<4xf32>): - %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - %14 = "bar"(%13, %12) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - "quant.return"(%14) : (tensor<4xf32>) -> () - }) {input_specs = [!quant.uniform, !quant.uniform], - output_specs = [!quant.uniform], logical_kernel = "xyz"} - : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) - return %0 : tensor<4xf32> -} - -// ----- - -func @isolated(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<4xf32>) -> (tensor<4xf32>) { - // @expected-note @+1 {{required by region isolation constraints}} - %0 = "quant.region"(%arg0, %arg1) ({ - ^bb0(%10: tensor<4xf32>, %11: tensor<4xf32>): - %13 = "foo"(%10, %11) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - // @expected-error @+1 {{'bar' op using value defined outside the region}} - %14 = "bar"(%13, %arg2) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> - "quant.return"(%14) : (tensor<4xf32>) -> () - }) {input_specs = [!quant.uniform, !quant.uniform], - output_specs = [!quant.uniform], logical_kernel = "xyz"} - : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xf32>) - return %0 : tensor<4xf32> -} - diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,7 +7,6 @@ MLIRDialect) add_subdirectory(Affine) -add_subdirectory(Quant) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) add_subdirectory(Utils) diff --git a/mlir/unittests/Dialect/Quant/CMakeLists.txt b/mlir/unittests/Dialect/Quant/CMakeLists.txt deleted file mode 100644 --- a/mlir/unittests/Dialect/Quant/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -add_mlir_unittest(MLIRQuantTests - QuantizationUtilsTest.cpp -) -target_link_libraries(MLIRQuantTests - PRIVATE - MLIRQuant - MLIRQuantUtils - ) diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp deleted file mode 100644 --- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp +++ /dev/null @@ -1,172 +0,0 @@ -//===- QuantizationUtilsTest.cpp - unit tests for quantization utils ------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantizeUtils.h" -#include "mlir/Dialect/Quant/UniformSupport.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/BuiltinTypes.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using namespace mlir; -using namespace mlir::quant; - -namespace { - -// Test UniformQuantizedValueConverter converts all APFloat to a magic number 5. -class TestUniformQuantizedValueConverter - : public UniformQuantizedValueConverter { -public: - TestUniformQuantizedValueConverter(UniformQuantizedType type) - : UniformQuantizedValueConverter(type), qtype(type) {} - APInt quantizeFloatToInt(APFloat expressedValue) const override { - return APInt(qtype.getStorageType().cast().getWidth(), 5L); - } - -private: - UniformQuantizedType qtype; -}; - -Attribute getTestFloatAttr(double value, MLIRContext *ctx) { - return FloatAttr::get(FloatType::getF32(ctx), value); -} - -template -ConcreteAttrClass getTestElementsAttr(MLIRContext *ctx, ArrayRef shape, - Arg... value) { - auto eleType = FloatType::getF32(ctx); - ShapedType tensorType; - if (shape.size() == 1 && shape[0] == -1) { - tensorType = UnrankedTensorType::get(eleType); - } else { - tensorType = RankedTensorType::get(shape, eleType); - } - return ConcreteAttrClass::get(tensorType, value...); -} - -ElementsAttr getTestSparseElementsAttr(MLIRContext *ctx, - ArrayRef shape) { - auto eleType = FloatType::getF32(ctx); - ShapedType tensorType; - if (shape.size() == 1 && shape[0] == -1) { - tensorType = UnrankedTensorType::get(eleType); - } else { - tensorType = RankedTensorType::get(shape, eleType); - } - auto indicesType = RankedTensorType::get({1, 2}, IntegerType::get(ctx, 64)); - auto indices = - DenseIntElementsAttr::get(indicesType, {APInt(64, 0), APInt(64, 0)}); - auto valuesType = RankedTensorType::get({1}, eleType); - auto values = DenseFPElementsAttr::get(valuesType, {APFloat(0.0f)}); - return SparseElementsAttr::get(tensorType, indices, values); -} - -UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) { - return UniformQuantizedType::get(/*flags=*/false, storageType, - FloatType::getF32(ctx), /*scale=*/1.0, - /*zeroPoint=*/0, /*storageTypeMin=*/0, - /*storageTypeMax=*/255); -} - -TEST(QuantizationUtilsTest, convertFloatAttrUniform) { - MLIRContext ctx; - ctx.getOrLoadDialect(); - IntegerType convertedType = IntegerType::get(&ctx, 8); - auto quantizedType = getTestQuantizedType(convertedType, &ctx); - TestUniformQuantizedValueConverter converter(quantizedType); - - auto realValue = getTestFloatAttr(1.0, &ctx); - Type typeResult; - auto valueResult = - quantizeAttrUniform(realValue, quantizedType, converter, typeResult); - - EXPECT_EQ(valueResult.cast().getInt(), 5); - EXPECT_EQ( - valueResult.cast().getType().cast().getWidth(), - convertedType.getWidth()); -} - -TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) { - MLIRContext ctx; - ctx.getOrLoadDialect(); - IntegerType convertedType = IntegerType::get(&ctx, 8); - auto quantizedType = getTestQuantizedType(convertedType, &ctx); - TestUniformQuantizedValueConverter converter(quantizedType); - auto realValue = getTestElementsAttr>( - &ctx, {1, 2}, {getTestFloatAttr(1.0, &ctx), getTestFloatAttr(2.0, &ctx)}); - - Type returnedType; - auto returnedValue = - quantizeAttrUniform(realValue, quantizedType, converter, returnedType); - - // Check Elements attribute shape and kind are not changed. - auto tensorType = returnedType.cast(); - auto expectedTensorType = realValue.getType().cast(); - EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); - EXPECT_EQ(tensorType.getElementType(), convertedType); - EXPECT_TRUE(returnedValue.isa()); - - // Check Elements attribute element value is expected. - auto firstValue = - returnedValue.cast().getValues()[{0, 0}]; - EXPECT_EQ(firstValue.cast().getInt(), 5); -} - -TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { - MLIRContext ctx; - ctx.getOrLoadDialect(); - IntegerType convertedType = IntegerType::get(&ctx, 8); - auto quantizedType = getTestQuantizedType(convertedType, &ctx); - TestUniformQuantizedValueConverter converter(quantizedType); - auto realValue = getTestElementsAttr( - &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx)); - - Type returnedType; - auto returnedValue = - quantizeAttrUniform(realValue, quantizedType, converter, returnedType); - - // Check Elements attribute shape and kind are not changed. - auto tensorType = returnedType.cast(); - auto expectedTensorType = realValue.getType().cast(); - EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); - EXPECT_EQ(tensorType.getElementType(), convertedType); - EXPECT_TRUE(returnedValue.isa()); - - // Check Elements attribute element value is expected. - auto firstValue = - returnedValue.cast().getValues()[{0, 0}]; - EXPECT_EQ(firstValue.cast().getInt(), 5); -} - -TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) { - MLIRContext ctx; - ctx.getOrLoadDialect(); - IntegerType convertedType = IntegerType::get(&ctx, 8); - auto quantizedType = getTestQuantizedType(convertedType, &ctx); - TestUniformQuantizedValueConverter converter(quantizedType); - auto realValue = getTestSparseElementsAttr(&ctx, {1, 2}); - - Type returnedType; - auto returnedValue = - quantizeAttrUniform(realValue, quantizedType, converter, returnedType); - - // Check Elements attribute shape and kind are not changed. - auto tensorType = returnedType.cast(); - auto expectedTensorType = realValue.getType().cast(); - EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape()); - EXPECT_EQ(tensorType.getElementType(), convertedType); - EXPECT_TRUE(returnedValue.isa()); - - // Check Elements attribute element value is expected. - auto firstValue = - returnedValue.cast().getValues()[{0, 0}]; - EXPECT_EQ(firstValue.cast().getInt(), 5); -} - -} // namespace