Index: mlir/docs/Dialects/TOSA.md =================================================================== --- /dev/null +++ mlir/docs/Dialects/TOSA.md @@ -0,0 +1,95 @@ +# TOSA Dialect + +[TOC] + +## Rationale + +The MLIR TOSA dialect implements the [TOSA +specification](https://developer.mlplatform.org/w/tosa/). This document +describes the decision process for how TOSA expresses operators in +high level dialects. + +TOSA was developed after parallel efforts to rationalize the top-down picture from multiple high-level frameworks, as well as a bottom-up view of different hardware target concerns (CPU, GPU and NPU), and reflects a set of choices that attempt to manage both sets of requirements. + +## TOSA and Tensor Level Expressiveness + +TOSA endeavors to provide an operator set that tries to fulfil the following expressivenes goals at the *tensor level of abstraction* : + +### Complete + +This is driven by the top-down perspective, needing to express as much of multiple high level frameworks fully in TOSA, as possible. This was originally done from an operator frequency analysis done upon dozens of high level networks in different frameworks, to select the most frequently occuring ones and establish a common set of tensor-level operators that could express them. + +TOSA categorizes its operator set into classes and attempts to address major functional operations at the tensor level, including compute, reduction, elementwise transformations, comparison and control flow. + +### Minimal + +This takes the bottom-up approach - keep the TOSA operator set minimal in order to bound the design of hardware, operator kernels, code generation strategies and associated considerations that effect the executability of TOSA content. + +In this regard TOSA seeks to avoid creating compound operators, instead leaving it to compiler backend to fuse multiple TOSA ops if required. This choice also benefits the numerical precision goal, since it is easier to fuse the numerical functionality of successive operators, than to split the numerical functionality of a compound operator. + +### Numerical Precision + +TOSA began as a means to address operator-level numerical precision for code generation and hardware development. It therefore incorporates precision detail into the operator set. + +In this regard, TOSA operators are best understood as a combination of the visible quantization information embedded within an operation, together with the functional information about how that information is used, as described in the specification of the operation. + +## TOSA Operator Rationale + +The general basis of selection of the operator set that constitutes TOSA is described in the TOSA specification document under Section 1.3 Operator Selection. Explanation of the thinking behind some operators is listed here: + +### IDENTITYN + +tosa.IDENTITYN is used to form a list of Operator results during +lowering of operations such as tf.Split from a sequence of tosa.SLICE +ops. If there are alternate ways to express this lowering without the +tosa.IDENTITYN op, the tosa.IDENTITYN op could be removed from TOSA. + +``` +Value lower_split_op(Value %value, size_t axis, size_t +num_split) { Value %output[] + + size_t slice_size = %value.shape[axis] / num_split + + for (int i = 0; i < num_split; i++) { + vector begin_vals, size_vals + + for (int j = 0; j < %value.rank; j++) { + if (j == axis) { + begin_vals.push_back(slice_size * i) + size_vals.push_back(slice_size) + } else { + begin_vals.push_back(0) + size_vals.push_bac(%value.shape[j]) + } + + %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals} (tensor<%value.type>) -> tensor + } + + } + + %output_list = tosa.IDENTITYN(%output) (tensor<%output:*.type>) -> tensor<%output_list:*.type> + return %output_list +} +``` + +### COND\_IF and WHILE\_LOOP + +Several neural networks express conditional control flow at the tensor level. A survey of multiple high level frameworks indicated that conditional if and a loop construct are common in all major frameworks, with some variation. Since TOSA endeavors to be complete in expressing tensor level functionality including control flow, it implements these constructs. + +The COND\_IF and WHILE\_LOOP operators implement such structured control flow forms and should be lowerable to corresponding ops in the scf dialect. Since the dialect seeks to remain isomorphic with an external, serialized form, the decision was to keep these ops in the dialect (as opposed to deferring completely to scf), and this may be re-evaluated if this turns out to not yield the expected value. + +## Using TOSA In A Compiler + +The TOSA specification describes each operator in functional detail. It is expected that compilers that use TOSA will use its builders to construct the operators so that the quantization information for the operator is correctly generated. + +The functional steps described in the pseudocode of the specification enables the construction of code generation for that operation, or decisions on the design of underlying hardware. The functional pseudocode also describes how the quantization parameters are utilized within the operation. + +### Quantization Parameters in Ops vs Tensors + +TOSA uses the quantization parameters embedded in the input and output tensors to construct the quantization attributes that sit within the operator. Once these attributes are constructed, the quantization information within the tensors are no longer necessary for code generation. + +This enables the tensors to be subsequently interpreted simply as contiguous buffers containing raw data, with no 'meta information' in the form of the quantization_type. Precision related manipulation of the input or output are instead described by the operator itself which describes, for example, when the zero point is applied, or when the scale multiplication is done. + +However, TOSA does *not* eliminate the existing MLIR QuantOps quantization type information within the tensors; this leaves the choice of how to handle quantization information, to later backend code generation steps. + +Maintaining the ability to overlap these different representations of quantization parameters (i.e. tensor-carried vs op-carried) is an important capability when considering progressive lowering between uses that expect one scheme vs the other. Index: mlir/include/mlir/Dialect/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/CMakeLists.txt +++ mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,4 +13,5 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) Index: mlir/include/mlir/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) Index: mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaDialect.h.inc -gen-dialect-decls) +mlir_tablegen(TosaOps.h.inc -gen-op-decls) +mlir_tablegen(TosaOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTosaOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaStructs.h.inc -gen-struct-attr-decls) +mlir_tablegen(TosaStructs.cpp.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRTosaStructsIncGen) + + +set(LLVM_TARGET_DEFINITIONS TosaInterfaces.td) +mlir_tablegen(TosaInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TosaInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRTosaInterfaceIncGen) + +add_mlir_doc(TosaOps -gen-op-doc TosaOps Dialects/) Index: mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td @@ -0,0 +1,27 @@ +//===-- TosaInterfaces.td - TOSA dialect interfaces --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dialect op interfaces for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_OP_INTERFACES +#define TOSA_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def TosaOpInterface : OpInterface<"TosaOp"> { + let description = [{ + Implements interfaces implemented by ops that correspond to the Tosa + specification. + }]; + + let methods = []; +} + +#endif Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -0,0 +1,166 @@ +//===-- TosaOpBase.td - TOSA dialect op builders -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the common definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef TOSA_OP_BASE +#define TOSA_OP_BASE + +//===----------------------------------------------------------------------===// +// TOSA Operator Quantization Attributes. +//===----------------------------------------------------------------------===// + +// Quantization attributes used across TOSA operators. Quantization attributes +// feed numerical precision parameters to the functional implementation of TOSA +// operators. +// The functional behavior is defined in the TOSA specification maintained at +// https://developer.mlplatform.org/w/tosa/ . TOSA leverages MLIR's built in +// quantization support: https://mlir.llvm.org/docs/Quantization/, and supports +// uniform quantization. Depending on datatype, asymmetric and symmetric +// quantization are supported. The types themselves are described in +// TosaTypesBase.td . + +// This quantization attribute expresses numerical behavior of operators where +// the operator has a numerical relationship between a single input and output. +// For example: tosa.negate. +def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr", + Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"output_zp", I32Attr> + ]> { + let description = "Attribute for UnaryOp quantization information."; +} + +// There is no explicit BinaryOpQuantizationAttr for 2-input/1-output ops. In +// this case, a tosa.rescale is used to express the inputs to the same scale. +// TODO: Upload WIP legalization document describing this construction by +// example. + +// This quantization attribute holds input and weight zero point. Both the +// ConvOp and MatMulOp QuantizationAttrs follow a common design semantic where +// their ownquantization attribute only expresses the numerical behavior at +// the inputs. +// The scaling of their accumulator output is done using an explicit +// tosa.rescale operator that scales the accumulator result to output scale. +def Tosa_ConvOpQuantizationAttr : StructAttr<"ConvOpQuantizationAttr", + Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"weight_zp", I32Attr> + ]> { + let description = "Attribute for Conv type op quantization information."; +} + +def Tosa_MatMulOpQuantizationAttr : StructAttr<"MatMulOpQuantizationAttr", + Tosa_Dialect, [ + StructFieldAttr<"a_zp", I32Attr>, + StructFieldAttr<"b_zp", I32Attr> + ]> { + let description = "Attribute for MatMulOp quantization information."; +} + +// This attribute holds input zero point correction applied to the padding +// zeros to ensure numerical accuracy in the subsequent TOSA operations. +// Its functional application is described in the tosa.pad() operator +// description in the specification. +def Tosa_PadOpQuantizationAttr : StructAttr<"PadOpQuantizationAttr", + Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr> + ]> { + let description = "Attribute for PadOp quantization information."; +} + +//===----------------------------------------------------------------------===// +// TOSA Operator Quantization Builders. +//===----------------------------------------------------------------------===// + +// This builder is called on all convolution operators except for TransposeConv, +// which has specialized output shape semantics. The builder also defines the +// bitwidth of the output given the bit width of the input & weight content. +def Tosa_ConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias, + "ArrayAttr":$pad, "ArrayAttr":$stride, "ArrayAttr":$dilation), + [{ + ::ConvOpQuantInfoBuilder($_builder, $_state, outputType, + input, weight, bias, + pad, stride, dilation); + }]>; + +// Handles tosa.transpose_conv2d which has an outpad and output shape attribute. +def Tosa_TransConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias, + "ArrayAttr":$outpad, "ArrayAttr":$stride, "ArrayAttr":$dilation, + "ArrayAttr":$outputShape), + [{ + ::TransConvOpQuantInfoBuilder($_builder, $_state, outputType, + input, weight, bias, + outpad, stride, dilation, + outputShape); + }]>; + +// The tosa.fully_connected op has its own builder as it does not have +// strides/dilation/padding. +def Tosa_FCOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias), + [{ + ::FCOpQuantInfoBuilder($_builder, $_state, outputType, + input, weight, bias); + }]>; + +// The tosa.matmul op is also intended to be generated where a fully_connected +// op must be constructed where the weight is not a constant. In this case, +// the fully_connected op must be expressed using matmul. +// TODO: Add link to the leglization document explaining this. +def Tosa_MatMulOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$a, "Value":$b), + [{ + ::MatMulOpQuantInfoBuilder($_builder, $_state, outputType, + a, b); + }]>; + +// Both the tosa.avg_pool2d and unary ops use the same +// UnaruOpQuantizationAttr but the avg_pool operator has its own builder as it +// has additional parameters not part of the unary ops. +def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "ArrayAttr":$kernel, + "ArrayAttr":$stride, "ArrayAttr":$pad), + [{ + ::AvgPool2dOpQuantInfoBuilder($_builder, $_state, outputType, + input, kernel, stride, pad); + }]>; + +// This builder is called on single-parameter unary operators that have a scale +// relationship between their input and output, expressed by the +// UnaryOpQuantizationAttr. +def Tosa_UnaryOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input), + [{ + ::UnaryOpQuantInfoBuilder($_builder, $_state, outputType, input); + }]>; + +// This builder is called on the TOSA pad operator that needs to create its own +// OptionalAttr quantization_attr parameter to scale the padding values +// correctly. +def Tosa_PadOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$paddings), + [{ + ::PadOpQuantInfoBuilder($_builder, $_state, outputType, + input, paddings); + }]>; + +//===----------------------------------------------------------------------===// +// TOSA Operator. +//===----------------------------------------------------------------------===// + +class Tosa_Op traits = []> : + Op { +} + +#endif // TOSA_OP_BASE Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -0,0 +1,70 @@ +//===-- TosaOps.h - TOSA dialect operation definitions ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_IR_TOSA_OPS_H +#define DIALECT_TOSA_IR_TOSA_OPS_H + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tosa/IR/TosaTraits.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" + +//===----------------------------------------------------------------------===// +// TOSA dialect and structs includes. +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/IR/TosaDialect.h.inc" +#include "mlir/Dialect/Tosa/IR/TosaStructs.h.inc" + +namespace mlir { +namespace tosa { + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" + +//===----------------------------------------------------------------------===// +// TOSA Operator Quantization Builders. +//===----------------------------------------------------------------------===// +void ConvOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value weight, + Value bias, ArrayAttr pad, ArrayAttr stride, + ArrayAttr dilation); +void TransConvOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value weight, + Value bias, ArrayAttr outpad, ArrayAttr stride, + ArrayAttr dilation, ArrayAttr outpadShape); +void FCOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value weight, + Value bias); +void MatMulOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value a, Value b); +void AvgPool2dOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, ArrayAttr kernel, + ArrayAttr stride, ArrayAttr pad); +void UnaryOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input); +void PadOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value paddings); + +} // end namespace tosa +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" + +#endif // TOSA_OPS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -0,0 +1,1748 @@ +//===-- TosaOps.td - TOSA dialect operation definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the operation set for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_OPS +#define TOSA_OPS + +include "mlir/IR/OpBase.td" + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" + +include "mlir/Dialect/Tosa/IR/TosaTypesBase.td" + +//===----------------------------------------------------------------------===// +// The TOSA Dialect. +//===----------------------------------------------------------------------===// +def Tosa_Dialect : Dialect { + let name = "tosa"; + + let description = [{ + The Tensor Operator Set Architecture (TOSA) dialect. + + This dialect implements the TOSA standard described at + https://developer.mlplatform.org/w/tosa/ . + + Tensor Operator Set Architecture (TOSA) provides a set of whole-tensor + operations commonly employed by Deep Neural Networks. The intent is to + enable a variety of implementations running on a diverse range of + processors, with the results at the TOSA level consistent across those + implementations. Applications or frameworks which target TOSA can therefore + be deployed on a wide range of different processors, such as CPUs or GPUs, + with defined accuracy and compatibility constraints. Most operators from the + common ML frameworks should be expressible in TOSA. It is expected that + there will be tools to lower from the ML frameworks into TOSA. + }]; + + let cppNamespace = "mlir::tosa"; +} + +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.2 +// Operator Class: Tensor Data Engine Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: argmax +//===----------------------------------------------------------------------===// +def Tosa_ArgMaxOp : Tosa_Op<"argmax", [NoSideEffect]> { + let summary = "Perform argmax on the input."; + + let description = [{ + This returns the index with the largest value across the given axis of the + input tensor. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D: $input, + I64Attr: $axis + ); + + let results = (outs + Tosa_TensorUpto4D: $output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: avg_pool2d +//===----------------------------------------------------------------------===// +def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs an average pooling over the given input tensor. A sliding + window of size given by is passed over the input tensor, with + the mean value being placed in the output tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + + Tosa_IntArrayAttr2:$kernel, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr4:$pad, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: conv2d +//===----------------------------------------------------------------------===// +def Tosa_Conv2DOp : Tosa_Op<"conv2d", [NoSideEffect]> { + let summary = [{ + 2D Convolution Operator + }]; + + let description = [{ + Performs a 2D convolution over the given tensor input, using the weight + tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + Tosa_Tensor4D:$weight, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr4:$pad, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr2:$dilation, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// Operator: conv3d +//===----------------------------------------------------------------------===// +def Tosa_Conv3DOp : Tosa_Op<"conv3d", [NoSideEffect]> { + let summary = [{ + 3D Convolution operator + }]; + + let description = [{ + Performs a 3D convolution over the given input tensor. + }]; + + let arguments = (ins + Tosa_Tensor5D:$input, + Tosa_Tensor5D:$weight, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr6:$pad, + Tosa_IntArrayAttr3:$stride, + Tosa_IntArrayAttr3:$dilation, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor5D:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// Operator: depthwise_conv2d +//===----------------------------------------------------------------------===// +def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [NoSideEffect]> { + let summary = [{ + Depthwise 2D Convolution operator + }]; + + let description = [{ + Performs 2D convolutions separately over each channel of the given tensor + input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + Tosa_Tensor4D:$weight, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr4:$pad, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr2:$dilation, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// Operator: fully_connected +//===----------------------------------------------------------------------===// +def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> { + let summary = "Fully Connected operator"; + + let description = [{ + Performs a fully connected network. + }]; + + let arguments = (ins + Tosa_Tensor2D:$input, + Tosa_Tensor2D:$weight, + Tosa_Tensor1D:$bias, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor2D:$output + ); + + let builders = [Tosa_FCOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// Operator: matmul +//===----------------------------------------------------------------------===// +def Tosa_MatMulOp : Tosa_Op<"matmul", [NoSideEffect]> { + let summary = "Matrix multiplication with bias"; + + let description = [{ + Performs a two dimensional matrix multiplication. This allows both inputs to + be activations, rather than reserving weights as an attribute in the + FULLY_CONNECTED operator. + }]; + + let arguments = (ins + Tosa_Tensor2D:$a, + Tosa_Tensor2D:$b, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor2D:$c + ); + + let builders = [Tosa_MatMulOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: max_pool2d +//===----------------------------------------------------------------------===// +def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs a max pooling over the given input tensor. A sliding window of + size given by is passed over the input tensor, with the + maximum value being placed in the + output tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + + Tosa_IntArrayAttr2:$kernel, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr4:$pad + ); + + let results = (outs + Tosa_Tensor4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose_conv2d +//===----------------------------------------------------------------------===// +def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [NoSideEffect]> { + let summary = [{ + Transpose 2D Convolution operator. + }]; + + let description = [{ + Performs a 2D transposed convolution over the given tensor input, using the + weights tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + Tosa_Tensor4D:$filter, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr2:$out_pad, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr2:$dilation, + Tosa_IntArrayAttrUpto4:$out_shape, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_TransConvOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.3 +// Operator Class: Activation Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: clamp +//===----------------------------------------------------------------------===// +def Tosa_ClampOp : Tosa_Op<"clamp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes clamp(features, min, max)."; + + let description = [{ + Clamp to an arbitrary minimum and maximum value. Note that the maximum and + minimum values are specified as signed quantized values, no scaling happens + before or after this operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$min_int, + I64Attr:$max_int, + F32Attr:$min_fp, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reluN +//===----------------------------------------------------------------------===// +def Tosa_ReluNOp : Tosa_Op<"reluN", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes rectified linear: `max(features, N)`."; + + let description = [{ + ReLU with a scalar maximum value. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$max_int, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: sigmoid +//===----------------------------------------------------------------------===// +def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Computes elementwise sigmoid of input."; + + let description = [{ + Sigmoid function: output = 1 / (1 + exp(-input)) + For quantized integer data types, the TABLE operator should be used instead + with the following definition. The sigmoid table has 513 entries each of + 16-bit precision and covering the input range -16.0 to +16.0 + in steps of 1/16. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: tanh +//===----------------------------------------------------------------------===// +def Tosa_TanhOp : Tosa_Op<"tanh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise hyperbolic tangent of input"; + + let description = [{ + Parameterized hyperbolic tangent. + For quantized integer data types, the TABLE operator should be used instead + with the following definition. The tanh_table has 513 entries each of + 16-bit precision and covering the input range -8.0 to +8.0 in steps of 1/32. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.4 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise binary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: add +//===----------------------------------------------------------------------===// +def Tosa_AddOp : Tosa_Op<"add", [ResultsBroadcastableShape, NoSideEffect, + Commutative]> { + let summary = "Elementwise addition operator"; + + let description = [{ + Elementwise addition of input1 and input2. Axis of size 1 will be broadcast, + as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: arithmetic_right_shift +//===----------------------------------------------------------------------===// +def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", + [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Elementwise Arithmetic Right Shift"; + + let description = [{ + Elementwise arithmetic right shift of input1 by the amount specified in + input2. Axis of size 1 will be broadcast, as necessary. Rank of input + tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2, + BoolAttr:$round + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_and +//===----------------------------------------------------------------------===// +def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Bitwise AND operator"; + + let description = [{ + Elementwise bitwise AND of input1 and input2. Axis of size 1 + will be broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_or +//===----------------------------------------------------------------------===// +def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Bitwise OR operator"; + + let description = [{ + Elementwise bitwise OR of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_xor +//===----------------------------------------------------------------------===// +def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Bitwise XOR operator"; + + let description = [{ + Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_and +//===----------------------------------------------------------------------===// +def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ResultsBroadcastableShape, + Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x AND y element-wise."; + + let description = [{ + Elementwise logical AND of input1 and input2. Axis of size 1 will be + broadcast, as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$input1, + I1Tensor:$input2 + ); + + let results = (outs + I1Tensor:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_left_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", + [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Elementwise Logical Left Shift"; + + let description = [{ + Elementwise left shift of input1 and input2. Axis of size 1 will be + broadcast, as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_right_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", + [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Elementwise Logical Right Shift"; + + let description = [{ + Elementwise logical right shift of input1 by the amount specified in input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_or +//===----------------------------------------------------------------------===// +def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ResultsBroadcastableShape, + Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x OR y element-wise."; + + let description = [{ + Elementwise logical OR of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$input1, + I1Tensor:$input2 + ); + + let results = (outs + I1Tensor:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_xor +//===----------------------------------------------------------------------===// +def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ResultsBroadcastableShape, + Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x XOR y element-wise."; + + let description = [{ + Elementwise logical XOR of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$input1, + I1Tensor:$input2 + ); + + let results = (outs + I1Tensor:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: maximum +//===----------------------------------------------------------------------===// +def Tosa_MaximumOp : Tosa_Op<"maximum", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Elementwise Maximum"; + + let description = [{ + Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as + necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: minimum +//===----------------------------------------------------------------------===// +def Tosa_MinimumOp : Tosa_Op<"minimum", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Elementwise Minimum"; + + let description = [{ + Elementwise minimum of input1 and input2. Axis of size 1 + will be broadcast, as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: mul +//===----------------------------------------------------------------------===// +def Tosa_MulOp : Tosa_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, + Commutative]> { + let summary = "Multiplication operator"; + + let description = [{ + Elementwise multiplication (Hadamard product) of input1 and input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2, + I32Attr:$shift + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: pow +//===----------------------------------------------------------------------===// +def Tosa_PowOp : Tosa_Op<"pow", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Computes the power of one value to another."; + + let description = [{ + Elementwise input1 raised to the power of input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: sub +//===----------------------------------------------------------------------===// +def Tosa_SubOp : Tosa_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise subtraction operator"; + + let description = [{ + Elementwise subtraction of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: table +//===----------------------------------------------------------------------===// +def Tosa_TableOp : Tosa_Op<"table", [NoSideEffect]> { + let summary = "Table lookup op"; + + let description = [{ + Interpolated table lookup operation. Input values are scaled to create a + fixed-point 9.7 value. The high 9 bits are used to index into the table. + The fractional bits are used to interpolate based on the looked up value and + the index+1 value in the table. The TABLE operator then returns a 16.7 + interpolated value. Note that there must be 513 values to handle the full + range of inputs. + + The TABLE operator is expected to be used as follows: + * A RESCALE node is expected before the TABLE operator to scale the input + to a full int16_t range for the table lookup + * If an int16_t result is required then follow the TABLE operator with a + RESCALE with a right shift of 7 + * If an int8_t result is required then follow the TABLE operator with a + RESCALE with a right shift of 15 + }]; + + let arguments = (ins + Tosa_TensorUpto4D: $input, + Tosa_Tensor1D: $table + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.5 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise unary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: abs +//===----------------------------------------------------------------------===// +def Tosa_AbsOp : Tosa_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise abs op"; + + let description = [{ + Elementwise absolute value operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_not +//===----------------------------------------------------------------------===// +def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Bitwise NOT operator"; + + let description = [{ + Elementwise bitwise NOT of input tensor. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: ceil +//===----------------------------------------------------------------------===// +def Tosa_CeilOp : Tosa_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise ceil op"; + + let description = [{ + Elementwise ceiling operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: clz +//===----------------------------------------------------------------------===// +def Tosa_ClzOp : Tosa_Op<"clz", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise count leading zero op"; + + let description = [{ + Elementwise count leading zeros operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: exp +//===----------------------------------------------------------------------===// +def Tosa_ExpOp : Tosa_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise exp op"; + + let description = [{ + Elementwise e to the x operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: floor +//===----------------------------------------------------------------------===// +def Tosa_FloorOp : Tosa_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise floor op"; + + let description = [{ + Elementwise floor operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: log +//===----------------------------------------------------------------------===// +def Tosa_LogOp : Tosa_Op<"log", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise log op"; + + let description = [{ + Elementwise natural logarithm operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_not +//===----------------------------------------------------------------------===// +def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Returns the truth value of NOT x element-wise."; + + let description = [{ + Elementwise logical NOT of input. + }]; + + let arguments = (ins + I1Tensor:$input1 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: negate +//===----------------------------------------------------------------------===// +def Tosa_NegateOp : Tosa_Op<"negate", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Elementwise negate op"; + + let description = [{ + Elementwise negation operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); + + let builders = [Tosa_UnaryOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: reciprocal +//===----------------------------------------------------------------------===// +def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Elementwise reciprocal op"; + + let description = [{ + Elementwise reciprocal operation. For integer operation, a TABLE should be + used with the appropriate ranges. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: rsqrt +//===----------------------------------------------------------------------===// +def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise 1/sqrt op"; + + let description = [{ + Elementwise reciprocal square root operation. For integer operation, a TABLE + should be used with the appropriate ranges. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.6 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise ternary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: select +//===----------------------------------------------------------------------===// +def Tosa_SelectOp : Tosa_Op<"select", [NoSideEffect]> { + let summary = "Elementwise select operator"; + + let description = [{ + Elementwise select of the output based on a condition. + }]; + + let arguments = (ins + I1Tensor:$input1, + Tosa_TensorUpto4D:$input2, + Tosa_TensorUpto4D:$input3 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.7 +// Operator Class: Logical Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: equal +//===----------------------------------------------------------------------===// +def Tosa_EqualOp : Tosa_Op<"equal", [ResultsBroadcastableShape, Commutative, + NoSideEffect]> { + let summary = "Returns the truth value of (x == y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: greater +//===----------------------------------------------------------------------===// +def Tosa_GreaterOp : Tosa_Op<"greater", [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Returns the truth value of (x > y) element-wise."; + + let description = [{ + Elementwise greater than comparison operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: greater_equal +//===----------------------------------------------------------------------===// +def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Returns the truth value of (x >= y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.8 +// Operator Class: Reduction Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: reduce_all +//===----------------------------------------------------------------------===// +def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> { + let summary = [{ + Reduce All operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical AND operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_any +//===----------------------------------------------------------------------===// +def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> { + let summary = [{ + Reduce Any operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical OR operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_max +//===----------------------------------------------------------------------===// +def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> { + let summary = [{ + Reduce Max operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a maximum operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_min +//===----------------------------------------------------------------------===// +def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> { + let summary = [{ + Reduce Min operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a minimum operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_prod +//===----------------------------------------------------------------------===// +def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> { + let summary = [{ + Reduce Prod operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the product of the axis. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_sum +//===----------------------------------------------------------------------===// +def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> { + let summary = [{ + Reduce Sum operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the sum of the axis. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.9 +// Operator Class: Data Layout / Memory Reinterpretation. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: concat +//===----------------------------------------------------------------------===// +def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> { + let summary = "Concatenates tensors along one dimension."; + + let description = [{ + Concatenate two tensors along a given axis. No data conversion happens + during a concat operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input1, + Tosa_Tensor1Dto4D:$input2, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: pad +//===----------------------------------------------------------------------===// +def Tosa_PadOp : Tosa_Op<"pad", [NoSideEffect]> { + let summary = "Pads a tensor with zeros."; + + let description = [{ + Zero-pads a tensor along borders of each dimension. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input1, + Tosa_Int32Or64Tensor:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + + let builders = [Tosa_PadOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: reshape +//===----------------------------------------------------------------------===// +def Tosa_ReshapeOp: Tosa_Op<"reshape", [ + NoSideEffect]> { + let summary = "Reshape operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with a new shape + specified by the shape argument. Reshape may operate on tensors of any rank. + No data conversion happens during a reshape operation. + }]; + + let arguments = (ins + Tosa_TensorUpto6D:$input1, + I64ArrayAttr:$new_shape + ); + + let results = (outs + Tosa_TensorUpto6D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reverse +//===----------------------------------------------------------------------===// +def Tosa_ReverseOp: Tosa_Op<"reverse", [NoSideEffect]> { + let summary = "Reverse operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with the data + reversed along the given axis. No data conversion happens during a reverse + operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: slice +//===----------------------------------------------------------------------===// +def Tosa_SliceOp: Tosa_Op<"slice", [NoSideEffect]> { + let summary = "Slice operator"; + + let description = [{ + Extracts a slice of the input1 on the given axis, beginning at the + start coordinates, and extending for size elements in each direction. No + data conversion happens during a slice operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto6D:$input, + I64ArrayAttr:$start, + I64ArrayAttr:$size + ); + + let results = (outs + Tosa_Tensor1Dto6D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: tile +//===----------------------------------------------------------------------===// +def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> { + let summary = "Tile operator"; + + let description = [{ + Replicates input 0 multiplies times along each dimension. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input1, + I64ArrayAttr:$multiples); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose +//===----------------------------------------------------------------------===// +def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> { + let summary = "Transpose operator"; + + let description = [{ + Permutes the dimensions based on perm. + }]; + + let arguments = (ins + Tosa_Tensor1Dto6D:$input1, + Tosa_Int32Or64Tensor:$perms + ); + + let results = ( + outs Tosa_Tensor1Dto6D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.10 +// Operator Class: Scatter/gather Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: gather +//===----------------------------------------------------------------------===// +def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> { + let summary = [{ + Gather operation + }]; + + let description = [{ + Generate a tensor for which each element in the output is a subtensor of the + values tensor along the given axis, based on the value of indices. + }]; + + let arguments = (ins + Tosa_Int32Or64Tensor:$indices, + Tosa_Tensor1Dto4D:$values, + I32Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.11 +// Operator Class: Image Frontend Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: resize +//===----------------------------------------------------------------------===// +def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> { + + let summary = "Resize operation, supports various resize/upsample modes"; + + let description = [{ + Resizes a tensor. Resize is only allowed in the H and W dimensions. In + expected use, stride_y is approximately (IH< { + + let summary = "Cast operation"; + + let description = [{ + Performs a set of permissible cast operations + Mode Input Output + --------------------------------------- + signed 8 to bool int8 Boolean + signed 16 to bool int16 Boolean + signed 32 to bool int32 Boolean + bool to 8 Boolean int8 + bool to 16 Boolean int16 + bool to 32 Boolean int32 + signed 8 to signed 16 int8 int16 + signed 8 to signed 32 int8 int32 + signed 16 to signed 8 int16 int8 + signed 16 to signed 32 int16 int32 + signed 32 to signed 8 int32 int8 + signed 32 to signed 16 int32 int16 + float to signed 8 float int8 + float to signed 16 float int16 + signed 8 to float int8 float + signed 16 to float int16 float + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: rescale +//===----------------------------------------------------------------------===// +def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> { + let summary = "Tosa rescale operator"; + + let description = [{ + Rescale quantized values into a new domain. Supported rescalings are: + Mode Input Output + signed 8 to 8 aint8 aint8 + signed 8 to 16 aint8 int16 + signed 8 to 32 aint8 int32 + signed 16 to 8 int16 aint8 + signed 16 to 16 int16 int16 + signed 16 to 32 int16 int32 + signed 32 to 8 int32 aint8 + signed 32 to 16 int32 int16 + signed 32 to 32 int32 int32 + signed 48 to 8 int48 aint8 + signed 48 to 16 int48 int16 + signed 48 to 32 int48 int32 + unsigned 8 to signed 8 uint8 aint8 + signed 8 to unsigned 8 aint8 uint8 + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input, + I32Attr:$input_zp, + I32Attr:$output_zp, + I32ArrayAttr:$multiplier, + I32ArrayAttr:$shift, + BoolAttr:$scale32, + BoolAttr:$double_round, + BoolAttr:$per_channel + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.13 +// Operator Class: Data Node Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: const +//===----------------------------------------------------------------------===// +def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, NoSideEffect, + FirstAttrDerivedResultType]> { + let summary = "Constant op."; + + let description = [{ + A node containing constant data for use as the input to an operation. May + hold data in any of the supported data formats. + }]; + + let arguments = (ins + AnyAttr:$value + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: identity +//===----------------------------------------------------------------------===// +def Tosa_IdentityOp: Tosa_Op<"identity", [NoSideEffect]> { + let summary = "Identity operator"; + let description = [{ + Returns a tensor with the same shape, size, type + and content as the input. + }]; + + let arguments = (ins + Tosa_TensorUpto6D:$input1 + ); + + let results = (outs + Tosa_TensorUpto6D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: identityn +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Further described in docs/Rationale/RationaleTOSADialect.md . +//===----------------------------------------------------------------------===// +def Tosa_IdentityNOp: Tosa_Op<"identityn", [NoSideEffect]> { + let summary = "IdentityN operator"; + let description = [{ + Returns a list of tensors with the same shape, type, and contents as the + input list of tensors. + }]; + + let arguments = (ins + Variadic:$input1 + ); + + let results = (outs + Variadic:$output + ); +} + + +//===----------------------------------------------------------------------===// +// Operator: placeholder +//===----------------------------------------------------------------------===// +def Tosa_PlaceholderOp : Tosa_Op<"placeholder", [NoSideEffect]> { + let summary = "Placeholder op"; + + let description = [{ + A node where data will be inserted into the network at runtime. Generally + used for inputs to the network. + }]; + + let arguments = (ins + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.14 +// Operator Class: Custom Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: custom +//===----------------------------------------------------------------------===// +def Tosa_CustomOp : Tosa_Op<"custom"> { + + let summary = "Custom operator wrapper for Tosa"; + + let description = [{ + Hardware implementing TOSA may choose to add additional custom operators + that are not expressed in the existing TOSA operations. These operators are + not expected to be portable across TOSA implementations. The input and + output signatures must be expressed in the corresponding TOSA node. + }]; + + let arguments = (ins + StrAttr:$identifier, + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.15 +// Operator Class: Control Flow Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: cond_if +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Further described in docs/Rationale/RationaleTOSADialect.md . +//===----------------------------------------------------------------------===// +def Tosa_IfOp : Tosa_Op<"cond_if", [ + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Conditional if operator"; + + let description = [{ + Evaluates a Boolean condition and then takes one of two distinct execution + paths. This implements the semantic If-then-else structure. + }]; + + let arguments = (ins + I1Tensor:$cond, + Variadic:$inputs + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region + SizedRegion<1>:$then_branch, + SizedRegion<1>:$else_branch + ); +} + +//===----------------------------------------------------------------------===// +// Operator: while_loop +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Further described in docs/Rationale/RationaleTOSADialect.md . +//===----------------------------------------------------------------------===// +def Tosa_WhileOp : Tosa_Op<"while_loop", [ + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = [{ + output = input; While (Cond(output)) { output = Body(output) } + }]; + + let description = [{ + Generates and evaluates a Bool condition and either executes a loop body or + exits to another control point. This action is performed repeatedly after + updating and re-evaluating the Boolean condition every iteration. This + implements the semantic foreach or while iterative loop structure. + }]; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region + SizedRegion<1>:$cond, + SizedRegion<1>:$body + ); +} + +//===----------------------------------------------------------------------===// +// Operator: yield +//===----------------------------------------------------------------------===// +def Tosa_YieldOp : Tosa_Op<"yield", [Terminator]> { + let summary = "yield operator"; + + let description = [{ + return operation within the conditional and body of + structured control flow. Operation takes variadic operands + but produces no results of its own. + }]; + + let arguments = (ins + Variadic:$inputs + ); +} + +#endif // TOSA_OPS Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h @@ -0,0 +1,34 @@ +//===-- TosaTraits.h - TOSA dialect operation traits ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect OpTraits in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATRAITS_H +#define MLIR_DIALECT_TOSA_IR_TOSATRAITS_H + +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace OpTrait { +namespace tosa { + +// TBD + +} // namespace tosa +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATRAITS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h @@ -0,0 +1,31 @@ +//===-- TosaTypes.h - TOSA dialect type definitions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect Types in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATYPES_H +#define MLIR_DIALECT_TOSA_IR_TOSATYPES_H + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +namespace tosa { + +// TOSA specific types go here + +} // namespace tosa + +} // end namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATYPES_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -0,0 +1,158 @@ +//===-- TosaTypesBase.td - TOSA type definitions -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the type definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_TYPES_BASE +#define TOSA_TYPES_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Tosa Type Definitions. +//===----------------------------------------------------------------------===// + +// The base class of a quantized type. +// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end]. +// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for +// the 8-bit case. +class Tosa_QuantizedType params, bit signed> + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # !head(params)>]>, + "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = + StrJoinInt.result # !if(signed, ", true", ", false"); +} + +//===----------------------------------------------------------------------===// +// Non-Quantized Signed Integer Types. +// Used to express accumulator results or compare results. +//===----------------------------------------------------------------------===// + +def Tosa_Int32 : I<32>; +def Tosa_Int48 : I<48>; +def Tosa_Int64 : I<64>; + +def Tosa_SignedInt : AnyTypeOf<[Tosa_Int32, + Tosa_Int48, + Tosa_Int64]>; + +def Tosa_Bool : I<1>; + +// No unsigned unquantized int types. +def Tosa_Int : AnyTypeOf<[Tosa_Bool, + Tosa_SignedInt]>; + +def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, + Tosa_Int64]>; + +//===----------------------------------------------------------------------===// +// Quantized Integer Types. +// Datatype for network feature map or weight content. +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Name Symmetry Grouping Sign +//===----------------------------------------------------------------------===// +// aint8 : asymmetric per tensor, signed +// uint8 : asymmetric per tensor , unsigned +// int4 : symmetric per channel, signed +// int8 : symmetric per tensor/per channel, signed +// int16 : symmetric per tensor, signed +//===----------------------------------------------------------------------===// +def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"aint8", [8], 1>, + Tosa_QuantizedType<"uint8", [8], 0>, + Tosa_QuantizedType<"int4", [4, 0], 1>, + Tosa_QuantizedType<"int8", [8, 0], 1>, + Tosa_QuantizedType<"int16", [16, 0], 1>]>; + +//===----------------------------------------------------------------------===// +// Floating-point types. +//===----------------------------------------------------------------------===// +def Tosa_Float : AnyTypeOf<[F32, + F16, + BF16]>; + +//===----------------------------------------------------------------------===// +// Multi-category types. +//===----------------------------------------------------------------------===// +def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], + "number">; + +//===----------------------------------------------------------------------===// +// Tensor types +//===----------------------------------------------------------------------===// + +def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>; + +def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; + +// Any tensor element type allowed in Tosa ops. +def Tosa_ElementType : Type, "tosa.dtype">; + +class Tosa_TensorOfOrNone allowedTypes, string description = ""> : + AnyTypeOf<[TensorOf, NoneType], description>; + +//===----------------------------------------------------------------------===// +// Tensor types with constrained ranks. +//===----------------------------------------------------------------------===// + +// Must be listed rank. +def Tosa_Tensor1D : 1DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor2D : 2DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor4D : 4DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor5D : TensorRankOf<[Tosa_AnyNumber], [5]>; +def Tosa_Tensor6D : TensorRankOf<[Tosa_AnyNumber], [6]>; + +// Ranked tensors up to given rank. +def Tosa_Tensor1Dto2D : TensorRankOf<[Tosa_AnyNumber], [1,2]>; +def Tosa_Tensor1Dto4D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>; +def Tosa_Tensor1Dto5D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5]>; +def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>; + +def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>; +def Tosa_TensorUpto6D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4,5,6]>; + +//===----------------------------------------------------------------------===// +// Attribute predicates and classes. +//===----------------------------------------------------------------------===// +class ArrayMaxCt : AttrConstraint< + CPred<"$_self.cast<::mlir::ArrayAttr>().size() <= " # n>, + "with at least " # n # " elements">; + +def Tosa_IntArrayAttr2 : Confined]>; +def Tosa_IntArrayAttr3 : Confined]>; +def Tosa_IntArrayAttr4 : Confined]>; +def Tosa_IntArrayAttr5 : Confined]>; +def Tosa_IntArrayAttr6 : Confined]>; + +def Tosa_IntArrayAttrUpto2 : Confined]>; +def Tosa_IntArrayAttrUpto4 : Confined]>; +def Tosa_IntArrayAttrUpto5 : Confined]>; + +//===----------------------------------------------------------------------===// +// Iterable attributes. +//===----------------------------------------------------------------------===// +// Supported regimes for tosa.resize. +def Tosa_ResizeTypeAttr : StringBasedAttr< + CPred<"$_self.cast().getValue() == \"BILINEAR\" || " # + "$_self.cast().getValue() == \"NEAREST_NEIGHBOR\"">, + "Supported resize/upsampling strategies">; + +def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; + +// Tensor to buffer types. +def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>; +def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>; +def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>; + +#endif // TOSA_TYPES_BASE Index: mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) +add_public_tablegen_target(MLIRTosaPassIncGen) +add_dependencies(mlir-headers MLIRTosaPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc TosaPasses ./) Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -0,0 +1,31 @@ +//===-- Passes.h - TOSA optimization pass declarations ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +namespace tosa { + +std::unique_ptr> createTosaMakeBroadcastablePass(); +std::unique_ptr> createTOSATestQuantUtilAPIPass(); + +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -0,0 +1,37 @@ +//===-- Passes.td - TOSA optimization pass declarations ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +include "mlir/Pass/PassBase.td" + +def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> { + let summary = "TOSA rank Reshape to enable Broadcasting"; + let description = [{ + Pass that enables broadcast by making all input arrays have the same + number of dimensions. Insert RESHAPE operations to prepend dimensions + of size one until the number of dimensions is equal. Implements + approach similar to step 1 of Numpy 4-step broadcasting: + https://numpy.org/doc/stable/reference/ufuncs.html#broadcasting + }]; + + let constructor = "createTosaMakeBroadcastablePass()"; +} + +// TOSA Test Passes + +def TosaTestQuantUtils : FunctionPass<"tosa-test-quant-utils"> { + let summary = "TOSA Test: Exercise the APIs in QuantUtils.cpp"; + let description = [{ + Exercises the API that builds a quantized type from min/max quantized range. + }]; + + let constructor = "createTOSATestQuantUtilAPIPass()"; +} Index: mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -0,0 +1,71 @@ +//===-- QuantUtils.h - TOSA numerical support declarations ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Function declarations for TOSA numerical support functions and quantization +// attribute builders +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_UTILS_QUANT_UTILS_H +#define DIALECT_TOSA_UTILS_QUANT_UTILS_H + +//===----------------------------------------------------------------------===// +// Utililty functions to support quantization handling in Tosa. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +#include "mlir/Dialect/Quant/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/UniformSupport.h" + +namespace mlir { +namespace tosa { + +/// From a scale value, computes multiplier and shift values +/// for 16 or 32-bit scale widths. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth); + +//// Builds ConvOpQuantizationAttr from input and weight. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, + Value input, Value weight); + +//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B. +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, + Value a, Value b); + +//// Builds UnaryOpQuantizationAttr for unary operations from input values. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, + Value input, + Type outputRawType); + +//// Builds PadOpQuantizationAttr for pad operations from input values. +PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, + Value input); + +//// construct ConvOp output type with correct bitwidth based on input/weight +/// width. +Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, + Value weight); + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange); + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange); + +} // namespace tosa +} // namespace mlir + +#endif // DIALECT_TOSA_UTILS_QUANT_UTILS_H Index: mlir/include/mlir/InitAllDialects.h =================================================================== --- mlir/include/mlir/InitAllDialects.h +++ mlir/include/mlir/InitAllDialects.h @@ -33,6 +33,7 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Dialect.h" @@ -60,7 +61,8 @@ NVVM::NVVMDialect, ROCDL::ROCDLDialect, SDBMDialect, - shape::ShapeDialect>(); + shape::ShapeDialect, + tosa::TosaDialect>(); // clang-format on } Index: mlir/include/mlir/InitAllPasses.h =================================================================== --- mlir/include/mlir/InitAllPasses.h +++ mlir/include/mlir/InitAllPasses.h @@ -24,6 +24,7 @@ #include "mlir/Dialect/SPIRV/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Transforms/Passes.h" #include @@ -54,6 +55,7 @@ registerShapePasses(); spirv::registerSPIRVPasses(); registerStandardPasses(); + tosa::registerTosaOptPasses(); } } // namespace mlir Index: mlir/lib/Dialect/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/CMakeLists.txt +++ mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) set(LLVM_OPTIONAL_SOURCES Index: mlir/lib/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_dialect_library(MLIRTosa + IR/TosaOps.cpp + Utils/QuantUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + + DEPENDS + MLIRStandardOpsIncGen + MLIRTosaOpsIncGen + MLIRTosaStructsIncGen + MLIRTosaInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRStandard + MLIRCallInterfaces + MLIRControlFlowInterfaces + MLIRSideEffectInterfaces + MLIRViewLikeInterface + ) + +add_subdirectory(Transforms) Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -0,0 +1,285 @@ +//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This file implements the TOSA Specification: +// https://developer.mlplatform.org/w/tosa/ +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Parser.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; +using namespace mlir::tosa; + +//===----------------------------------------------------------------------===// +// Tosa dialect structs and interface includes. +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" +#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc" + +namespace { +//===----------------------------------------------------------------------===// +// Dialect Function Inliner Interface. +//===----------------------------------------------------------------------===// +struct TosaInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks. + //===--------------------------------------------------------------------===// + + /// All operations can be inlined by default. + bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, + BlockAndValueMapping &map) const final { + return true; + } + + /// All regions with If and While parent operators can be inlined. + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + BlockAndValueMapping &map) const final { + return (isa(dest->getParentOp()) || + isa(dest->getParentOp())); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// TOSA control flow support. +//===----------------------------------------------------------------------===// + +/// Returns the while loop body. +Region &tosa::WhileOp::getLoopBody() { return body(); } + +bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) { + return !body().isAncestor(value.getParentRegion()); +} + +LogicalResult WhileOp::moveOutOfLoop(ArrayRef ops) { + if (ops.empty()) + return success(); + + Operation *tosaWhileOp = this->getOperation(); + for (auto *op : ops) + op->moveBefore(tosaWhileOp); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Tosa dialect initialization. +//===----------------------------------------------------------------------===// + +void TosaDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// TOSA Operator Verifiers. +//===----------------------------------------------------------------------===// + +template static LogicalResult verifyConvOp(T op) { + + // All TOSA conv ops have an input() and weight(). + auto inputType = op.input().getType().template dyn_cast(); + auto weightType = op.weight().getType().template dyn_cast(); + + // Must be ranked tensor types + if (!inputType || !weightType) + return failure(); + + auto inputQType = + inputType.getElementType().template isa(); + auto weightQType = + weightType.getElementType().template isa(); + + // Either both must be quantized or both unquantized. + if (inputQType != weightQType) + return failure(); + + // Quantized type must have constructed the quantizationattr, and unquantized + // types should not have a quantizationattr. + if ((inputQType && !op.quantization_info()) || + (!inputQType && op.quantization_info())) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// TOSA Operator Quantization Builders. +//===----------------------------------------------------------------------===// + +// This builder is called on all convolution operators except for TransposeConv, +// which has specialized output shape semantics. The builder also defines the +// bitwidth of the output given the bit width of the input & weight content. +void ConvOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value weight, + Value bias, ArrayAttr pad, ArrayAttr stride, + ArrayAttr dilation) { + + result.addOperands({input, weight, bias}); + result.addAttribute("pad", pad); + result.addAttribute("stride", stride); + result.addAttribute("dilation", dilation); + + auto quantAttr = tosa::buildConvOpQuantizationAttr(builder, input, weight); + if (quantAttr) { + result.addAttribute("quantization_info", quantAttr); + result.addTypes( + tosa::buildConvOpResultTypeInfo(builder, outputType, input, weight)); + } else + result.addTypes(outputType); +} + +// Handles tosa.transpose_conv2d which has an outpad and output shape attribute. +void TransConvOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value weight, + Value bias, ArrayAttr outpad, ArrayAttr stride, + ArrayAttr dilation, ArrayAttr outputShape) { + result.addOperands({input, weight, bias}); + result.addAttribute("out_pad", outpad); + result.addAttribute("stride", stride); + result.addAttribute("dilation", dilation); + result.addAttribute("out_shape", outputShape); + auto quantAttr = tosa::buildConvOpQuantizationAttr(builder, input, weight); + + if (quantAttr) { + result.addAttribute("quantization_info", quantAttr); + result.addTypes( + tosa::buildConvOpResultTypeInfo(builder, outputType, input, weight)); + } else + result.addTypes(outputType); +} + +// The tosa.fully_connected op has its own builder as it does not have +// strides/dilation/padding. +void FCOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value weight, + Value bias) { + + result.addOperands({input, weight, bias}); + auto quantAttr = tosa::buildConvOpQuantizationAttr(builder, input, weight); + if (quantAttr) { + result.addAttribute("quantization_info", quantAttr); + result.addTypes( + tosa::buildConvOpResultTypeInfo(builder, outputType, input, weight)); + } else + result.addTypes(outputType); +} + +// The tosa.matmul op is also intended to be generated where a fully_connected +// op must be constructed where the weight is not a constant. In this case, +// the fully_connected op must be expressed using matmul. +// TODO: Add link to the leglization document explaining this. +void MatMulOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value a, Value b) { + result.addOperands({a, b}); + auto quantAttr = tosa::buildMatMulOpQuantizationAttr(builder, a, b); + + if (quantAttr) { + result.addAttribute("quantization_info", quantAttr); + + auto inputType = a.getType().dyn_cast(); + assert(inputType && "Input must be a ranked tensor type!"); + + auto inputQType = inputType.getElementType() + .dyn_cast(); + assert(inputQType && "Tensor must have quantized datatype!"); + + unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); + + auto outputShapedType = outputType.dyn_cast(); + assert(outputShapedType && "Output must be a ranked tensor type"); + + auto outputShape = outputShapedType.getShape(); + + IntegerType accElementType; + if (inputBits == 16) + accElementType = builder.getIntegerType(48); + else + accElementType = builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + result.addTypes(accType); + } else + result.addTypes(outputType); +} + +// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr +// but the avg_pool operator has its own builder as it has additional parameters +// not part of the unary ops. +void AvgPool2dOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, ArrayAttr kernel, + ArrayAttr stride, ArrayAttr pad) { + result.addOperands(input); + result.addAttribute("kernel", kernel); + result.addAttribute("stride", stride); + result.addAttribute("pad", pad); + auto quantAttr = + tosa::buildUnaryOpQuantizationAttr(builder, input, outputType); + if (quantAttr) + result.addAttribute("quantization_info", quantAttr); + result.types.push_back(outputType); +} + +// This builder is called on single-parameter unary operators that have a scale +// relationship between their input and output, expressed by the +// UnaryOpQuantizationAttr. +void UnaryOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input) { + result.addOperands(input); + auto quantAttr = + tosa::buildUnaryOpQuantizationAttr(builder, input, outputType); + if (quantAttr) + result.addAttribute("quantization_info", quantAttr); + result.types.push_back(outputType); +} + +// This builder is called on the TOSA pad operator that needs to create its own +// OptionalAttr quantization_attr parameter to scale the padding values +// correctly. +void PadOpQuantInfoBuilder(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value paddings) { + result.addOperands({input, paddings}); + auto quantAttr = tosa::buildPadOpQuantizationAttr(builder, input); + if (quantAttr) + result.addAttribute("quantization_info", quantAttr); + result.types.push_back(outputType); +} + +//===----------------------------------------------------------------------===// +// TOSA Operator Definitions. +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" Index: mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRTosaTransforms + TosaMakeBroadcastable.cpp + TosaTestPasses.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms + + DEPENDS + MLIRTosaPassIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTosa + ) Index: mlir/lib/Dialect/Tosa/Transforms/PassDetail.h =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/PassDetail.h @@ -0,0 +1,21 @@ +//===- PassDetail.h - TOSA Pass class details -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_TRANSFORMS_PASSDETAIL_H +#define DIALECT_TOSA_TRANSFORMS_PASSDETAIL_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +} // end namespace mlir + +#endif // DIALECT_TOSA_TRANSFORMS_PASSDETAIL_H Index: mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -0,0 +1,283 @@ +//===- TosaMakeBroadcastable.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Insert reshape to binary op's input if needed to match rank +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::tosa; + +/// There are two potential ways implementing broadcast: +/// a. https://www.tensorflow.org/xla/broadcasting#formal_definition +/// b. https://numpy.org/doc/stable/user/basics.broadcasting.html +/// TBD: picking option (a) now. + +/// In this pass, we insert RESHAPE operators to increase the rank of the +/// lower rank operand as a first step in the broadcasting process. The TOSA +/// operators that support broadcast require that the rank of the operands +/// are equal. + +// Examples: +// If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1]. +// TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into +// [1, b, 1]. +// If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c]. +// If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c]. +// If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1]. +// If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c]. +// If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1]. +// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. +// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. + +static void computeReshapeOutput(ArrayRef higherRankShape, + ArrayRef lowerRankShape, + SmallVectorImpl &reshapeOutputShape) { + // Intialize new shapes with [1] * higherRank. + int64_t higherRank = higherRankShape.size(); + int64_t lowerRank = lowerRankShape.size(); + + reshapeOutputShape.assign(higherRank, 1); + + int64_t higherLeftIndex = 0; + int64_t higherRightIndex = higherRank; + int64_t lowerLeftIndex = 0; + int64_t lowerRightIndex = lowerRank; + int64_t higherRankDim, lowerRankDim; + + if (lowerRightIndex != 0 && higherRightIndex != 0) { + // Matches lower rank shape from right dimension first, until not + // matching high rank shape or reaching dimension 0. + while (true) { + higherRankDim = higherRankShape[higherRightIndex - 1]; + lowerRankDim = lowerRankShape[lowerRightIndex - 1]; + if (higherRankDim != lowerRankDim) + break; + + reshapeOutputShape[higherRightIndex - 1] = higherRankDim; + + if (higherRightIndex > 0) + higherRightIndex--; + + if (lowerRightIndex > 0) + lowerRightIndex--; + + if (higherRightIndex == 0 || lowerRightIndex == 0) + break; + } + if (lowerRightIndex != 0 && higherRightIndex != 0) { + // Matches lower rank shape from left dimension, until not matching + // high rank shape or reaching right index. + while (true) { + higherRankDim = higherRankShape[higherLeftIndex]; + lowerRankDim = lowerRankShape[lowerLeftIndex]; + if (higherRankDim != lowerRankDim) + break; + + reshapeOutputShape[higherLeftIndex] = higherRankDim; + + if (higherLeftIndex < higherRightIndex) + higherLeftIndex++; + + if (lowerLeftIndex < lowerRightIndex) + lowerLeftIndex++; + + if (higherLeftIndex == higherRightIndex || + lowerLeftIndex == lowerRightIndex) + break; + } + } + } +} + +/// Common code to reate the reshape op where necessary to make the rank of the +/// operations equal. Returns the updated input1 and input2 for the original +/// input. The caller is expected to use these to rewrite the original operator +/// with the RESHAPE now in the graph. +int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, + RankedTensorType outputType, Value input1, + Value input2, Value &outInput1, Value &outInput2) { + + int64_t input1Rank = input1.getType().cast().getRank(); + int64_t input2Rank = input2.getType().cast().getRank(); + + Value higherTensorValue, lowerTensorValue; + /* return if rank already match */ + if (input1Rank == input2Rank) { + return 1; + } else if (input1Rank > input2Rank) { + higherTensorValue = input1; + lowerTensorValue = input2; + } else { + higherTensorValue = input2; + lowerTensorValue = input1; + } + + ArrayRef outputRankShape = outputType.getShape(); + ArrayRef higherRankShape = + higherTensorValue.getType().cast().getShape(); + ArrayRef lowerRankShape = + lowerTensorValue.getType().cast().getShape(); + + /* outputRank == higherRank == max(input1Rank, input2Rank) */ + assert(higherRankShape.size() == outputRankShape.size()); + + SmallVector reshapeOutputShape; + + computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape); + + auto reshapeInputType = lowerTensorValue.getType().cast(); + auto reshapeOutputType = RankedTensorType::get( + ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); + + auto reshapeLower = rewriter.create( + loc, reshapeOutputType, lowerTensorValue, + rewriter.getI64ArrayAttr(reshapeOutputShape)); + + if (input1Rank > input2Rank) { + outInput1 = higherTensorValue; + outInput2 = reshapeLower.getResult(); + } else { + outInput1 = reshapeLower.getResult(); + outInput2 = higherTensorValue; + } + + return 0; +} + +template struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy tosaBinaryOp, + PatternRewriter &rewriter) const { + + Value input1 = tosaBinaryOp.input1(); + Value input2 = tosaBinaryOp.input2(); + Value output = tosaBinaryOp.getResult(); + auto outputType = output.getType().cast(); + + Value outInput1, outInput2; + if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, + input1, input2, outInput1, outInput2)) + return failure(); + + rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, outInput1, + outInput2); + + return success(); + } +}; + +// The MulOp has an extra parameter 'shift' not present in other elementwise +// binary ops, that necessitates special handling of its builder. +template <> +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp, + PatternRewriter &rewriter) const { + + Value input1 = tosaBinaryOp.input1(); + Value input2 = tosaBinaryOp.input2(); + int32_t shift = tosaBinaryOp.shift(); + Value output = tosaBinaryOp.getResult(); + auto outputType = output.getType().cast(); + + Value outInput1, outInput2; + if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, + input1, input2, outInput1, outInput2)) + return failure(); + + rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, + outInput1, outInput2, shift); + + return success(); + } +}; + +// The ArithmeticRightShiftOp has an extra parameter 'round' not present in +// other elementwise binary ops, that necessitates special handling of its +// builder. +template <> +struct ConvertTosaOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp, + PatternRewriter &rewriter) const { + + Value input1 = tosaBinaryOp.input1(); + Value input2 = tosaBinaryOp.input2(); + int32_t round = tosaBinaryOp.round(); + Value output = tosaBinaryOp.getResult(); + auto outputType = output.getType().dyn_cast(); + + Value outInput1, outInput2; + if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, + input1, input2, outInput1, outInput2)) + return failure(); + + rewriter.replaceOpWithNewOp( + tosaBinaryOp, outputType, outInput1, outInput2, round); + + return success(); + } +}; + +static void applyTosaBroadcastablePatterns(FuncOp func) { + OwningRewritePatternList patterns; + MLIRContext *ctx = func.getContext(); + // Add the generated patterns to the list. + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +namespace { + +/// Pass that enables broadcast by making all input arrays have the same +/// number of dimensions. Insert RESHAPE operations to lower rank operand +struct TosaMakeBroadcastable + : public TosaMakeBroadcastableBase { +public: + void runOnFunction() override { + auto func = getFunction(); + applyTosaBroadcastablePatterns(func); + } +}; +} // anonymous namespace + +std::unique_ptr> +mlir::tosa::createTosaMakeBroadcastablePass() { + return std::make_unique(); +} Index: mlir/lib/Dialect/Tosa/Transforms/TosaTestPasses.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/TosaTestPasses.cpp @@ -0,0 +1,204 @@ +//===- TosaTestPasses.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test passes to exercise TOSA helper functions. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::tosa; + +namespace { + +struct TosaTestQuantUtilAPI + : public TosaTestQuantUtilsBase { + void runOnFunction() override; +}; + +// This transformation converts quantized uint8 to quantized int8. The +// construction of the new type invokes buildQTypeFromMinMax. Extracted from +// TOSA legalization infrastructure. +struct ConvertTosaNegateOp : public RewritePattern { + explicit ConvertTosaNegateOp(MLIRContext *context) + : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult +ConvertTosaNegateOp::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + + auto tosaNegateOp = cast(op); + + auto inputType = + tosaNegateOp.input1().getType().dyn_cast(); + // skip if input is not ranked tensor type + if (!inputType) + return failure(); + + // skip if it's not ranked tensor type. + auto outputType = + tosaNegateOp.getResult().getType().dyn_cast(); + if (!outputType) + return failure(); + + // skip if output is not per-tensor quantized type. + auto outputElementType = + outputType.getElementType().dyn_cast(); + if (!outputElementType) + return failure(); + + // skip if output is not uint8. + if (outputElementType.isSigned() || + outputElementType.getStorageTypeIntegralWidth() != 8) + return failure(); + + double typeRangeMin = double(outputElementType.getStorageTypeMin() - + outputElementType.getZeroPoint()) * + outputElementType.getScale(); + double typeRangeMax = double(outputElementType.getStorageTypeMax() - + outputElementType.getZeroPoint()) * + outputElementType.getScale(); + bool narrow_range = outputElementType.getStorageTypeMin() == 1 ? true : false; + + auto dstQConstType = RankedTensorType::get( + outputType.getShape(), + buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(), + rewriter.getF64FloatAttr(typeRangeMin), + rewriter.getF64FloatAttr(typeRangeMax), + rewriter.getI32IntegerAttr( + outputElementType.getStorageTypeIntegralWidth()), + 0, true /* signed */, + rewriter.getBoolAttr(narrow_range))); + + ElementsAttr inputElems; + if (!matchPattern(tosaNegateOp.input1(), m_Constant(&inputElems))) + return failure(); + + auto newConstOp = + rewriter.create(op->getLoc(), dstQConstType, inputElems); + auto newNegateOp = rewriter.create( + op->getLoc(), dstQConstType, newConstOp.getResult()); + + rewriter.replaceOp(op, {newNegateOp.getResult()}); + return success(); +} + +// This transformation modifies the quantized output of a test conv2d input and +// appends a TOSA rescale after it. The rescale op requires the invocation of +// computeMultiplierAndShift. From TOSA legalization infrastructure. +struct ConvertTosaConv2DOp : public RewritePattern { + explicit ConvertTosaConv2DOp(MLIRContext *context) + : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult +ConvertTosaConv2DOp::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + + auto tosaConv2DOp = cast(op); + + auto inputType = + tosaConv2DOp.input().getType().dyn_cast(); + + // skip if input is not ranked tensor type + if (!inputType) + return failure(); + + auto weightType = + tosaConv2DOp.weight().getType().dyn_cast(); + + // skip if wt is not ranked tensor type + if (!weightType) + return failure(); + + // skip if it's not ranked tensor type. + auto outputType = + tosaConv2DOp.getResult().getType().dyn_cast(); + if (!outputType) + return failure(); + + auto inputQType = + inputType.getElementType().dyn_cast(); + auto weightQType = + weightType.getElementType().dyn_cast(); + auto outputQType = + outputType.getElementType().dyn_cast(); + + // Works on quantized type only. + if (!(inputQType && weightQType && outputQType)) + return failure(); + + auto newTosaConv2DOpType = + RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32)); + + auto newTosaConv2DOp = rewriter.create( + op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.input(), + tosaConv2DOp.weight(), tosaConv2DOp.bias(), tosaConv2DOp.pad(), + tosaConv2DOp.stride(), tosaConv2DOp.dilation()); + + // Create rescale to quantized type + double inputScale = inputQType.getScale(); + double weightScale = weightQType.getScale(); + double outputScale = outputQType.getScale(); + int64_t outputZp = outputQType.getZeroPoint(); + + double opTensorScale = (inputScale * weightScale) / outputScale; + + int32_t multiplier; + int32_t shift; + + // Obtain the quantized scale = multiplier and shift. + computeMultiplierAndShift(opTensorScale, multiplier, shift, 32); + + auto newTosaRescaleOp = rewriter.create( + op->getLoc(), outputType, newTosaConv2DOp.getResult(), + rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp), + rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}), + rewriter.getBoolAttr(true), rewriter.getBoolAttr(true), + rewriter.getBoolAttr(false)); + + rewriter.replaceOp(op, {newTosaRescaleOp.getResult()}); + return success(); +} + +void TosaTestQuantUtilAPI::runOnFunction() { + OwningRewritePatternList patterns; + auto *ctx = &getContext(); + auto func = getFunction(); + + patterns.insert(ctx); + patterns.insert(ctx); + applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // anonymous namespace + +std::unique_ptr> +mlir::tosa::createTOSATestQuantUtilAPIPass() { + return std::make_unique(); +} Index: mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -0,0 +1,363 @@ +//===- QuantUtils.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains TOSA numerical support functions and quantization +// attribute builders. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" + +namespace mlir { +namespace tosa { +namespace { + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 16-bit scaling. +void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 15)); + + // Can't be greater than 1.0. + assert(shiftedM <= (int64_t(1) << 15) && + "Shifted mantissa exceeds 16 signed bits"); + + if (shiftedM == (int64_t(1) << 15)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive and embed (1 << 15) into right + // shift bits. + shift = (-shift) + 15; + + assert(shiftedM <= std::numeric_limits::max() && + "Shifted mantissa exceeds 32-bit signed output type"); + + multiplier = static_cast(shiftedM); +} + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 32-bit scaling. +void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 31)); + + // Can't be greater than 1.0. + assert(shiftedM <= (int64_t(1) << 31) && + "Shifted mantissa exceeds 32 signed bits"); + if (shiftedM == (int64_t(1) << 31)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive, and embed (1 << 31) into right + // shift bits. + shift = (-shift) + 31; + + assert(shiftedM <= std::numeric_limits::max() && + "Shifted mantissa exceeds 32-bit signed output type"); + + multiplier = static_cast(shiftedM); +} + +} // namespace + +/// Generates a quantized multiplier/shift from double. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth) { + + switch (scaleWidth) { + case 16: + computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); + return; + case 32: + computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); + return; + default: + assert(0 && "Unsupported Tosa quantized_scale regime specified!"); + } +} + +#define GET_UQTYPE(input_type) \ + ((input_type).getElementType().dyn_cast()) +#define GET_QTYPE(input_type) \ + ((input_type).getElementType().dyn_cast()) + +// Method to build ConvOpQuantizationAttr, called from +// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: input_zp: input zeropoint +// weight_zp: weight zeropoint. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, + Value input, Value weight) { + + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); + + if (!inputType || !weightType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + auto weightPerTensorQType = GET_UQTYPE(weightType); + auto weightPerAxisQType = weightType.getElementType() + .dyn_cast(); + + // Weights must be either per-tensor quantized or per-axis quantized. + assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && + "Weights must be either per-tensor or per-axis quantized"); + + // Either all quantized or all not quantized. + assert(!((bool)inputQType ^ + ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) && + "Inputs and weights must be all quantized or all not quantized"); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t weightZp = 0; + + if (weightPerTensorQType) { + weightZp = weightPerTensorQType.getZeroPoint(); + } else if (weightPerAxisQType) { + weightZp = weightPerAxisQType.getZeroPoints().front(); + } + + auto quantAttr = tosa::ConvOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds MatMulOpQuantizationAttr, called from +/// MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: input b zeropoint. +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, + Value a, Value b) { + + auto aType = a.getType().dyn_cast(); + auto bType = b.getType().dyn_cast(); + + if (!aType || !bType) + return nullptr; + + auto aQType = GET_UQTYPE(aType); + auto bQType = GET_UQTYPE(bType); + + // A and B are either all quantized or all not quantized. + assert(!((bool)aQType ^ (bool)bQType) && + "Matmul operands must be all quantized or all not quantized"); + + if (aQType) { + + int64_t aZp = aQType.getZeroPoint(); + int64_t bZp = bQType.getZeroPoint(); + + auto quantAttr = tosa::MatMulOpQuantizationAttr::get( + builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds UnaryOpQuantizationAttr +/// UnaryOpQuantInfoBuilder: inputZp: input zeropoint +/// outputZp: output zeropoint. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, + Value input, + Type outputRawType) { + + auto inputType = input.getType().dyn_cast(); + auto outputType = outputRawType.dyn_cast(); + + if (!inputType || !outputType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + auto outputQType = GET_UQTYPE(outputType); + + // Either all quantized or all not quantized. + assert(!((bool)inputQType ^ (bool)outputQType) && + "Unary inputs/outputs must be all quantized or all not quantized"); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t outputZp = outputQType.getZeroPoint(); + + auto quantAttr = tosa::UnaryOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: +/// inputZp: input zeropoint. +PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, + Value input) { + + auto inputType = input.getType().dyn_cast(); + + if (!inputType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + + auto quantAttr = tosa::PadOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds output type for a quantized ConvOp with the right bitwidth. +/// This is called by the builder when dealing with quantized content. +Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, + Value weight) { + + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); + + assert(inputType && weightType && + "Could not extract input or weight tensors from Conv op"); + + auto inputQType = GET_QTYPE(inputType); + auto weightQType = GET_QTYPE(weightType); + + assert(inputQType && weightQType && + "Could not extract input or weight tensor types from Conv op"); + + unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); + unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); + + auto outputShapedType = outputType.dyn_cast(); + assert(outputShapedType && + "Could not extract output shape type from Conv op"); + + auto outputShape = outputShapedType.getShape(); + + IntegerType accElementType; + if (inputBits == 16 && weightBits == 8) + accElementType = builder.getIntegerType(48); + else + accElementType = builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + return accType; +} + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange) { + + quant::QuantizedType retType; + + auto convfunc = + quant::ExpressedToQuantizedConverter::forInputType(inputDType); + + auto minElems = minAttr.dyn_cast(); + auto maxElems = maxAttr.dyn_cast(); + + SmallVector min, max; + + // At least one is per-axis quantized elementsattr. + if (minElems || maxElems) { + + // Must have the same number of elements. + if (minElems.getNumElements() != maxElems.getNumElements()) + return {}; + + min.reserve(minElems.getNumElements()); + max.reserve(maxElems.getNumElements()); + for (auto i : minElems) { + min.push_back(FloatAttr::getValueAsDouble(i)); + } + for (auto i : maxElems) { + max.push_back(FloatAttr::getValueAsDouble(i)); + } + } else { // Just a single FP value. + + auto minVal = minAttr.dyn_cast(); + if (minVal) + min.push_back(minVal.getValueAsDouble()); + else + return {}; + auto maxVal = maxAttr.dyn_cast(); + if (maxVal) + max.push_back(maxVal.getValueAsDouble()); + else + return {}; + } + + if (min.size() == max.size()) { + + if (min.size() == 1) { // Per-tensor quantization with one min/max pair. + + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], + narrowRange.getValue(), convfunc.expressedType, isSigned); + + } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. + + auto shape = inputDType.dyn_cast(); + if (!shape) + return {}; + if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { + + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], + max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); + } + + } else { + return {}; + } + } else { + return {}; + } + + if (!retType) + return {}; + + return convfunc.convert(retType); +} + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange) { + + return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, + maxAttr, quantBits, filterQuantDim, + isSigned, narrowRange)); +} + +} // namespace tosa +} // namespace mlir Index: mlir/test/Dialect/Tosa/broadcast.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/broadcast.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s + +// ----- +// CHECK-LABEL: broadcast0 +func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-NOT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +// ----- +// CHECK-LABEL: broadcast1 +func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32> + return %0 : tensor<2x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast2 +func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32> + return %0 : tensor<2x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast3 +func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32> + return %0 : tensor<2x1x1x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast4 +func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32> + return %0 : tensor<1x1x1x2xf32> +} + +// ----- +// CHECK-LABEL: broadcast5 +func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32> + return %0 : tensor<1x1x2x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast6 +func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast7 +func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32> + return %0 : tensor<17x16x1x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast8 +func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast9 +func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast10 +func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast13 +func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast14 +func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> + return %0 : tensor<17x16x1x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast15 +func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast16 +func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast17 +func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast18 +func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> { + // CHECK: add + %0 = "tosa.add"(%arg0, %arg1) : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32> + return %0 : tensor<14x15xf32> +} + +// ----- +// CHECK-LABEL: broadcast_mul +func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { + // CHECK: reshape + %0 = "tosa.mul"(%arg0, %arg1) {shift = 1 : i32 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> + return %0 : tensor<17x16x15x14xi32> +} + +// ----- +// CHECK-LABEL: broadcast_arithmetic_right_shift +func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { + // CHECK: reshape + %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> + return %0 : tensor<17x16x15x14xi32> +} + +// ----- +// CHECK-LABEL: broadcast_scalar +func @test_broadcast_scalar(%arg0: tensor, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> + return %0 : tensor<17x16x15x14xi32> +} Index: mlir/test/Dialect/Tosa/dynamic_shape.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/dynamic_shape.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + + +// ----- +// CHECK-LABEL: argmax +func @test_argmax(%arg0: tensor) -> tensor { + %0 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor) -> tensor + return %0 : tensor +} + Index: mlir/test/Dialect/Tosa/inlining.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/inlining.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt %s -inline | FileCheck %s + +// These tests verify that regions with operations from TOSA dialect +// can be inlined. + +// CHECK-LABEL: func @inlined_if_fn +func @inlined_if_fn(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: tosa.cond_if + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = call @add(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = call @sub(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} +func @add(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { + %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} +func @sub(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { + %0 = "tosa.sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @inlined_while_fn +func @inlined_while_fn(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> tensor<10xi32> { + // CHECK: tosa.while_loop + %1:4 = "tosa.while_loop"(%arg0, %arg1, %arg2, %arg3) ( { + ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor<10xi32>): // no predecessors + %2 = call @while_cond_40(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }, { + ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor<10xi32>): // no predecessors + %2:4 = call @while_body_50(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) + "tosa.yield"(%2#0, %2#1, %2#2, %2#3) : (tensor, tensor, tensor, tensor<10xi32>) -> () + }) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) + return %1#3 : tensor<10xi32> +} +func @while_body_50(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) attributes {sym_visibility = "private"} { + %1 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor) -> tensor<10xi32> + return %1, %arg1, %arg2, %2: tensor, tensor, tensor, tensor<10xi32> +} +func @while_cond_40(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> tensor attributes {sym_visibility = "private"} { + %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "tosa.logical_not"(%0) : (tensor) -> tensor + return %1 : tensor +} Index: mlir/test/Dialect/Tosa/ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/ops.mlir @@ -0,0 +1,512 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + + +// ----- +// CHECK-LABEL: argmax +func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> { + %0 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<14x19xf32>) -> tensor<14xi32> + return %0 : tensor<14xi32> +} + +// ----- +// CHECK-LABEL: avg_pool2d +func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + return %0 : tensor<1x7x7x9xf32> +} + +// ----- +// CHECK-LABEL: conv2d +func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + +// ----- +// CHECK-LABEL: depthwise_conv2d +func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %2 : tensor<1x4x4x8xf32> +} + +// ----- +// CHECK-LABEL: fully_connected +func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> { + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<14x19xf32>, tensor<19x28xf32>, tensor<28xf32>) -> tensor<14x28xf32> + return %0 : tensor<14x28xf32> +} + +// ----- +// CHECK-LABEL: test_matmul +func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> { + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32> + return %0 : tensor<14x28xf32> +} + +// ----- +// CHECK-LABEL: max_pool2d +func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + %0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +/// CHECK-LABEL: transpose_conv2d +func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 32, 32, 16], stride = [1, 1]} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- +// CHECK-LABEL: clamp +func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.clamp"(%arg0) {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: relu +func @test_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.reluN"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + + +// ----- +// CHECK-LABEL: sigmoid +func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.sigmoid"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: tanh +func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.tanh"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: add +func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: arithmetic_right_shift +func @test_arithmetic_right_shift(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = false } : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + + +// ----- +// CHECK-LABEL: bitwise_and +func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: bitwise_or +func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: bitwise_xor +func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_and +func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> { + %0 = "tosa.logical_and"(%arg0, %arg1) : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: logical_left_shift +func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_right_shift +func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_or +func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + %0 = "tosa.logical_or"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: logical_xor +func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + %0 = "tosa.logical_xor"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: maximum +func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.maximum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: minimum +func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.minimum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: mul +func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: pow +func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.pow"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: sub +func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: table +func @main(%arg0: tensor<64xi32>, %arg1: tensor<513x!quant.uniform>) -> tensor<64x!quant.uniform> { + %0 = "tosa.table"(%arg0, %arg1) : (tensor<64xi32>, tensor<513x!quant.uniform>) -> tensor<64x!quant.uniform> + return %0 : tensor<64x!quant.uniform> +} + +// ----- +// CHECK-LABEL: abs +func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.abs"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: bitwise_not +func @test_bitwise_not(%arg0: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> { + %0 = "tosa.bitwise_not"(%arg0) : (tensor<13x21x1xi32>) -> tensor<13x21x1xi32> + return %0 : tensor<13x21x1xi32> +} + +// ----- +// CHECK-LABEL: ceil +func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.ceil"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: clz +func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.clz"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: exp +func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: floor +func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.floor"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: log +func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: logical_not +func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { + %0 = "tosa.logical_not"(%arg0) : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1> + return %0 : tensor<1x21x3xi1> +} + +// ----- +// CHECK-LABEL: negate +func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.negate"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: reciprocal +func @test_reciprocal(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.reciprocal"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: rsqrt +func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.rsqrt"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: select +func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + + +// ----- +// CHECK-LABEL: equal +func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> { + %0 = "tosa.equal"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: greater +func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + %0 = "tosa.greater"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: greater_equal +func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_all +func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { + %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xi1>) -> tensor<21x3xi1> + return %1 : tensor<21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_any +func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { + %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xi1>) -> tensor<21x3xi1> + return %1 : tensor<21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_max +func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { + %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_min +func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { + %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_product +func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { + %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_sum +func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: concat +func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32> + return %0 : tensor<26x21x3xf32> +} + +// ----- +// CHECK-LABEL: pad +func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> { + %0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: reshape +func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> { + %0 = "tosa.reshape"(%arg0) {new_shape = [1, 819]} : (tensor<13x21x3xf32>) -> tensor<1x819xf32> + return %0 : tensor<1x819xf32> +} + +// ----- +// CHECK-LABEL: reverse +func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: slice +func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { + %0 = "tosa.slice"(%arg0) {start = [6, 8, 0], size = [4, 11, 1]} : (tensor<13x21x3xf32>) -> tensor<4x11x1xf32> + return %0 : tensor<4x11x1xf32> +} + +// ----- +// CHECK-LABEL: tile +func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> { + %0 = "tosa.tile"(%arg0) {multiples = [3, 1, 2]} : (tensor<13x21x3xf32>) -> tensor<39x21x6xf32> + return %0 : tensor<39x21x6xf32> +} + +// ----- +// CHECK-LABEL: transpose +func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32> + return %1 : tensor<3x13x21xf32> +} + +// ----- +// CHECK-LABEL: gather +func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<26xi32>) -> tensor<26x21x3xi32> { + %0 = "tosa.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i64} : (tensor<13x21x3xi32>, tensor<26xi32>) -> tensor<26x21x3xi32> + return %0 : tensor<26x21x3xi32> +} + +// Test TBD +// DISABLED-CHECK-LABEL: resize +//func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { +// %0 = "tosa.const"() {value = dense<64> : tensor<2xi32>} : () -> tensor<2xi32> +// %1 = "tosa.resize"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x32x32x8xf32>, tensor<2xi32>) -> tensor<1x64x64x8xf32> +// return %1 : tensor<1x64x64x8xf32> +//} + +// ----- +// CHECK-LABEL: cast +func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> { + %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: cast2 +func @test_cast2(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> { + %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: cast3 +func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> { + %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: rescale +func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { + %0 = "tosa.rescale"(%arg0) {double_round = false, input_zp = 127 : i32, multiplier = [1073741824 : i32], output_zp = -1 : i32, per_channel = false, scale32 = true, shift = [30 : i32]} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: const +func @test_const(%arg0 : index) -> tensor<4xi32> { + %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- +// CHECK-LABEL: identity +func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.identity"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: identityn +func @test_identityn(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + %0:2 = "tosa.identityn"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>) + return %0#0 : tensor<1xi32> +} + +// ----- +// CHECK-LABEL: placeholder +func @test_placeholder() -> tensor<1xi32> { + %0 = "tosa.placeholder"() : () -> tensor<1xi32> + return %0 : tensor<1xi32> +} + +// ----- +// CHECK-LABEL: cond_if +func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = "tosa.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = "tosa.sub"(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- +// CHECK-LABEL: while_loop +func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + %1:3 = "tosa.while_loop"(%0, %0, %arg0) ( { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): // no predecessors + %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor, tensor) -> tensor + %3 = "tosa.logical_not"(%2) : (tensor) -> tensor + "tosa.yield"(%3) : (tensor) -> () + }, { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): // no predecessors + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor + %4 = "tosa.reshape"(%2) {new_shape = [1]} : (tensor) -> tensor<1xi32> + %5 = "tosa.add"(%arg4, %4) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> + %6 = "tosa.add"(%arg2, %2) : (tensor, tensor) -> tensor + "tosa.yield"(%6, %3, %5) : (tensor, tensor, tensor<10xi32>) -> () + }) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) + return +} Index: mlir/test/Dialect/Tosa/quant-ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/quant-ops.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + +//===-----------------------------------------------------------------------===// +// Tests to check for parsing of legal TOSA operators with quantized types +// and quantization_info attributes. +//===-----------------------------------------------------------------------===// + +// ----- +// CHECK-LABEL: avg_pool2d +func @test_avg_pool2d(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + return %0 : tensor<1x7x7x9x!quant.uniform> +} + +// ----- +// CHECK-LABEL: conv2d +func @test_conv2d(%arg0: tensor<1x4x4x4x!quant.uniform>, %arg1: tensor<8x5x5x4x!quant.uniform:f32, 0.015746051445603371>>, %arg2: tensor<8x!quant.uniform>) -> tensor<1x4x4x8xi32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 1], pad = [4, 4, 2, 2], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32}, stride = [1, 1]} : (tensor<1x4x4x4x!quant.uniform>, tensor<8x5x5x4x!quant.uniform:f32, 0.015746051445603371>>, tensor<8x!quant.uniform>) -> tensor<1x4x4x8xi32> + return %0 : tensor<1x4x4x8xi32> +} + +// ----- +// CHECK-LABEL: depthwise_conv2d +func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4x!quant.uniform>, %arg1: tensor<1x1x4x2x!quant.uniform>, %arg2: tensor<8x!quant.uniform>) -> tensor<1x4x4x8xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1,1], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32 }, pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4x!quant.uniform>, tensor<1x1x4x2x!quant.uniform>, tensor<8x!quant.uniform>) -> tensor<1x4x4x8xi32> + return %2 : tensor<1x4x4x8xi32> +} + +// ----- +// CHECK-LABEL: fully_connected +func @test_fully_connected(%arg0: tensor<14x19x!quant.uniform>, %arg1: tensor<19x28x!quant.uniform>, %arg2: tensor<28x!quant.uniform>) -> tensor<14x28x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) {quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32 } } : (tensor<14x19x!quant.uniform>, tensor<19x28x!quant.uniform>, tensor<28x!quant.uniform>) -> tensor<14x28x!quant.uniform> + return %0 : tensor<14x28x!quant.uniform> +} + +// ----- +// CHECK-LABEL: test_matmul +func @test_matmul(%arg0: tensor<14x19x!quant.uniform>, %arg1: tensor<19x28x!quant.uniform>) -> tensor<14x28xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.matmul"(%arg0, %arg1) { quantization_info = {a_zp = 1 : i32, b_zp = 2: i32}}: (tensor<14x19x!quant.uniform>, tensor<19x28x!quant.uniform>) -> tensor<14x28xi32> + return %0 : tensor<14x28xi32> +} + +// ----- +// CHECK-LABEL: transpose_conv2d +func @test_transpose_conv2d(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1: tensor<16x1x1x8x!quant.uniform>, %arg2: tensor<16x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 32, 32, 16], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32 }, stride = [1, 1]} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform>, tensor<16x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> + return %0 : tensor<1x32x32x16x!quant.uniform> +} + +// ----- +// CHECK-LABEL: negate +func @test_negate(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.negate"(%arg0) { quantization_info = {input_zp = -1 : i32, output_zp = 1 : i32 }}: (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: pad +func @test_pad(%arg0: tensor<13x21x3x!quant.uniform>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.pad"(%arg0, %arg1) { quantization_info = {input_zp = -1 : i32 }}: (tensor<13x21x3x!quant.uniform>, tensor<3x2xi32>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: rescale +func @test_rescale(%arg0: tensor<1x32x32x16xi32>) -> tensor<1x32x32x16x!quant.uniform> { + %0 = "tosa.rescale"(%arg0) {double_round = true, input_zp = 0 : i32, multiplier = [1723896117 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [39 : i32]} : (tensor<1x32x32x16xi32>) -> tensor<1x32x32x16x!quant.uniform> + return %0 : tensor<1x32x32x16x!quant.uniform> +} Index: mlir/test/Dialect/Tosa/quant-test.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/quant-test.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s + +// ----- +// CHECK-LABEL: test_build_qtype +func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> { + // CHECK: tosa.negate + %0 = "tosa.negate"(%arg0) : (tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> + return %0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> +} + +// ----- +// CHECK-LABEL: test_build_mult_and_shift +func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> { + // CHECK: tosa.conv2d + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 2, 2], dilation = [2, 1], stride = [1, 1], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32}} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> + return %0 : tensor<1x32x32x16x!quant.uniform> + +} Index: mlir/test/Dialect/Tosa/types.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/types.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + +//===-----------------------------------------------------------------------===// +// Tests to check for parsing of legal TOSA data types +//===-----------------------------------------------------------------------===// + +// ----- +// CHECK-LABEL: @test_bool +func @test_bool(tensor<2xi1>) -> () + +// ----- +// CHECK-LABEL: @test_int32 +func @test_int32(tensor<2xi32>) -> () + +// ----- +// CHECK-LABEL: @test_i48 +func @test_i48(tensor<2xi48>) -> () + +// ----- +// CHECK-LABEL: @test_i64 +func @test_i64(tensor<2xi64>) -> () + +// ----- +// CHECK-LABEL: @test_f16 +func @test_f16(tensor<2xf16>) -> () + +// ----- +// CHECK-LABEL: @test_f32 +func @test_f32(tensor<2xf32>) -> () + +// ----- +// CHECK-LABEL: @test_bf16 +func @test_bf16(tensor<2xbf16>) -> () + +// ----- +// CHECK-LABEL: @test_quant1 +func @test_quant1(tensor<2x!quant.uniform>) -> () + +// ----- +// CHECK-LABEL: @test_quant2 +func @test_quant2(tensor<2x!quant.uniform>) -> () + +// ----- +// CHECK-LABEL: @test_quant3 +func @test_quant3(tensor<2x!quant.uniform>) -> () + +// ----- +// CHECK-LABEL: @test_quant4 +func @test_quant4(tensor<2x!quant.uniform>) -> () + +// ----- +// CHECK-LABEL: @test_perchannel_quant +func @test_perchannel_quant(tensor<16x3x3x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>) -> ()