Index: mlir/include/mlir/Dialect/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/CMakeLists.txt +++ mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,4 +13,5 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) Index: mlir/include/mlir/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) Index: mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaOps.h.inc -gen-op-decls) +mlir_tablegen(TosaOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTosaOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaStructs.h.inc -gen-struct-attr-decls) +mlir_tablegen(TosaStructs.cpp.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRTosaStructsIncGen) + + +set(LLVM_TARGET_DEFINITIONS TosaInterfaces.td) +mlir_tablegen(TosaInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TosaInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRTosaInterfaceIncGen) + +add_mlir_doc(TosaOps -gen-op-doc TosaOps Dialects/) Index: mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td @@ -0,0 +1,34 @@ +//===-- TosaInterfaces.td - TOSA dialect interfaces --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dialect op interfaces for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_OP_INTERFACES +#define TOSA_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def TosaOpInterface : OpInterface<"TosaOp"> { + let description = [{ + Implements interfaces implemented by ops that correspond to the Tosa specification. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the TOSA version.}], + "StringRef", "getTOSAVersion", (ins), [{ + return "0.20"; + }] + >, + ]; + +} + +#endif Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -0,0 +1,289 @@ +//===-- TosaOpBase.td - TOSA dialect op builders -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the common definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + + +#ifdef TOSA_OP_BASE +#else +#define TOSA_OP_BASE + +// Quantization attributes used across TOSA operators. Quantization attributes feed +// numerical precision parameters to the functional implementation of TOSA operators. +// The functional behavior is defined in the TOSA specification maintained at +// https://developer.mlplatform.org/w/tosa/ . TOSA leverages MLIR's built in +// quantization support: https://mlir.llvm.org/docs/Quantization/ , and supports +// uniform quantization. Depending on datatype, asymmetric and symmetric quantization +// are supported. The types themselves are described in TosaTypesBase.td . + +// This quantization attribute expresses numerical behavior of operators where the +// operator has a numerical relationship between a single input and output. +// For example: tosa.negate. +def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"output_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Unary Ops."; +} + +// There is no explicit BinaryOpQuantizationAttr for 2-input/1-output ops. In this +// case, tosa.rescale is used to express the inputs to the same scale. +// TODO: Upload WIP legalization document describing this construction by example. + +// This quantization attribute holds input and weight zero point. Both the ConvOp and +// MatMulOp QuantizationAttrs follow a common design semantic where their own quantization +// attribute only expresses the numerical behavior at the inputs. The scaling of their +// accumulator output is done using an explicit tosa.rescale operator that scales the +// accumulator result to the output scale. +def Tosa_ConvOpQuantizationAttr : StructAttr<"ConvOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"weight_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Convolution Ops."; +} + +def Tosa_MatMulOpQuantizationAttr : StructAttr<"MatMulOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"a_zp", I32Attr>, + StructFieldAttr<"b_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Convolution Ops."; +} + +// This attribute holds input zero point correction applied to the padding zeros to ensure +// numerical accuracy in the subsequent TOSA operations. Its functional application is +// described in the tosa.pad() operator description in the specification. +def Tosa_PadOpQuantizationAttr : StructAttr<"PadOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Pad Ops."; +} + +// TOSA Quantization Builders. + +// This builder is called on all convolution operators except for TransposeConv, which has +// specialized output shape semantics. The builder also defines the bitwidth of the output +// given the bit width of the input & weight content. +def Tosa_ConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$filter, "Value":$bias, "ArrayAttr":$strides, "ArrayAttr":$dilations, "ArrayAttr":$padding), + [{ + $_state.addOperands(input); + $_state.addOperands(filter); + $_state.addOperands(bias); + $_state.addAttribute("strides", strides); + $_state.addAttribute("dilations", dilations); + $_state.addAttribute("padding", padding); + + auto quantAttr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + filter); + if (quantAttr) { + // TODO - update to clean up dyn_cast on dyn_cast . + $_state.addAttribute("quantization_info", quantAttr); + unsigned inputBits = input.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + unsigned weightBits = filter.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto outputShape = outputType.dyn_cast().getShape(); + IntegerType accElementType; + if(inputBits == 16 && weightBits == 8) + accElementType = $_builder.getIntegerType(48); + else + accElementType = $_builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + $_state.addTypes(accType); + } + else { + $_state.addTypes(outputType); + } + }]>; + +// Handles tosa.transpose_conv2d which has an outpad and output shape attribute. +def Tosa_TransConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$filter, "Value":$bias, "ArrayAttr":$strides, "ArrayAttr":$dilations, "ArrayAttr":$outpad, "ArrayAttr":$outputShape), + [{ + $_state.addOperands(input); + $_state.addOperands(filter); + $_state.addOperands(bias); + $_state.addAttribute("strides", strides); + $_state.addAttribute("dilations", dilations); + $_state.addAttribute("outpad", outpad); + $_state.addAttribute("outputShape", outputShape); + auto quantAttr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + filter); + + if (quantAttr) { + $_state.addAttribute("quantization_info", quantAttr); + unsigned inputBits = input.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + unsigned weightBits = filter.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto outputShape = outputType.dyn_cast().getShape(); + IntegerType accElementType; + if(inputBits == 16 && weightBits == 8) + accElementType = $_builder.getIntegerType(48); + else + accElementType = $_builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + $_state.addTypes(accType); + } + else { + $_state.addTypes(outputType); + } + }]>; + +// The tosa.fully_connected op has its own builder as it does not have +// strides/dilation/padding. +// TODO: clean up the conv, transpose_conv and fully_connected builders with a +// common utility function for all behavior that's common to the three. +def Tosa_FCOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$filter, "Value":$bias), + [{ + $_state.addOperands(input); + $_state.addOperands(filter); + $_state.addOperands(bias); + auto quantAttr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + filter); + if (quantAttr) { + $_state.addAttribute("quantization_info", quantAttr); + // TODO: Use utility function to express this common code. + unsigned inputBits = input.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + unsigned weightBits = filter.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto outputShape = outputType.dyn_cast().getShape(); + IntegerType accElementType; + if(inputBits == 16 && weightBits == 8) + accElementType = $_builder.getIntegerType(48); + else + accElementType = $_builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + $_state.addTypes(accType); + } + else { + $_state.addTypes(outputType); + } + }]>; + +// The tosa.matmul op is also intended to be generated where a fully_connected +// op must be constructed where the weight is not a constant. In this case, +// the fully_connected op must be expressed using matmul. +// TODO: Add link to the leglization document explaining this. +def Tosa_MatMulOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$a, "Value":$b), + [{ + $_state.addOperands(a); + $_state.addOperands(b); + auto quantAttr = mlir::tosa::buildMatMulOpQuantizationAttr($_builder, a, b); + + if (quantAttr) { + $_state.addAttribute("quantization_info", quantAttr); + unsigned inputBits = a.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto outputShape = outputType.dyn_cast().getShape(); + IntegerType accElementType; + if(inputBits == 16) + accElementType = $_builder.getIntegerType(48); + else + accElementType = $_builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + $_state.addTypes(accType); + } + else { + $_state.addTypes(outputType); + } + }]>; + +// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but +// the avg_pool operator has its own builder as it has additional parameters not part +// of the unary ops. +// TODO: split out the common code into a utility function and invoke from both builders. +def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "ArrayAttr":$kernel_size, "ArrayAttr":$strides, "ArrayAttr":$padding), + [{ + $_state.addOperands(input); + $_state.addAttribute("kernel_size", kernel_size); + $_state.addAttribute("strides", strides); + $_state.addAttribute("padding", padding); + auto quantAttr = mlir::tosa::buildUnaryOpQuantizationAttr($_builder, + input, + outputType); + if (quantAttr) + $_state.addAttribute("quantization_info", quantAttr); + $_state.types.push_back(outputType); + }]>; + +// This builder is called on single-parameter unary operators that have a scale +// relationship between their input and output, expressed by the UnaryOpQuantizationAttr. +def Tosa_UnaryOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input), + [{ + $_state.addOperands(input); + auto quantAttr = mlir::tosa::buildUnaryOpQuantizationAttr($_builder, + input, + outputType); + if (quantAttr) + $_state.addAttribute("quantization_info", quantAttr); + $_state.types.push_back(outputType); + }]>; + +// This builder is called on the TOSA pad operator that needs to create its own +// OptionalAttr quantization_attr parameter to scale the padding values correctly. +def Tosa_PadOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$paddings), + [{ + $_state.addOperands(input); + $_state.addOperands(paddings); + auto quantAttr = mlir::tosa::buildPadOpQuantizationAttr($_builder, + input); + if (quantAttr) + $_state.addAttribute("quantization_info", quantAttr); + $_state.types.push_back(outputType); + }]>; + +// This builder is called on elementwise binary broadcastable operators. +// TODO: Implement and use InferTypeOpInterface instead. +def Tosa_BroadcastableBinaryBuilder : OpBuilderDAG< + (ins "Value":$lhs, "Value":$rhs), + [{ + auto resultType = + OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); + if (!resultType) + mlir::emitError($_state.location, "Operands are not broadcastable"); + $_state.addOperands(lhs); + $_state.addOperands(rhs); + $_state.types.push_back(resultType); + }]>; + +// TOSA Specification Chapter-wise Operator Classification. + +class TOSA_StrEnumAttr cases> : + StrEnumAttr { + let predicate = And<[ + StrAttr.predicate, + CPred<"::mlir::tosa::symbolize" # name # "(" + "$_self.cast().getValue()).hasValue()">, + ]>; + let cppNamespace = "::mlir::tosa"; +} + +class Tosa_Op traits = []> : + Op { + +} + +// Specify traits of operators. + +#endif // TOSA_OP_BASE Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -0,0 +1,58 @@ +//===-- TosaOps.h - TOSA dialect operation definitions ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_IR_TOSA_OPS_H +#define DIALECT_TOSA_IR_TOSA_OPS_H + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tosa/IR/TosaTraits.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" + +#include "mlir/Dialect/Tosa/IR/TosaStructs.h.inc" + +namespace mlir { +namespace tosa { + +//===----------------------------------------------------------------------===// +// TOSA Dialect +//===----------------------------------------------------------------------===// +class TosaDialect : public Dialect { + +public: + explicit TosaDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "tosa"; } + + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; +}; + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" + +} // end namespace tosa +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" + +#endif // TOSA_OPS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -0,0 +1,1695 @@ +//===-- TosaOps.td - TOSA dialect operation definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the operation set for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOSA_OPS +#else +#define TOSA_OPS + +include "mlir/IR/OpBase.td" + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" + +include "mlir/Dialect/Tosa/IR/TosaTypesBase.td" + +//===----------------------------------------------------------------------===// +// The TOSA Dialect. +//===----------------------------------------------------------------------===// +def Tosa_Dialect : Dialect { + let name = "tosa"; + + let description = [{ + The Tensor Operator Set Architecture (TOSA) dialect. + + This dialect implements the TOSA standard described at + https://developer.mlplatform.org/w/tosa/ . + + Tensor Operator Set Architecture (TOSA) provides a set of whole-tensor operations + commonly employed by Deep Neural Networks. The intent is to enable a variety of + implementations running on a diverse range of processors, with the results at the + TOSA level consistent across those implementations. Applications or frameworks + which target TOSA can therefore be deployed on a wide range of different processors, + such as CPUs or GPUs, with defined accuracy and compatibility constraints. Most + operators from the common ML frameworks should be expressible in TOSA. It is + expected that there will be tools to lower from the ML frameworks into TOSA. + + }]; + + let cppNamespace = "mlir::tosa"; +} + +#ifdef TOSA_OP_BASE +#else +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" +#endif + + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.2 +// Operator Class: Tensor Data Engine Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: argmax +//===----------------------------------------------------------------------===// +def Tosa_ArgMaxOp : Tosa_Op<"argmax", [NoSideEffect]> { + let summary = "Perform argmax on the input."; + + let description = [{ + This returns the index with the largest value across the given axis of the input tensor. + }]; + + let arguments = (ins + Tosa_Tensor: $input, + I64Attr: $axis); + + let results = (outs Tosa_Tensor: $output); + +} + +//===----------------------------------------------------------------------===// +// Operator: avg_pool2d +//===----------------------------------------------------------------------===// +def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs an average pooling over the given input tensor. A sliding window of size + given by is passed over the input tensor, with the mean value being placed + in the output tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + + // TODO: Create Tosa_Tensor4D ODS type to enforce restriction + Confined]>:$kernel_size, + Confined]>:$strides, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: conv2d +//===----------------------------------------------------------------------===// +def Tosa_Conv2DOp : Tosa_Op<"conv2d", [NoSideEffect]> { + let summary = [{ + 2D Convolution Operator + }]; + + let description = [{ + Performs a 2D convolution over the given tensor input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$weight, + Tosa_Tensor:$bias, + + DefaultValuedAttr:$strides, + DefaultValuedAttr:$dilations, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: conv3d +//===----------------------------------------------------------------------===// +def Tosa_Conv3DOp : Tosa_Op<"conv3d", [NoSideEffect]> { + let summary = [{ + 3D Convolution operator + }]; + + let description = [{ + Performs a 3D convolution over the given input tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + DefaultValuedAttr]>, "{1, 1, 1, 1, 1}">:$strides, + DefaultValuedAttr]>, "{1, 1, 1, 1, 1}">:$dilations, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: depthwise_conv2d +//===----------------------------------------------------------------------===// +def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [NoSideEffect]> { + let summary = [{ + Depthwise 2D Convolution operator + }]; + + let description = [{ + Performs 2D convolutions separately over each channel of the given tensor input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + DefaultValuedAttr:$strides, + DefaultValuedAttr:$dilations, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: fully_connected +//===----------------------------------------------------------------------===// +def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> { + let summary = "Fully Connected operator"; + + let description = [{ + Performs a fully connected network. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_TensorOfOrNone<[Tosa_AnyNumber]>:$bias, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_FCOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: matmul +//===----------------------------------------------------------------------===// +def Tosa_MatMulOp : Tosa_Op<"matmul", [NoSideEffect]> { + let summary = "Matrix multiplication with bias"; + + let description = [{ + Performs a two dimensional matrix multiplication. This allows both inputs to be activations, + rather than reserving weights as an attribute in the FULLY_CONNECTED operator. + }]; + + let arguments = (ins + Tosa_Tensor:$a, + Tosa_Tensor:$b, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$c + ); + + let builders = [Tosa_MatMulOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: max_pool2d +//===----------------------------------------------------------------------===// +def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs a max pooling over the given input tensor. A sliding window of size given by + is passed over the input tensor, with the maximum value being placed in the + output tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + + Confined]>:$kernel_size, + Confined]>:$strides, + DefaultValuedAttr:$padding + ); + + let results = (outs + Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose_conv2d +//===----------------------------------------------------------------------===// +def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [NoSideEffect]> { + let summary = [{ + Transpose 2D Convolution operator. + }]; + + let description = [{ + Performs a 2D transposed convolution over the given tensor input, using the weights tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + I64ArrayAttr:$strides, + I64ArrayAttr:$dilations, + I64ArrayAttr:$outpad, + I64ArrayAttr:$output_shape, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_TransConvOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.3 +// Operator Class: Activation Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: clamp +//===----------------------------------------------------------------------===// +def Tosa_ClampOp : Tosa_Op<"clamp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes clamp(features, min, max)."; + + let description = [{ + Clamp to an arbitrary minimum and maximum value. Note that the maximum and minimum values are + specified as signed quantized values, no scaling happens before or after this operation. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$min_int, + I64Attr:$max_int, + F32Attr:$min_fp, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reluN +//===----------------------------------------------------------------------===// +def Tosa_ReluNOp : Tosa_Op<"reluN", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes rectified linear: `max(features, N)`."; + + let description = [{ + ReLU with a scalar maximum value. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$max_int, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: sigmoid +//===----------------------------------------------------------------------===// +def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise sigmoid of input."; + + let description = [{ + Sigmoid function: output = 1 / (1 + exp(-input)) + For quantized integer data types, the TABLE operator should be used instead with the following definition. + The sigmoid table has 513 entries each of 16-bit precision and covering the input range -16.0 to +16.0 + in steps of 1/16. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: tanh +//===----------------------------------------------------------------------===// +def Tosa_TanhOp : Tosa_Op<"tanh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise hyperbolic tangent of input"; + + let description = [{ + Parameterized hyperbolic tangent. + For quantized integer data types, the TABLE operator should be used instead with the following definition. + The tanh_table has 513 entries each of 16-bit precision and covering the input range -8.0 to +8.0 in steps of 1/32. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.4 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise binary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: add +//===----------------------------------------------------------------------===// +def Tosa_AddOp : Tosa_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise addition operator"; + + let description = [{ + Elementwise addition of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: arithmetic_right_shift +//===----------------------------------------------------------------------===// +def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Arithmetic Right Shift"; + + let description = [{ + Elementwise arithmetic right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_and +//===----------------------------------------------------------------------===// +def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise AND operator"; + + let description = [{ + Elementwise bitwise AND of input tensor 0 and input tensor 1. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_or +//===----------------------------------------------------------------------===// +def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise OR operator"; + + let description = [{ + Elementwise bitwise OR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_xor +//===----------------------------------------------------------------------===// +def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise XOR operator"; + + let description = [{ + Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_and +//===----------------------------------------------------------------------===// +def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x AND y element-wise."; + + let description = [{ + Elementwise logical AND of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$lhs, + I1Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_left_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Logical Left Shift"; + + let description = [{ + Elementwise left shift of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_right_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Logical Right Shift"; + + let description = [{ + Elementwise logical right shift of input1 by the amount specified in input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_or +//===----------------------------------------------------------------------===// +def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x OR y element-wise."; + + let description = [{ + Elementwise logical OR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$lhs, + I1Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_xor +//===----------------------------------------------------------------------===// +def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x XOR y element-wise."; + + let description = [{ + Elementwise logical XOR of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$lhs, + I1Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: maximum +//===----------------------------------------------------------------------===// +def Tosa_MaximumOp : Tosa_Op<"maximum", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise Maximum"; + + let description = [{ + Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: minimum +//===----------------------------------------------------------------------===// +def Tosa_MinimumOp : Tosa_Op<"minimum", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise Minimum"; + + let description = [{ + Elementwise minimum of input tensor 0 and input tensor 1. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: mul +//===----------------------------------------------------------------------===// +def Tosa_MulOp : Tosa_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Multiplication operator"; + + let description = [{ + Elementwise multiplication (Hadamard product) of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: pow +//===----------------------------------------------------------------------===// +def Tosa_PowOp : Tosa_Op<"pow", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Computes the power of one value to another."; + + let description = [{ + Elementwise input tensor 0 value raised to the power of input 1 tensor. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + Tosa_Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: sub +//===----------------------------------------------------------------------===// +def Tosa_SubOp : Tosa_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise subtraction operator"; + + let description = [{ + Elementwise subtraction of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: table +//===----------------------------------------------------------------------===// +def Tosa_TableOp : Tosa_Op<"table", [NoSideEffect]> { + let summary = "Table lookup op"; + + let description = [{ + Interpolated table lookup operation. Input values are scaled to create a fixed-point 9.7 value. + The high 9 bits are used to index into the table. The fractional bits are used to interpolate + based on the looked up value and the index+1 value in the table. The TABLE operator then returns + a 16.7 interpolated value. Note that there must be 513 values to handle the full range of inputs. + + The TABLE operator is expected to be used as follows: + • A RESCALE node is expected before the TABLE operator to scale the input to a full int16_t range + for the table lookup + • If an int16_t result is required then follow the TABLE operator with a RESCALE with a right + shift of 7 + • If an int8_t result is required then follow the TABLE operator with a RESCALE with a right + shift of 15 + }]; + + let arguments = (ins + Tosa_Tensor: $input, + Tosa_Tensor: $lut + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.5 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise unary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: abs +//===----------------------------------------------------------------------===// +def Tosa_AbsOp : Tosa_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise abs op"; + + let description = [{ + Elementwise absolute value operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_not +//===----------------------------------------------------------------------===// +def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Bitwise NOT operator"; + + let description = [{ + Elementwise bitwise NOT of input tensor. + }]; + + let arguments = ( + ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output); + +} + +//===----------------------------------------------------------------------===// +// Operator: ceil +//===----------------------------------------------------------------------===// +def Tosa_CeilOp : Tosa_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise ceil op"; + + let description = [{ + Elementwise ceiling operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: clz +//===----------------------------------------------------------------------===// +def Tosa_ClzOp : Tosa_Op<"clz", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise count leading zero op"; + + let description = [{ + Elementwise count leading zeros operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: exp +//===----------------------------------------------------------------------===// +def Tosa_ExpOp : Tosa_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise exp op"; + + let description = [{ + Elementwise e to the x operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: floor +//===----------------------------------------------------------------------===// +def Tosa_FloorOp : Tosa_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise floor op"; + + let description = [{ + Elementwise floor operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: log +//===----------------------------------------------------------------------===// +def Tosa_LogOp : Tosa_Op<"log", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise log op"; + + let description = [{ + Elementwise natural logarithm operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_not +//===----------------------------------------------------------------------===// +def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Returns the truth value of NOT x element-wise."; + + let description = [{ + Elementwise logical NOT of input. + }]; + + let arguments = (ins + I1Tensor:$x + ); + + let results = (outs + I1Tensor:$y + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: negate +//===----------------------------------------------------------------------===// +def Tosa_NegateOp : Tosa_Op<"negate", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise negate op"; + + let description = [{ + Elementwise negation operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + OptionalAttr:$quantization_info + ); + let results = (outs Tosa_Tensor:$output + ); + + let builders = [Tosa_UnaryOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: reciprocal +//===----------------------------------------------------------------------===// +def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise reciprocal op"; + + let description = [{ + Elementwise reciprocal operation. For integer operation, a TABLE should be used + with the appropriate ranges. + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: rsqrt +//===----------------------------------------------------------------------===// +def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise 1/sqrt op"; + + let description = [{ + Elementwise reciprocal square root operation. For integer operation, a TABLE should be + used with the appropriate ranges. + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.6 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise ternary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: select +//===----------------------------------------------------------------------===// +def Tosa_SelectOp : Tosa_Op<"select", [NoSideEffect]> { + let summary = "Elementwise select operator"; + + let description = [{ + Elementwise select of the output based on a condition. + }]; + + let arguments = (ins + Tosa_Tensor:$condition, + Tosa_Tensor:$a, + Tosa_Tensor:$b + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.7 +// Operator Class: Logical Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: equal +//===----------------------------------------------------------------------===// +def Tosa_EqualOp : Tosa_Op<"equal", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of (x == y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: greater +//===----------------------------------------------------------------------===// +def Tosa_GreaterOp : Tosa_Op<"greater", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Returns the truth value of (x > y) element-wise."; + + let description = [{ + Elementwise greater than comparison operation + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: greater_equal +//===----------------------------------------------------------------------===// +def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Returns the truth value of (x >= y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.8 +// Operator Class: Reduction Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: reduce_all +//===----------------------------------------------------------------------===// +def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> { + let summary = [{ + Reduce All operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical AND operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_any +//===----------------------------------------------------------------------===// +def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> { + let summary = [{ + Reduce Any operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical OR operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_max +//===----------------------------------------------------------------------===// +def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> { + let summary = [{ + Reduce Max operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a maximum operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_min +//===----------------------------------------------------------------------===// +def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> { + let summary = [{ + Reduce Min operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a minimum operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_prod +//===----------------------------------------------------------------------===// +def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> { + let summary = [{ + Reduce Prod operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the product of the axis. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_sum +//===----------------------------------------------------------------------===// +def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> { + let summary = [{ + Reduce Sum operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the sum of the axis. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.9 +// Operator Class: Data Layout / Memory Reinterpretation. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: concat +//===----------------------------------------------------------------------===// +def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> { + let summary = "Concatenates tensors along one dimension."; + + let description = [{ + Concatenate two tensors along a given axis. No data conversion happens during a concat operation. + }]; + + let arguments = (ins + Tosa_Tensor:$a, + Tosa_Tensor:$b, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: pad +//===----------------------------------------------------------------------===// +def Tosa_PadOp : Tosa_Op<"pad", [NoSideEffect]> { + let summary = "Pads a tensor with zeros."; + + let description = [{ + Zero-pads a tensor along borders of each dimension. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Int32Or64Tensor:$paddings, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_PadOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: reshape +//===----------------------------------------------------------------------===// +def Tosa_ReshapeOp: Tosa_Op<"reshape", [ + NoSideEffect]> { + let summary = "Reshape operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with a new shape specified by the shape + argument. Reshape may operate on tensors of any rank. No data conversion happens during a reshape + operation. + }]; + + let arguments = ( + ins Tosa_Tensor:$input, + I64ArrayAttr:$shape); + + let results = (outs Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: reverse +//===----------------------------------------------------------------------===// +def Tosa_ReverseOp: Tosa_Op<"reverse", [ + NoSideEffect]> { + let summary = "Reverse operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with the data reversed along the given + axis. No data conversion happens during a reverse operation. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis); + + let results = (outs + Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: slice +//===----------------------------------------------------------------------===// +def Tosa_SliceOp: Tosa_Op<"slice", [ + NoSideEffect]> { + let summary = "Slice operator"; + + let description = [{ + Extracts a slice of the input tensor 0 on the given axis, beginning at the start coordinates, + and extending for size elements in each direction. No data conversion happens during a slice operation. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64ArrayAttr:$begin, + I64ArrayAttr:$size + ); + + let results = (outs + Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: tile +//===----------------------------------------------------------------------===// +def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> { + let summary = "Tile operator"; + + let description = [{ + Replicates input 0 multiplies times along each dimension. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64ArrayAttr:$multiples); + + let results = (outs + Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose +//===----------------------------------------------------------------------===// +def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> { + let summary = "Transpose operator"; + + let description = [{ + Permutes the dimensions based on perm. + }]; + + let arguments = (ins + Tosa_Tensor:$x, + Tosa_Int32Or64Tensor:$perm + ); + + let results = ( + outs Tosa_Tensor:$y + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.10 +// Operator Class: Scatter/gather Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: gather +//===----------------------------------------------------------------------===// +def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> { + let summary = [{ + Gather operation + }]; + + let description = [{ + Generate a tensor for which each element in the output is a subtensor of the values tensor along + the given axis, based on the value of indices. + }]; + + let arguments = (ins + Tosa_Tensor:$params, + Tosa_Int32Or64Tensor:$indices, + I64Attr:$axis, + + DefaultValuedAttr:$batch_dims + ); + + let results = (outs + Tosa_Tensor:$z + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.11 +// Operator Class: Image Frontend Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: resize +//===----------------------------------------------------------------------===// +def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> { + + let summary = "Resize operation, supports various resize/upsample modes"; + + let description = [{ + Resizes a tensor. Resize is only allowed in the H and W dimensions. In expected use, + stride_y is approximately (IH< { + + let summary = "Cast operation"; + + let description = [{ + Performs a set of permissible cast operations + Mode Input Output + --------------------------------------- + signed 8 to bool int8 Boolean + signed 16 to bool int16 Boolean + signed 32 to bool int32 Boolean + bool to 8 Boolean int8 + bool to 16 Boolean int16 + bool to 32 Boolean int32 + signed 8 to signed 16 int8 int16 + signed 8 to signed 32 int8 int32 + signed 16 to signed 8 int16 int8 + signed 16 to signed 32 int16 int32 + signed 32 to signed 8 int32 int8 + signed 32 to signed 16 int32 int16 + float to signed 8 float int8 + float to signed 16 float int16 + signed 8 to float int8 float + signed 16 to float int16 float + }]; + + let arguments = ( + ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output); + +} + +//===----------------------------------------------------------------------===// +// Operator: rescale +//===----------------------------------------------------------------------===// +def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> { + let summary = "Tosa rescale operator"; + + let description = [{ + Rescale quantized values into a new domain. Supported rescalings are: + Mode Input Output + signed 8 to 8 aint8 aint8 + signed 8 to 16 aint8 int16 + signed 8 to 32 aint8 int32 + signed 16 to 8 int16 aint8 + signed 16 to 16 int16 int16 + signed 16 to 32 int16 int32 + signed 32 to 8 int32 aint8 + signed 32 to 16 int32 int16 + signed 32 to 32 int32 int32 + signed 48 to 8 int48 aint8 + signed 48 to 16 int48 int16 + signed 48 to 32 int48 int32 + unsigned 8 to signed 8 uint8 aint8 + signed 8 to unsigned 8 aint8 uint8 + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I32Attr:$input_zp, + I32Attr:$output_zp, + I32ArrayAttr:$multiplier, + I32ArrayAttr:$shift, + BoolAttr:$scale32, + BoolAttr:$double_round, + BoolAttr:$per_channel + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.13 +// Operator Class: Data Node Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: const +//===----------------------------------------------------------------------===// +def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, NoSideEffect, FirstAttrDerivedResultType]> { + let summary = "Constant op."; + + let description = [{ + A node containing constant data for use as the input to an operation. May hold data + in any of the supported data formats. + }]; + + let arguments = (ins ElementsAttr:$value); + + let results = (outs Tosa_Tensor:$output); + + let builders = [ + OpBuilderDAG<(ins "Type":$type, "Attribute":$value)>, + ]; + +} + +//===----------------------------------------------------------------------===// +// Operator: identity +//===----------------------------------------------------------------------===// +def Tosa_IdentityOp: Tosa_Op<"identity", [NoSideEffect]> { + let summary = "Identity operator"; + let description = [{ + Returns a tensor with the same shape, size, type + and content as the input. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: identityn +//===----------------------------------------------------------------------===// +def Tosa_IdentityNOp: Tosa_Op<"identityn", [NoSideEffect]> { + let summary = "IdentityN operator"; + let description = [{ + Returns a list of tensors with the same shape, type, and contents as the + input list of tensors. + }]; + + let arguments = (ins + Variadic:$input + ); + + let results = (outs + Variadic:$output); +} + + +//===----------------------------------------------------------------------===// +// Operator: placeholder +//===----------------------------------------------------------------------===// +def Tosa_PlaceholderOp : Tosa_Op<"placeholder", [NoSideEffect]> { + let summary = "Placeholder op"; + + let description = [{ + A node where data will be inserted into the network at runtime. Generally used for inputs to the network. + }]; + + let arguments = (ins + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.14 +// Operator Class: Custom Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: custom +//===----------------------------------------------------------------------===// +def Tosa_CustomOp : Tosa_Op<"custom"> { + + let summary = "Custom operator wrapper for Tosa"; + + let description = [{ + Hardware implementing TOSA may choose to add additional custom operators that are not expressed in + the existing TOSA operations. These operators are not expected to be portable across TOSA + implementations. The input and output signatures must be expressed in the corresponding TOSA node. + }]; + + let arguments = (ins + StrAttr:$identifier, + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.15 +// Operator Class: Control Flow Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: cond_if +//===----------------------------------------------------------------------===// +def Tosa_IfOp : Tosa_Op<"cond_if", [ + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Conditional if operator"; + + let description = [{ + Evaluates a Boolean condition and then takes one of two distinct execution paths. This + implements the semantic If-then-else structure. + }]; + + let arguments = (ins + Tosa_Tensor:$cond, + Variadic:$inputs + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region + SizedRegion<1>:$then_branch, + SizedRegion<1>:$else_branch + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: while_loop +//===----------------------------------------------------------------------===// +def Tosa_WhileOp : Tosa_Op<"while_loop", [ + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = [{ + output = input; While (Cond(output)) { output = Body(output) } + }]; + + let description = [{ + Generates and evaluates a Bool condition and either executes a loop body or exits to + another control point. This action is performed repeatedly after updating and re-evaluating + the Boolean condition every iteration. This implements the semantic foreach or while + iterative loop structure. + }]; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + Variadic:$output); + + let regions = (region + SizedRegion<1>:$cond, + SizedRegion<1>:$body + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: yield +//===----------------------------------------------------------------------===// +def Tosa_YieldOp : Tosa_Op<"yield", [Terminator]> { + let summary = "yield operator"; + + let description = [{ + return operation within the conditional and body of + structured control flow. Operation takes variadic operands + but produces no results of its own. + }]; + + let arguments = (ins + Variadic:$inputs + ); + +} + +#endif // TOSA_OPS Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h @@ -0,0 +1,34 @@ +//===-- TosaTraits.h - TOSA dialect operation traits ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect OpTraits in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATRAITS_H +#define MLIR_DIALECT_TOSA_IR_TOSATRAITS_H + +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace OpTrait { +namespace tosa { + +// TBD + +} // namespace tosa +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATRAITS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h @@ -0,0 +1,31 @@ +//===-- TosaTypes.h - TOSA dialect type definitions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect Types in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATYPES_H +#define MLIR_DIALECT_TOSA_IR_TOSATYPES_H + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +namespace tosa { + +// TOSA specific types go here + +} // namespace tosa + +} // end namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATYPES_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -0,0 +1,135 @@ +//===-- TosaTypesBase.td - TOSA type definitions -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the type definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOSA_TYPES_BASE +#else +#define TOSA_TYPES_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Tosa Type Definitions. +//===----------------------------------------------------------------------===// + +// The base class of a quantized type. +// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end]. +// Where low and high ends are 0,255 when unsigned, -128,127 when signe, for the 8-bit case. +class Tosa_QuantizedType params, bit signed> + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # !head(params)>]>, + "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = + StrJoinInt.result # !if(signed, ", true", ", false"); +} + +//===----------------------------------------------------------------------===// +// Non-Quantized Signed Integer Types. +// Used to express accumulator results or compare results. +// +// Booleans are currently assumed to be expressed using int8 or +// built in U1 / U1Tensor type. +//===----------------------------------------------------------------------===// + +def Tosa_Int32 : SI<32>; +def Tosa_Int48 : SI<48>; +def Tosa_Int64 : SI<64>; + +def Tosa_SignedInt : AnyTypeOf<[Tosa_Int32, + Tosa_Int48, + Tosa_Int64]>; + +// No unsigned unquantized int types. +def Tosa_Int : AnyTypeOf<[Tosa_SignedInt]>; + +def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, + Tosa_Int64]>; + +def Tosa_IntTensor : TensorOf<[Tosa_SignedInt]>; + +def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>; + +//===----------------------------------------------------------------------===// +// Quantized Integer Types. +// Datatype for network feature map or weight content. +//===----------------------------------------------------------------------===// +def Tosa_Quint8 : Tosa_QuantizedType<"Uniform", [8], 0>; +def Tosa_Qint8 : Tosa_QuantizedType<"Uniform", [8], 1>; +def Tosa_Qint16 : Tosa_QuantizedType<"Uniform", [16], 1>; + +//===----------------------------------------------------------------------===// +// Name Symmetry Grouping Sign +//===----------------------------------------------------------------------===// +// aint8 : asymmetric per tensor, signed +// uint8 : asymmetric per tensor , unsigned +// int4 : symmetric per channel, signed +// int8 : symmetric per tensor/per channel, signed +// int16 : symmetric per tensor, signed +//===----------------------------------------------------------------------===// +def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"aint8", [8], 1>, + Tosa_QuantizedType<"uint8", [8], 0>, + Tosa_QuantizedType<"int4", [4, 0], 1>, + Tosa_QuantizedType<"int8", [8, 0], 1>, + Tosa_QuantizedType<"int16", [16, 0], 1>]>; + + +//===----------------------------------------------------------------------===// +// Floating-point types +//===----------------------------------------------------------------------===// +def Tosa_Float : AnyTypeOf<[F32, + F16, + BF16]>; + +def Tosa_FpTensor : TensorOf<[Tosa_Float]>; + +//===----------------------------------------------------------------------===// +// Multi-category types +//===----------------------------------------------------------------------===// +def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], + "number">; + +//===----------------------------------------------------------------------===// +// Tensor types +//===----------------------------------------------------------------------===// +def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; + +def Tosa_QTypeTensor : TensorOf<[Tosa_QuantizedInt]>; + +class Tosa_TensorOrNone possibleTypes, string description = ""> : + AnyTypeOf<[TensorOf, NoneType], description>; + +// Any tensor element type allowed in Tosa ops. +def Tosa_ElementType : Type, + "tosa.dtype">; + +class Tosa_TensorOfOrNone allowedTypes, string description = ""> : + AnyTypeOf<[TensorOf, NoneType], description>; + +//===----------------------------------------------------------------------===// +// Iterable attributes +//===----------------------------------------------------------------------===// +// Supported regimes for tosa.resize +def Tosa_ResizeTypeAttr : StringBasedAttr< + CPred<"$_self.cast().getValue() == \"TRANSPOSE\" || " # + "$_self.cast().getValue() == \"BILINEAR\" || " # + "$_self.cast().getValue() == \"NEAREST_NEIGHBOR\"">, + "Supported resize/upsampling strategies">; + +def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; + +// Tensor to buffer types. +def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>; +def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>; +def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>; + +#endif // TOSA_TYPES_BASE Index: mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) +add_public_tablegen_target(MLIRTosaPassIncGen) +add_dependencies(mlir-headers MLIRTosaPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc TosaPasses ./) Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -0,0 +1,36 @@ +//===-- Passes.h - TOSA optimization pass declarations ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class FuncOp; +class ModuleOp; +class Pass; +template +class OperationPass; + +namespace tosa { + +std::unique_ptr> createTosaMakeBroadcastablePass(); + +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -0,0 +1,26 @@ +//===-- Passes.td - TOSA optimization pass declarations ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +include "mlir/Pass/PassBase.td" + +def TosaBinaryInputReshapePass : Pass<"tosa-make-broadcastable", "FuncOp"> { + let summary = "TOSA rank Reshape to enable Broadcasting"; + let description = [{ + Pass that enables broadcast by making all input arrays have the same + number of dimensions. Insert RESHAPE operations to prepend dimensions + of size one until the number of dimensions is equal. Implements + approach similar to step 1 of Numpy 4-step broadcasting: + https://numpy.org/doc/stable/reference/ufuncs.html#broadcasting + }]; + + let constructor = "createTosaMakeBroadcastablePass()"; +} Index: mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -0,0 +1,66 @@ +//===-- QuantUtils.h - TOSA numerical support declarations ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Function declarations for TOSA numerical support functions and quantization +// attribute builders +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_UTILS_QUANT_UTILS_H +#define DIALECT_TOSA_UTILS_QUANT_UTILS_H + +//===----------------------------------------------------------------------===// +// Utililty functions to support quantization handling in Tosa. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +#include "mlir/Dialect/Quant/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/UniformSupport.h" + +namespace mlir { +namespace tosa { + +/// From a scale value, computes multiplier and shift values +/// for 16 or 32-bit scale widths. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth); + +//// Builds ConvOpQuantizationAttr from input and weight. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, Value weight); + +//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B. +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(mlir::OpBuilder &builder, + Value a, Value b); + +//// Builds UnaryOpQuantizationAttr for unary operations from input values. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, + Type outputRawType); + +//// Builds PadOpQuantizationAttr for pad operations from input values. +PadOpQuantizationAttr buildPadOpQuantizationAttr(mlir::OpBuilder &builder, + Value input); + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange); + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange); + +} // namespace tosa +} // namespace mlir + +#endif // DIALECT_TOSA_UTILS_QUANT_UTILS_H Index: mlir/lib/Dialect/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/CMakeLists.txt +++ mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) set(LLVM_OPTIONAL_SOURCES Index: mlir/lib/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_dialect_library(MLIRTosa + IR/TosaOps.cpp + Utils/QuantUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + + DEPENDS + MLIRStandardOpsIncGen + MLIRTosaOpsIncGen + MLIRTosaStructsIncGen + MLIRTosaInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRStandard + MLIRCallInterfaces + MLIRControlFlowInterfaces + MLIRSideEffectInterfaces + MLIRViewLikeInterface + ) + +add_subdirectory(Transforms) Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -0,0 +1,130 @@ +//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This file implements the TOSA Specification: +// https://developer.mlplatform.org/w/tosa/ +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Parser.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; +using namespace mlir::tosa; + +#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc" + +//===----------------------------------------------------------------------===// +// Tosa dialect interfaces. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" + +//===----------------------------------------------------------------------===// +// Dialect Function Inliner Interface. +//===----------------------------------------------------------------------===// +struct TosaInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks. + //===--------------------------------------------------------------------===// + + /// All operations can be inlined by default. + bool isLegalToInline(Operation *op, Region *region, + BlockAndValueMapping &map) const { + return true; + } + + /// All regions with If and While parent operators can be inlined. + /// TODO TEST: need test for the inliner + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &map) const { + return (isa(dest->getParentOp()) || + isa(dest->getParentOp())); + } +}; + +//===----------------------------------------------------------------------===// +// TOSA control flow support. +//===----------------------------------------------------------------------===// + +/// Returns the while loop body. +Region &tosa::WhileOp::getLoopBody() { return body(); } + +bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) { + return !body().isAncestor(value.getParentRegion()); +} + +/// TODO TEST: Test for loop invariant code motion +LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef ops) { + if (ops.empty()) + return success(); + + Operation *tosaWhileOp = this->getOperation(); + for (auto *op : ops) + op->moveBefore(tosaWhileOp); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Tosa Dialect. +//===----------------------------------------------------------------------===// + +TosaDialect::TosaDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// ConstOp +//===----------------------------------------------------------------------===// + +void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, + Attribute value) { + result.addTypes(type); + result.addAttribute("value", value); + return; +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + +Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + if (value.isa() || + (value.isa() && value.getType() != type)) + return builder.create(loc, type, value.cast()); + return nullptr; +} Index: mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRTosaTransforms + TosaMakeBroadcastable.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms + + DEPENDS + MLIRTosaPassIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTosa + ) Index: mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -0,0 +1,219 @@ +//===- TosaMakeBroadcastable.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Insert reshape to binary op's input if needed to match rank +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define PASS_NAME "tosa-make-broadcastable" +#define DEBUG_TYPE PASS_NAME + +namespace mlir { + +namespace tosa { + +namespace { +/// Pass that enables broadcast by making all input arrays have the same +/// number of dimensions. Insert RESHAPE operations to prepend dimensions +/// of size one until the number of dimensions is equal. Implements +/// approach similar to step 1 of Numpy 4-step broadcasting: +/// https://numpy.org/doc/stable/reference/ufuncs.html#broadcasting +/// It is expected that a code generation or other functional expression +/// of the operator implements steps 2-4. +/// TODO: Resolve discrepancy between TensorFlow and Numpy semantics. +class TosaMakeBroadcastable + : public PassWrapper { +public: + explicit TosaMakeBroadcastable() {} + void runOnFunction() override; +}; + +#define REPLACE_OP_LOGICAL(tosa_op, LHS_VALUE, RHS_VALUE) + +#define REPLACE_OP(tosa_op, LHS_VALUE, RHS_VALUE) \ + { \ + rewriter.replaceOpWithNewOp(op, output_type, LHS_VALUE, \ + RHS_VALUE); \ + } + +/// Pass that enables broadcast by making all input arrays have the same +/// number of dimensions. Insert RESHAPE operations to prepend dimensions +/// of size one until the number of dimensions is equal. +// Examples: +// If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1]. +// TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into +// [1, b, 1]. +// If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c]. +// If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c]. +// If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1]. +// If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c]. +// If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1]. +// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. +// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. + +template +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy tosa_binary_op, + PatternRewriter &rewriter) const { + + Value lhs = tosa_binary_op.lhs(); + Value rhs = tosa_binary_op.rhs(); + + int64_t lhs_rank = lhs.getType().dyn_cast().getRank(); + int64_t rhs_rank = rhs.getType().dyn_cast().getRank(); + + Value output = tosa_binary_op.output(); + auto output_type = output.getType().dyn_cast(); + + int64_t higher_rank, lower_rank; + Value higher_tensor_value, lower_tensor_value; + /* return if rank already match */ + if (lhs_rank == rhs_rank) { + return failure(); + } else if (lhs_rank > rhs_rank) { + higher_rank = lhs_rank; + lower_rank = rhs_rank; + higher_tensor_value = lhs; + lower_tensor_value = rhs; + } else { + higher_rank = rhs_rank; + lower_rank = lhs_rank; + higher_tensor_value = rhs; + lower_tensor_value = lhs; + } + + ArrayRef higher_rank_shape = output_type.getShape(); + ArrayRef lower_rank_shape = + lower_tensor_value.getType().dyn_cast().getShape(); + + SmallVector reshape_output_shape; + reshape_output_shape.assign(higher_rank, 1); + + int64_t higher_left_index = 0; + int64_t higher_right_index = higher_rank; + int64_t lower_left_index = 0; + int64_t lower_right_index = lower_rank; + int64_t higher_rank_dim, lower_rank_dim; + + if (lower_right_index != 0 && higher_right_index != 0) { + while (true) { + higher_rank_dim = higher_rank_shape[higher_right_index - 1]; + lower_rank_dim = lower_rank_shape[lower_right_index - 1]; + if (higher_rank_dim == lower_rank_dim) { + reshape_output_shape[higher_right_index - 1] = higher_rank_dim; + + if (higher_right_index > 0) { + higher_right_index--; + } + + if (lower_right_index > 0) { + lower_right_index--; + } + + if (higher_right_index == 0 || lower_right_index == 0) { + break; + } + } else { + break; + } + } + if (lower_right_index != 0 && higher_right_index != 0) { + while (true) { + higher_rank_dim = higher_rank_shape[higher_left_index]; + lower_rank_dim = lower_rank_shape[lower_left_index]; + if (higher_rank_dim == lower_rank_dim) { + reshape_output_shape[higher_left_index] = higher_rank_dim; + + if (higher_left_index < higher_right_index) { + higher_left_index++; + } + + if (lower_left_index < lower_right_index) { + lower_left_index++; + } + + if (higher_left_index == higher_right_index || + lower_left_index == lower_right_index) { + break; + } + } else { + break; + } + } + } + } + + auto reshape_input_type = + lower_tensor_value.getType().dyn_cast(); + auto reshape_output_type = + RankedTensorType::get(ArrayRef(reshape_output_shape), + reshape_input_type.getElementType()); + + auto reshape_lower = rewriter.create( + tosa_binary_op.getLoc(), reshape_output_type, lower_tensor_value, + rewriter.getI64ArrayAttr(reshape_output_shape)); + + if (lhs_rank > rhs_rank) { + rewriter.replaceOpWithNewOp(tosa_binary_op, output_type, + higher_tensor_value, + reshape_lower.getResult()); + } else { + rewriter.replaceOpWithNewOp(tosa_binary_op, output_type, + reshape_lower.getResult(), + higher_tensor_value); + } + + return success(); + } +}; + +void TosaMakeBroadcastable::runOnFunction() { + OwningRewritePatternList patterns; + auto *ctx = &getContext(); + auto func = getFunction(); + + // Add the generated patterns to the list. + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // anonymous namespace + +std::unique_ptr> createTosaMakeBroadcastablePass() { + return std::make_unique(); +} + +static PassRegistration + pass(PASS_NAME, + "Perform broadcast on elementwise TosaOps to ensure same rank"); + +} // namespace tosa + +} // namespace mlir Index: mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -0,0 +1,317 @@ +//===- QuantUtils.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains TOSA numerical support functions and quantization +// attribute builders. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" + +namespace mlir { +namespace tosa { + +namespace { + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 16-bit scaling. +void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 15)); + + assert(shiftedM <= (int64_t(1) << 15)); // Can't be greater that 1.0. + if (shiftedM == (int64_t(1) << 15)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive and embed (1 << 15) into right + // shift bits. + shift = (-shift) + 15; + + assert(shiftedM <= std::numeric_limits::max()); + + multiplier = static_cast(shiftedM); +} + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 32-bit scaling. +void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 31)); + + assert(shiftedM <= (int64_t(1) << 31)); // Can't be greater that 1.0. + if (shiftedM == (int64_t(1) << 31)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive, and embed (1 << 31) into right + // shift bits. + shift = (-shift) + 31; + + assert(shiftedM <= std::numeric_limits::max()); + + multiplier = static_cast(shiftedM); +} + +} // namespace + +/// Generates a quantized multiplier/shift from double. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth) { + + switch (scaleWidth) { + case 16: + computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); + return; + case 32: + computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); + return; + default: + assert(0 && "Unsupported Tosa quantized_scale regime specified!"); + } +} + +#define GET_UQTYPE(input_type) \ + ((input_type).getElementType().dyn_cast()) + +// Method to build ConvOpQuantizationAttr, called from +// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: input_zp: input zeropoint +// weight_zp: weight zeropoint. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, Value weight) { + + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); + + if (!inputType || !weightType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + auto weightPerTensorQType = GET_UQTYPE(weightType); + auto weightPerAxisQType = + weightType.getElementType() + .dyn_cast(); + + // Weight either per-tensor quantized or per-axis quantized + assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType)); + + // Either all quantized or all not quantized. + assert(!((bool)inputQType ^ + ((bool)weightPerTensorQType || (bool)weightPerAxisQType))); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t weightZp = 0; + + if (weightPerTensorQType) { + weightZp = weightPerTensorQType.getZeroPoint(); + } else if (weightPerAxisQType) { + weightZp = weightPerAxisQType.getZeroPoints().front(); + } + + auto quantAttr = mlir::tosa::ConvOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds MatMulOpQuantizationAttr, called from +/// MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: input b zeropoint +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(mlir::OpBuilder &builder, + Value a, Value b) { + + auto aType = a.getType().dyn_cast(); + auto bType = b.getType().dyn_cast(); + + if (!aType || !bType) + return nullptr; + + auto aQType = GET_UQTYPE(aType); + auto bQType = GET_UQTYPE(bType); + + // A and B are either all quantized or all not quantized. + assert(!((bool)aQType ^ (bool)bQType)); + + if (aQType) { + + int64_t aZp = aQType.getZeroPoint(); + int64_t bZp = bQType.getZeroPoint(); + + auto quantAttr = mlir::tosa::MatMulOpQuantizationAttr::get( + builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds UnaryOpQuantizationAttr +/// UnaryOpQuantInfoBuilder: inputZp: input zeropoint +/// outputZp: output zeropoint. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, + Type outputRawType) { + + auto inputType = input.getType().dyn_cast(); + auto outputType = outputRawType.dyn_cast(); + + if (!inputType || !outputType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + auto outputQType = GET_UQTYPE(outputType); + + // Either all quantized or all not quantized. + assert(!((bool)inputQType ^ (bool)outputQType)); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t outputZp = outputQType.getZeroPoint(); + + auto quantAttr = mlir::tosa::UnaryOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: +/// inputZp: input zeropoint +PadOpQuantizationAttr buildPadOpQuantizationAttr(mlir::OpBuilder &builder, + Value input) { + + auto inputType = input.getType().dyn_cast(); + + if (!inputType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + + auto quantAttr = mlir::tosa::PadOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange) { + + quant::QuantizedType retType; + + auto convfunc = + quant::ExpressedToQuantizedConverter::forInputType(inputDType); + + auto minElems = minAttr.dyn_cast(); + auto maxElems = maxAttr.dyn_cast(); + + SmallVector min, max; + + // At least one is per-axis quantized elementsattr. + if (minElems || maxElems) { + + // Must have the same number of elements. + if (minElems.getNumElements() != maxElems.getNumElements()) + return {}; + + min.reserve(minElems.getNumElements()); + max.reserve(maxElems.getNumElements()); + for (auto i : minElems) { + min.push_back(FloatAttr::getValueAsDouble(i)); + } + for (auto i : maxElems) { + max.push_back(FloatAttr::getValueAsDouble(i)); + } + } else { // Just a single FP value. + + auto minVal = minAttr.dyn_cast(); + if (minVal) + min.push_back(minVal.getValueAsDouble()); + else + return {}; + auto maxVal = maxAttr.dyn_cast(); + if (maxVal) + max.push_back(maxVal.getValueAsDouble()); + else + return {}; + } + + if (min.size() == max.size()) { + + if (min.size() == 1) { // Per-tensor quantization with one min/max pair. + + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], + narrowRange.getValue(), convfunc.expressedType, isSigned); + + } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. + + auto shape = inputDType.dyn_cast(); + if (!shape) + return {}; + if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { + + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], + max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); + } + + } else { + return {}; + } + } else { + return {}; + } + + if (!retType) + return {}; + + return convfunc.convert(retType); +} + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange) { + + return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, + maxAttr, quantBits, filterQuantDim, + isSigned, narrowRange)); +} + +} // namespace tosa +} // namespace mlir