diff --git a/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h index 5d11c769b8e0..42366842ada6 100644 --- a/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h +++ b/mlir/include/mlir/Dialect/QuantOps/UniformSupport.h @@ -1,119 +1,118 @@ //===- UniformSupport.h - Support utilities for uniform quant ---*- C++ -*-===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #ifndef MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ #define MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ #include "mlir/Dialect/QuantOps/QuantTypes.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" namespace mlir { namespace quant { /// Performs type conversion from an arbitrary input type to a type -/// that is expressed by a UniformQuantizedType. +/// that is expressed by a QuantizedType. /// /// This handles cases where the inputType is a supported primitive type /// (i.e. f32, bf16, etc) or a vector/tensor type based on a supported /// elemental type. /// /// Since conversion often involves introspecting some attributes of the /// input type in order to determine how to represent it, this is a two step /// process. -struct ExpressedToUniformQuantizedConverter { +struct ExpressedToQuantizedConverter { /// Creates a converter for the given input type. - static const ExpressedToUniformQuantizedConverter - forInputType(Type inputType); + static const ExpressedToQuantizedConverter forInputType(Type inputType); /// Converts the inputType to be based on the given elemental type, /// returning the new type (or nullptr and emit an error on failure). - Type convert(UniformQuantizedType elementalType) const; + Type convert(QuantizedType elementalType) const; /// Whether the conversion is legal. explicit operator bool() const { return (bool)expressedType; } /// The input type that is being converted from. /// This may be an elemental or composite type. const Type inputType; /// Supported, elemental expressed type (i.e. f32). /// Will be nullptr if conversion is not supported. const Type expressedType; }; /// Reference implementation of converting between real numbers and values /// represented by a UniformQuantizedType. /// Note that this is not expected to be speedy and may be superceded eventually /// by a more optimal implementation. /// Also, the interface assumes that quantization is done per-layer and will /// need to be wider for various per-channel schemes. As such, this is a /// placeholder. class UniformQuantizedValueConverter { public: UniformQuantizedValueConverter(UniformQuantizedType uniformType) : scale(uniformType.getScale()), zeroPoint(static_cast(uniformType.getZeroPoint())), clampMin(static_cast(uniformType.getStorageTypeMin())), clampMax(static_cast(uniformType.getStorageTypeMax())), storageBitWidth(uniformType.getStorageTypeIntegralWidth()), isSigned(uniformType.isSigned()) { assert(uniformType.getExpressedType().isa()); assert(uniformType.getStorageType().isa()); } virtual APInt quantizeFloatToInt(APFloat expressedValue) const { bool lossy; expressedValue.convert(scale.getSemantics(), APFloat::rmNearestTiesToEven, &lossy); // fixedpoint = clamp(clampMin, clampMax, ( // roundHalfToEven(expressed / scale) + zeroPoint)) APFloat scaled = (expressedValue / scale); scaled.roundToIntegral(APFloat::rmNearestTiesToEven); scaled.add(zeroPoint, APFloat::rmNearestTiesToEven); APFloat fixedpoint = llvm::minimum(scaled, clampMax); fixedpoint = llvm::maximum(fixedpoint, clampMin); llvm::APSInt result(storageBitWidth, !isSigned); fixedpoint.convertToInteger(result, APFloat::rmNearestTiesToEven, &lossy); return std::move(result); } int64_t quantizeFloatToInt64(APFloat expressedValue) const { APInt qValue = quantizeFloatToInt(expressedValue); return isSigned ? qValue.getSExtValue() : qValue.getZExtValue(); } virtual ~UniformQuantizedValueConverter() {} private: const APFloat scale; const APFloat zeroPoint; const APFloat clampMin; const APFloat clampMax; const uint32_t storageBitWidth; const bool isSigned; }; } // namespace quant } // namespace mlir #endif // MLIR_DIALECT_QUANTOPS_UNIFORMSUPPORT_H_ diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 129671979caa..4f6eb8cb9855 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -1,114 +1,114 @@ //===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Dialect/QuantOps/FakeQuantSupport.h" #include "mlir/Dialect/QuantOps/Passes.h" #include "mlir/Dialect/QuantOps/QuantOps.h" #include "mlir/Dialect/QuantOps/UniformSupport.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Pass/Pass.h" using namespace mlir; using namespace mlir::quant; namespace { class ConvertSimulatedQuantPass : public FunctionPass { public: void runOnFunction() override; }; } // end anonymous namespace /// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair. class ConstFakeQuantRewrite : public RewritePattern { public: bool *hadFailure; ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure) : RewritePattern(ConstFakeQuant::getOperationName(), 1, context), hadFailure(hadFailure) {} PatternMatchResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { // TODO: If this pattern comes up more frequently, consider adding core // support for failable rewrites. if (failableRewrite(op, rewriter)) { *hadFailure = true; return matchFailure(); } return matchSuccess(); } bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { auto fqOp = cast(op); auto converter = - ExpressedToUniformQuantizedConverter::forInputType(fqOp.getType()); + ExpressedToQuantizedConverter::forInputType(fqOp.getType()); if (!converter) { return (op->emitError("unsupported quantized type conversion"), true); } UniformQuantizedType uniformElementType = fakeQuantAttrsToType( fqOp.getLoc(), fqOp.num_bits().getSExtValue(), fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), fqOp.narrow_range(), converter.expressedType, fqOp.is_signed()); if (!uniformElementType) { // Note that the fakeQuantAttrsToType will have emitted the error. return true; } Type quantizedType = converter.convert(uniformElementType); assert(quantizedType && "Converter accepted a type that it did not convert"); // TODO: Map to a qbarrier with an attribute like [Forced] to signal that // this is a forced/hard-coded constraint. auto qbarrier = rewriter.create(op->getLoc(), quantizedType, fqOp.inputs()); rewriter.replaceOpWithNewOp(op, converter.inputType, qbarrier.getResult()); return false; } }; void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; auto func = getFunction(); auto *context = &getContext(); patterns.insert(context, &hadFailure); applyPatternsGreedily(func, patterns); if (hadFailure) signalPassFailure(); } std::unique_ptr mlir::quant::createConvertSimulatedQuantPass() { return std::make_unique(); } static PassRegistration pass("quant-convert-simulated-quantization", "Converts training-time simulated quantization ops to corresponding " "quantize/dequantize casts."); diff --git a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp index db8a58489815..aec45d4076bb 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/UniformSupport.cpp @@ -1,73 +1,72 @@ //===- UniformSupport.cpp - Support utilities for uniform quant -----------===// // // Copyright 2019 The MLIR Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "mlir/Dialect/QuantOps/UniformSupport.h" #include "mlir/IR/StandardTypes.h" using namespace mlir; using namespace mlir::quant; static bool isQuantizablePrimitiveType(Type inputType) { return inputType.isa(); } -const ExpressedToUniformQuantizedConverter -ExpressedToUniformQuantizedConverter::forInputType(Type inputType) { +const ExpressedToQuantizedConverter +ExpressedToQuantizedConverter::forInputType(Type inputType) { switch (inputType.getKind()) { default: if (isQuantizablePrimitiveType(inputType)) { // Supported primitive type (which just is the expressed type). - return ExpressedToUniformQuantizedConverter{inputType, inputType}; + return ExpressedToQuantizedConverter{inputType, inputType}; } // Unsupported. - return ExpressedToUniformQuantizedConverter{inputType, nullptr}; + return ExpressedToQuantizedConverter{inputType, nullptr}; case StandardTypes::RankedTensor: case StandardTypes::UnrankedTensor: case StandardTypes::Vector: { Type elementType = inputType.cast().getElementType(); if (!isQuantizablePrimitiveType(elementType)) { // Unsupported. - return ExpressedToUniformQuantizedConverter{inputType, nullptr}; + return ExpressedToQuantizedConverter{inputType, nullptr}; } - return ExpressedToUniformQuantizedConverter{ + return ExpressedToQuantizedConverter{ inputType, inputType.cast().getElementType()}; } } } -Type ExpressedToUniformQuantizedConverter::convert( - UniformQuantizedType elementalType) const { +Type ExpressedToQuantizedConverter::convert(QuantizedType elementalType) const { assert(expressedType && "convert() on unsupported conversion"); switch (inputType.getKind()) { default: if (isQuantizablePrimitiveType(elementalType)) { // For primitives, just use the new elemental type. return elementalType; } // Unsupported. return nullptr; case StandardTypes::RankedTensor: return RankedTensorType::get(inputType.cast().getShape(), elementalType); case StandardTypes::UnrankedTensor: return UnrankedTensorType::get(elementalType); case StandardTypes::Vector: return VectorType::get(inputType.cast().getShape(), elementalType); } }