Index: mlir/docs/Dialects/TOSA.md =================================================================== --- /dev/null +++ mlir/docs/Dialects/TOSA.md @@ -0,0 +1,146 @@ +# TOSA Dialect + +[TOC] + +## Rationale + +The MLIR TOSA dialect implements the [TOSA +specification](https://developer.mlplatform.org/w/tosa/). This document +describes the decision process for how TOSA expresses operators in +high level dialects. + +TOSA was developed after parallel efforts to rationalize the top-down picture +from multiple high-level frameworks, as well as a bottom-up view of different +hardware target concerns (CPU, GPU and NPU), and reflects a set of choices +that attempt to manage both sets of requirements. + +## TOSA and Tensor Level Expressiveness + +TOSA endeavors to provide an operator set that tries to fulfil the following +expressivenes goals at the *tensor level of abstraction* : + +### Complete + +This is driven by the top-down perspective, needing to express as much of +multiple high level frameworks fully in TOSA, as possible. This was originally +done from an operator frequency analysis done upon dozens of high level +networks in different frameworks, to select the most frequently occuring ones +and establish a common set of tensor-level operators that could express them. + +TOSA categorizes its operator set into classes and attempts to address major +functional operations at the tensor level, including compute, reduction, +elementwise transformations, comparison and control flow. + +### Minimal + +This takes the bottom-up approach - keep the TOSA operator set minimal in +order to bound the design of hardware, operator kernels, code generation +strategies and associated considerations that effect the executability of TOSA +content. + +In this regard TOSA seeks to avoid creating compound operators, instead +leaving it to compiler backend to fuse multiple TOSA ops if required. This +choice also benefits the numerical precision goal, since it is easier to fuse the +numerical functionality of successive operators, than to split the numerical +functionality of a compound operator. + +### Numerical Precision + +TOSA began as a means to address operator-level numerical precision for +code generation and hardware development. It therefore incorporates precision +detail into the operator set. + +In this regard, TOSA operators are best understood as a combination of the visible +quantization information embedded within an operation, together with the +functional information about how that information is used, as described in the +specification of the operation. + +## TOSA Operator Rationale + +The general basis of selection of the operator set that constitutes TOSA is +described in the TOSA specification document under Section 1.3 Operator +Selection. Explanation of the thinking behind some operators is listed here: + +### IDENTITYN + +tosa.IDENTITYN is used to form a list of Operator results during +lowering of operations such as tf.Split from a sequence of tosa.SLICE +ops. If there are alternate ways to express this lowering without the +tosa.IDENTITYN op, the tosa.IDENTITYN op could be removed from TOSA. + +``` +Value lower_split_op(Value %value, size_t axis, size_t +num_split) { Value %output[] + + size_t slice_size = %value.shape[axis] / num_split + + for (int i = 0; i < num_split; i++) { + vector begin_vals, size_vals + + for (int j = 0; j < %value.rank; j++) { + if (j == axis) { + begin_vals.push_back(slice_size * i) + size_vals.push_back(slice_size) + } else { + begin_vals.push_back(0) + size_vals.push_bac(%value.shape[j]) + } + + %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals} (tensor<%value.type>) -> tensor + } + + } + + %output_list = tosa.IDENTITYN(%output) (tensor<%output:*.type>) -> tensor<%output_list:*.type> + return %output_list +} +``` + +### COND\_IF and WHILE\_LOOP + +Several neural networks express conditional control flow at the tensor level. +A survey of multiple high level frameworks indicated that conditional if and +a loop construct are common in all major frameworks, with some variation. +Since TOSA endeavors to be complete in expressing tensor level functionality +including control flow, it implements these constructs. + +The COND\_IF and WHILE\_LOOP operators implement such structured control +flow forms and should be lowerable to corresponding ops in the scf dialect. +Since the dialect seeks to remain isomorphic with an external, serialized form, +the decision was to keep these ops in the dialect (as opposed to deferring +completely to scf), and this may be re-evaluated if this turns out to not yield +the expected value. + +## Using TOSA In A Compiler + +The TOSA specification describes each operator in functional detail. It is +expected that compilers that use TOSA will use its builders to construct the +operators so that the quantization information for the operator is correctly +generated. + +The functional steps described in the pseudocode of the specification enables +the construction of code generation for that operation, or decisions on the +design of underlying hardware. The functional pseudocode also describes +how the quantization parameters are utilized within the operation. + +### Quantization Parameters in Ops vs Tensors + +TOSA uses the quantization parameters embedded in the input and output +tensors to construct the quantization attributes that sit within the operator. +Once these attributes are constructed, the quantization information within +the tensors are no longer necessary for code generation. + +This enables the tensors to be subsequently interpreted simply as contiguous +buffers containing raw data, with no 'meta information' in the form of the +quantization_type. Precision related manipulation of the input or output are +instead described by the operator itself which describes, for example, when +the zero point is applied, or when the scale multiplication is done. + +However, TOSA does *not* eliminate the existing MLIR QuantOps quantization +type information within the tensors; this leaves the choice of how to handle +quantization information, to later backend code generation steps. + +Maintaining the ability to overlap these different representations of +quantization parameters (i.e. tensor-carried vs op-carried) is an important +capability when considering progressive lowering between uses that expect one +scheme vs the other. Index: mlir/include/mlir/Dialect/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/CMakeLists.txt +++ mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,4 +13,5 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) Index: mlir/include/mlir/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) Index: mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -0,0 +1,18 @@ +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaDialect.h.inc -gen-dialect-decls) +mlir_tablegen(TosaOps.h.inc -gen-op-decls) +mlir_tablegen(TosaOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTosaOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaStructs.h.inc -gen-struct-attr-decls) +mlir_tablegen(TosaStructs.cpp.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRTosaStructsIncGen) + + +set(LLVM_TARGET_DEFINITIONS TosaInterfaces.td) +mlir_tablegen(TosaInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TosaInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRTosaInterfaceIncGen) + +add_mlir_doc(TosaOps -gen-op-doc TosaOps Dialects/) Index: mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td @@ -0,0 +1,25 @@ +//===-- TosaInterfaces.td - TOSA dialect interfaces --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dialect op interfaces for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_OP_INTERFACES +#define TOSA_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def TosaOpInterface : OpInterface<"TosaOp"> { + let description = [{ + Implements interfaces implemented by ops that correspond to the Tosa + specification. + }]; +} + +#endif Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -0,0 +1,192 @@ +//===-- TosaOpBase.td - TOSA dialect op builders -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the common definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef TOSA_OP_BASE +#define TOSA_OP_BASE + +//===----------------------------------------------------------------------===// +// The TOSA Dialect. +//===----------------------------------------------------------------------===// +def Tosa_Dialect : Dialect { + let name = "tosa"; + + let description = [{ + The Tensor Operator Set Architecture (TOSA) dialect. + + This dialect implements the TOSA standard described at + https://developer.mlplatform.org/w/tosa/ . + + Tensor Operator Set Architecture (TOSA) provides a set of whole-tensor + operations commonly employed by Deep Neural Networks. The intent is to + enable a variety of implementations running on a diverse range of + processors, with the results at the TOSA level consistent across those + implementations. Applications or frameworks which target TOSA can therefore + be deployed on a wide range of different processors, such as CPUs or GPUs, + with defined accuracy and compatibility constraints. Most operators from the + common ML frameworks should be expressible in TOSA. It is expected that + there will be tools to lower from the ML frameworks into TOSA. + }]; + + let cppNamespace = "mlir::tosa"; +} + +//===----------------------------------------------------------------------===// +// TOSA Operator Quantization Attributes. +//===----------------------------------------------------------------------===// + +// Quantization attributes used across TOSA operators. Quantization attributes +// feed numerical precision parameters to the functional implementation of TOSA +// operators. +// The functional behavior is defined in the TOSA specification maintained at +// https://developer.mlplatform.org/w/tosa/ . TOSA leverages MLIR's built in +// quantization support: https://mlir.llvm.org/docs/Quantization/, and supports +// uniform quantization. Depending on datatype, asymmetric and symmetric +// quantization are supported. The types themselves are described in +// TosaTypesBase.td . + +// This quantization attribute expresses numerical behavior of operators where +// the operator has a numerical relationship between a single input and output. +// For example: tosa.negate. +def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr", + Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"output_zp", I32Attr> + ]> { + let description = "Attribute for UnaryOp quantization information."; +} + +// There is no explicit BinaryOpQuantizationAttr for 2-input/1-output ops. In +// this case, a tosa.rescale is used to express the inputs to the same scale. +// TODO: Upload WIP legalization document describing this construction by +// example. + +// This quantization attribute holds input and weight zero point. Both the +// ConvOp and MatMulOp QuantizationAttrs follow a common design semantic where +// their ownquantization attribute only expresses the numerical behavior at +// the inputs. +// The scaling of their accumulator output is done using an explicit +// tosa.rescale operator that scales the accumulator result to output scale. +def Tosa_ConvOpQuantizationAttr : StructAttr<"ConvOpQuantizationAttr", + Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"weight_zp", I32Attr> + ]> { + let description = "Attribute for Conv type op quantization information."; +} + +def Tosa_MatMulOpQuantizationAttr : StructAttr<"MatMulOpQuantizationAttr", + Tosa_Dialect, [ + StructFieldAttr<"a_zp", I32Attr>, + StructFieldAttr<"b_zp", I32Attr> + ]> { + let description = "Attribute for MatMulOp quantization information."; +} + +// This attribute holds input zero point correction applied to the padding +// zeros to ensure numerical accuracy in the subsequent TOSA operations. +// Its functional application is described in the tosa.pad() operator +// description in the specification. +def Tosa_PadOpQuantizationAttr : StructAttr<"PadOpQuantizationAttr", + Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr> + ]> { + let description = "Attribute for PadOp quantization information."; +} + +//===----------------------------------------------------------------------===// +// TOSA Operator Quantization Builders. +//===----------------------------------------------------------------------===// + +// This builder is called on all convolution operators except for TransposeConv, +// which has specialized output shape semantics. The builder also defines the +// bitwidth of the output given the bit width of the input & weight content. +def Tosa_ConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias, + "ArrayAttr":$pad, "ArrayAttr":$stride, "ArrayAttr":$dilation), + [{ + ::buildConvOpWithQuantInfo($_builder, $_state, outputType, + input, weight, bias, + pad, stride, dilation); + }]>; + +// Handles tosa.transpose_conv2d which has an outpad and output shape attribute. +def Tosa_TransConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias, + "ArrayAttr":$outpad, "ArrayAttr":$stride, "ArrayAttr":$dilation, + "ArrayAttr":$outputShape), + [{ + ::buildTransConvOpWithQuantInfo($_builder, $_state, outputType, + input, weight, bias, + outpad, stride, dilation, + outputShape); + }]>; + +// The tosa.fully_connected op has its own builder as it does not have +// strides/dilation/padding. +def Tosa_FCOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias), + [{ + ::buildFCOpWithQuantInfo($_builder, $_state, outputType, + input, weight, bias); + }]>; + +// The tosa.matmul op is also intended to be generated where a fully_connected +// op must be constructed where the weight is not a constant. In this case, +// the fully_connected op must be expressed using matmul. +// TODO: Add link to the leglization document explaining this. +def Tosa_MatMulOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$a, "Value":$b), + [{ + ::buildMatMulOpWithQuantInfo($_builder, $_state, outputType, + a, b); + }]>; + +// Both the tosa.avg_pool2d and unary ops use the same +// UnaruOpQuantizationAttr but the avg_pool operator has its own builder as it +// has additional parameters not part of the unary ops. +def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "ArrayAttr":$kernel, + "ArrayAttr":$stride, "ArrayAttr":$pad), + [{ + ::buildAvgPool2dOpWithQuantInfo($_builder, $_state, outputType, + input, kernel, stride, pad); + }]>; + +// This builder is called on single-parameter unary operators that have a scale +// relationship between their input and output, expressed by the +// UnaryOpQuantizationAttr. +def Tosa_UnaryOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input), + [{ + ::buildUnaryOpWithQuantInfo($_builder, $_state, outputType, input); + }]>; + +// This builder is called on the TOSA pad operator that needs to create its own +// OptionalAttr quantization_attr parameter to scale the padding values +// correctly. +def Tosa_PadOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$paddings), + [{ + ::buildPadOpWithQuantInfo($_builder, $_state, outputType, + input, paddings); + }]>; + +//===----------------------------------------------------------------------===// +// TOSA Operator. +//===----------------------------------------------------------------------===// + +class Tosa_Op traits = []> : + Op { +} + +#endif // TOSA_OP_BASE Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -0,0 +1,39 @@ +//===-- TosaOps.h - TOSA dialect operation definitions ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_IR_TOSA_OPS_H +#define DIALECT_TOSA_IR_TOSA_OPS_H + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tosa/IR/TosaTraits.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +//===----------------------------------------------------------------------===// +// TOSA dialect and structs includes. +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/IR/TosaDialect.h.inc" +#include "mlir/Dialect/Tosa/IR/TosaStructs.h.inc" + +namespace mlir { +namespace tosa { + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" + +} // end namespace tosa +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" + +#endif // TOSA_OPS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -0,0 +1,1701 @@ +//===-- TosaOps.td - TOSA dialect operation definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the operation set for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_OPS +#define TOSA_OPS + +include "mlir/IR/OpBase.td" + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" + +include "mlir/Dialect/Tosa/IR/TosaTypesBase.td" +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.2 +// Operator Class: Tensor Data Engine Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: argmax +//===----------------------------------------------------------------------===// +def Tosa_ArgMaxOp : Tosa_Op<"argmax", [NoSideEffect]> { + let summary = "Perform argmax on the input."; + + let description = [{ + This returns the index with the largest value across the given axis of the + input tensor. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D: $input, + I64Attr: $axis + ); + + let results = (outs + Tosa_TensorUpto4D: $output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: avg_pool2d +//===----------------------------------------------------------------------===// +def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs an average pooling over the given input tensor. A sliding + window of size given by is passed over the input tensor, with + the mean value being placed in the output tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + + Tosa_IntArrayAttr2:$kernel, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr4:$pad, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: conv2d +//===----------------------------------------------------------------------===// +def Tosa_Conv2DOp : Tosa_Op<"conv2d", [NoSideEffect]> { + let summary = "2D Convolution Operator"; + + let description = [{ + Performs a 2D convolution over the given tensor input, using the weight + tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + Tosa_Tensor4D:$weight, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr4:$pad, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr2:$dilation, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// Operator: conv3d +//===----------------------------------------------------------------------===// +def Tosa_Conv3DOp : Tosa_Op<"conv3d", [NoSideEffect]> { + let summary = "3D Convolution operator"; + + let description = [{ + Performs a 3D convolution over the given input tensor. + }]; + + let arguments = (ins + Tosa_Tensor5D:$input, + Tosa_Tensor5D:$weight, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr6:$pad, + Tosa_IntArrayAttr3:$stride, + Tosa_IntArrayAttr3:$dilation, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor5D:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// Operator: depthwise_conv2d +//===----------------------------------------------------------------------===// +def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [NoSideEffect]> { + let summary = "Depthwise 2D Convolution operator"; + + let description = [{ + Performs 2D convolutions separately over each channel of the given tensor + input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + Tosa_Tensor4D:$weight, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr4:$pad, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr2:$dilation, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// Operator: fully_connected +//===----------------------------------------------------------------------===// +def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> { + let summary = "Fully Connected operator"; + + let description = [{ + Performs a fully connected network. + }]; + + let arguments = (ins + Tosa_Tensor2D:$input, + Tosa_Tensor2D:$weight, + Tosa_Tensor1D:$bias, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor2D:$output + ); + + let builders = [Tosa_FCOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; +} + +//===----------------------------------------------------------------------===// +// Operator: matmul +//===----------------------------------------------------------------------===// +def Tosa_MatMulOp : Tosa_Op<"matmul", [NoSideEffect]> { + let summary = "Matrix multiplication with bias"; + + let description = [{ + Performs a two dimensional matrix multiplication. This allows both inputs to + be activations, rather than reserving weights as an attribute in the + FULLY_CONNECTED operator. + }]; + + let arguments = (ins + Tosa_Tensor2D:$a, + Tosa_Tensor2D:$b, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor2D:$c + ); + + let builders = [Tosa_MatMulOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: max_pool2d +//===----------------------------------------------------------------------===// +def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs a max pooling over the given input tensor. A sliding window of + size given by is passed over the input tensor, with the + maximum value being placed in the + output tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + + Tosa_IntArrayAttr2:$kernel, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr4:$pad + ); + + let results = (outs + Tosa_Tensor4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose_conv2d +//===----------------------------------------------------------------------===// +def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [NoSideEffect]> { + let summary = "Transpose 2D Convolution operator."; + + let description = [{ + Performs a 2D transposed convolution over the given tensor input, using the + weights tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + Tosa_Tensor4D:$filter, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr2:$out_pad, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr2:$dilation, + Tosa_IntArrayAttrUpto4:$out_shape, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_TransConvOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.3 +// Operator Class: Activation Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: clamp +//===----------------------------------------------------------------------===// +def Tosa_ClampOp : Tosa_Op<"clamp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes clamp(features, min, max)."; + + let description = [{ + Clamp to an arbitrary minimum and maximum value. Note that the maximum and + minimum values are specified as signed quantized values, no scaling happens + before or after this operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$min_int, + I64Attr:$max_int, + F32Attr:$min_fp, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reluN +//===----------------------------------------------------------------------===// +def Tosa_ReluNOp : Tosa_Op<"reluN", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes rectified linear: `max(features, N)`."; + + let description = [{ + ReLU with a scalar maximum value. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$max_int, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: sigmoid +//===----------------------------------------------------------------------===// +def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Computes elementwise sigmoid of input."; + + let description = [{ + Sigmoid function: output = 1 / (1 + exp(-input)) + For quantized integer data types, the TABLE operator should be used instead + with the following definition. The sigmoid table has 513 entries each of + 16-bit precision and covering the input range -16.0 to +16.0 + in steps of 1/16. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: tanh +//===----------------------------------------------------------------------===// +def Tosa_TanhOp : Tosa_Op<"tanh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise hyperbolic tangent of input"; + + let description = [{ + Parameterized hyperbolic tangent. + For quantized integer data types, the TABLE operator should be used instead + with the following definition. The tanh_table has 513 entries each of + 16-bit precision and covering the input range -8.0 to +8.0 in steps of 1/32. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.4 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise binary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: add +//===----------------------------------------------------------------------===// +def Tosa_AddOp : Tosa_Op<"add", [ResultsBroadcastableShape, NoSideEffect, + Commutative]> { + let summary = "Elementwise addition operator"; + + let description = [{ + Elementwise addition of input1 and input2. Axis of size 1 will be broadcast, + as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: arithmetic_right_shift +//===----------------------------------------------------------------------===// +def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", + [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Elementwise Arithmetic Right Shift"; + + let description = [{ + Elementwise arithmetic right shift of input1 by the amount specified in + input2. Axis of size 1 will be broadcast, as necessary. Rank of input + tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2, + BoolAttr:$round + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_and +//===----------------------------------------------------------------------===// +def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Bitwise AND operator"; + + let description = [{ + Elementwise bitwise AND of input1 and input2. Axis of size 1 + will be broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_or +//===----------------------------------------------------------------------===// +def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Bitwise OR operator"; + + let description = [{ + Elementwise bitwise OR of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_xor +//===----------------------------------------------------------------------===// +def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Bitwise XOR operator"; + + let description = [{ + Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_and +//===----------------------------------------------------------------------===// +def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ResultsBroadcastableShape, + Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x AND y element-wise."; + + let description = [{ + Elementwise logical AND of input1 and input2. Axis of size 1 will be + broadcast, as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$input1, + I1Tensor:$input2 + ); + + let results = (outs + I1Tensor:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_left_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", + [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Elementwise Logical Left Shift"; + + let description = [{ + Elementwise left shift of input1 and input2. Axis of size 1 will be + broadcast, as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_right_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", + [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Elementwise Logical Right Shift"; + + let description = [{ + Elementwise logical right shift of input1 by the amount specified in input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_or +//===----------------------------------------------------------------------===// +def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ResultsBroadcastableShape, + Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x OR y element-wise."; + + let description = [{ + Elementwise logical OR of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$input1, + I1Tensor:$input2 + ); + + let results = (outs + I1Tensor:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_xor +//===----------------------------------------------------------------------===// +def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ResultsBroadcastableShape, + Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x XOR y element-wise."; + + let description = [{ + Elementwise logical XOR of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$input1, + I1Tensor:$input2 + ); + + let results = (outs + I1Tensor:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: maximum +//===----------------------------------------------------------------------===// +def Tosa_MaximumOp : Tosa_Op<"maximum", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Elementwise Maximum"; + + let description = [{ + Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as + necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: minimum +//===----------------------------------------------------------------------===// +def Tosa_MinimumOp : Tosa_Op<"minimum", [ResultsBroadcastableShape, + NoSideEffect, Commutative]> { + let summary = "Elementwise Minimum"; + + let description = [{ + Elementwise minimum of input1 and input2. Axis of size 1 + will be broadcast, as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: mul +//===----------------------------------------------------------------------===// +def Tosa_MulOp : Tosa_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, + Commutative]> { + let summary = "Multiplication operator"; + + let description = [{ + Elementwise multiplication (Hadamard product) of input1 and input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2, + I32Attr:$shift + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: pow +//===----------------------------------------------------------------------===// +def Tosa_PowOp : Tosa_Op<"pow", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Computes the power of one value to another."; + + let description = [{ + Elementwise input1 raised to the power of input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: sub +//===----------------------------------------------------------------------===// +def Tosa_SubOp : Tosa_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise subtraction operator"; + + let description = [{ + Elementwise subtraction of input1 and input2. Axis of size 1 will be + broadcast as necessary. Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: table +//===----------------------------------------------------------------------===// +def Tosa_TableOp : Tosa_Op<"table", [NoSideEffect]> { + let summary = "Table lookup op"; + + let description = [{ + Interpolated table lookup operation. Input values are scaled to create a + fixed-point 9.7 value. The high 9 bits are used to index into the table. + The fractional bits are used to interpolate based on the looked up value and + the index+1 value in the table. The TABLE operator then returns a 16.7 + interpolated value. Note that there must be 513 values to handle the full + range of inputs. + + The TABLE operator is expected to be used as follows: + * A RESCALE node is expected before the TABLE operator to scale the input + to a full int16_t range for the table lookup + * If an int16_t result is required then follow the TABLE operator with a + RESCALE with a right shift of 7 + * If an int8_t result is required then follow the TABLE operator with a + RESCALE with a right shift of 15 + }]; + + let arguments = (ins + Tosa_TensorUpto4D: $input, + Tosa_Tensor1D: $table + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.5 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise unary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: abs +//===----------------------------------------------------------------------===// +def Tosa_AbsOp : Tosa_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise abs op"; + + let description = [{ + Elementwise absolute value operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_not +//===----------------------------------------------------------------------===// +def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Bitwise NOT operator"; + + let description = [{ + Elementwise bitwise NOT of input tensor. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: ceil +//===----------------------------------------------------------------------===// +def Tosa_CeilOp : Tosa_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise ceil op"; + + let description = [{ + Elementwise ceiling operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: clz +//===----------------------------------------------------------------------===// +def Tosa_ClzOp : Tosa_Op<"clz", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise count leading zero op"; + + let description = [{ + Elementwise count leading zeros operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: exp +//===----------------------------------------------------------------------===// +def Tosa_ExpOp : Tosa_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise exp op"; + + let description = [{ + Elementwise e to the x operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: floor +//===----------------------------------------------------------------------===// +def Tosa_FloorOp : Tosa_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise floor op"; + + let description = [{ + Elementwise floor operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: log +//===----------------------------------------------------------------------===// +def Tosa_LogOp : Tosa_Op<"log", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise log op"; + + let description = [{ + Elementwise natural logarithm operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_not +//===----------------------------------------------------------------------===// +def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Returns the truth value of NOT x element-wise."; + + let description = [{ + Elementwise logical NOT of input. + }]; + + let arguments = (ins + I1Tensor:$input1 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: negate +//===----------------------------------------------------------------------===// +def Tosa_NegateOp : Tosa_Op<"negate", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Elementwise negate op"; + + let description = [{ + Elementwise negation operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); + + let builders = [Tosa_UnaryOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: reciprocal +//===----------------------------------------------------------------------===// +def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [NoSideEffect, + SameOperandsAndResultType]> { + let summary = "Elementwise reciprocal op"; + + let description = [{ + Elementwise reciprocal operation. For integer operation, a TABLE should be + used with the appropriate ranges. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: rsqrt +//===----------------------------------------------------------------------===// +def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise 1/sqrt op"; + + let description = [{ + Elementwise reciprocal square root operation. For integer operation, a TABLE + should be used with the appropriate ranges. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.6 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise ternary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: select +//===----------------------------------------------------------------------===// +def Tosa_SelectOp : Tosa_Op<"select", [NoSideEffect]> { + let summary = "Elementwise select operator"; + + let description = [{ + Elementwise select of the output based on a condition. + }]; + + let arguments = (ins + I1Tensor:$input1, + Tosa_TensorUpto4D:$input2, + Tosa_TensorUpto4D:$input3 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.7 +// Operator Class: Logical Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: equal +//===----------------------------------------------------------------------===// +def Tosa_EqualOp : Tosa_Op<"equal", [ResultsBroadcastableShape, Commutative, + NoSideEffect]> { + let summary = "Returns the truth value of (x == y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: greater +//===----------------------------------------------------------------------===// +def Tosa_GreaterOp : Tosa_Op<"greater", [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Returns the truth value of (x > y) element-wise."; + + let description = [{ + Elementwise greater than comparison operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: greater_equal +//===----------------------------------------------------------------------===// +def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ResultsBroadcastableShape, + NoSideEffect]> { + let summary = "Returns the truth value of (x >= y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.8 +// Operator Class: Reduction Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: reduce_all +//===----------------------------------------------------------------------===// +def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> { + let summary = "Reduce All operator"; + + let description = [{ + Reduce a tensor along the given axis with a logical AND operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_any +//===----------------------------------------------------------------------===// +def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> { + let summary = "Reduce Any operator"; + + let description = [{ + Reduce a tensor along the given axis with a logical OR operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_max +//===----------------------------------------------------------------------===// +def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> { + let summary = "Reduce Max operator"; + + let description = [{ + Reduce a tensor along the given axis with a maximum operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_min +//===----------------------------------------------------------------------===// +def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> { + let summary = "Reduce Min operator"; + + let description = [{ + Reduce a tensor along the given axis with a minimum operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_prod +//===----------------------------------------------------------------------===// +def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> { + let summary = "Reduce Prod operator"; + + let description = [{ + Reduce a tensor along the given axis by computing the product of the axis. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_sum +//===----------------------------------------------------------------------===// +def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> { + let summary = "Reduce Sum operator"; + + let description = [{ + Reduce a tensor along the given axis by computing the sum of the axis. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.9 +// Operator Class: Data Layout / Memory Reinterpretation. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: concat +//===----------------------------------------------------------------------===// +def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> { + let summary = "Concatenates tensors along one dimension."; + + let description = [{ + Concatenate two tensors along a given axis. No data conversion happens + during a concat operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input1, + Tosa_Tensor1Dto4D:$input2, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: pad +//===----------------------------------------------------------------------===// +def Tosa_PadOp : Tosa_Op<"pad", [NoSideEffect]> { + let summary = "Pads a tensor with zeros."; + + let description = [{ + Zero-pads a tensor along borders of each dimension. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input1, + Tosa_Int32Or64Tensor:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + + let builders = [Tosa_PadOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: reshape +//===----------------------------------------------------------------------===// +def Tosa_ReshapeOp: Tosa_Op<"reshape", [ + NoSideEffect]> { + let summary = "Reshape operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with a new shape + specified by the shape argument. Reshape may operate on tensors of any rank. + No data conversion happens during a reshape operation. + }]; + + let arguments = (ins + Tosa_TensorUpto6D:$input1, + I64ArrayAttr:$new_shape + ); + + let results = (outs + Tosa_TensorUpto6D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: reverse +//===----------------------------------------------------------------------===// +def Tosa_ReverseOp: Tosa_Op<"reverse", [NoSideEffect]> { + let summary = "Reverse operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with the data + reversed along the given axis. No data conversion happens during a reverse + operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: slice +//===----------------------------------------------------------------------===// +def Tosa_SliceOp: Tosa_Op<"slice", [NoSideEffect]> { + let summary = "Slice operator"; + + let description = [{ + Extracts a slice of the input1 on the given axis, beginning at the + start coordinates, and extending for size elements in each direction. No + data conversion happens during a slice operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto6D:$input, + I64ArrayAttr:$start, + I64ArrayAttr:$size + ); + + let results = (outs + Tosa_Tensor1Dto6D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: tile +//===----------------------------------------------------------------------===// +def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> { + let summary = "Tile operator"; + + let description = [{ + Replicates input 0 multiplies times along each dimension. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input1, + I64ArrayAttr:$multiples); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose +//===----------------------------------------------------------------------===// +def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> { + let summary = "Transpose operator"; + + let description = [{ + Permutes the dimensions based on perm. + }]; + + let arguments = (ins + Tosa_Tensor1Dto6D:$input1, + Tosa_Int32Or64Tensor:$perms + ); + + let results = ( + outs Tosa_Tensor1Dto6D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.10 +// Operator Class: Scatter/gather Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: gather +//===----------------------------------------------------------------------===// +def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> { + let summary = "Gather operation,"; + + let description = [{ + Generate a tensor for which each element in the output is a subtensor of the + values tensor along the given axis, based on the value of indices. + }]; + + let arguments = (ins + Tosa_Int32Or64Tensor:$indices, + Tosa_Tensor1Dto4D:$values, + I32Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.11 +// Operator Class: Image Frontend Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: resize +//===----------------------------------------------------------------------===// +def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> { + + let summary = "Resize operation, supports various resize/upsample modes"; + + let description = [{ + Resizes a tensor. Resize is only allowed in the H and W dimensions. In + expected use, stride_y is approximately (IH< { + + let summary = "Cast operation"; + + let description = [{ + Performs a set of permissible cast operations + Mode Input Output + --------------------------------------- + signed 8 to bool int8 Boolean + signed 16 to bool int16 Boolean + signed 32 to bool int32 Boolean + bool to 8 Boolean int8 + bool to 16 Boolean int16 + bool to 32 Boolean int32 + signed 8 to signed 16 int8 int16 + signed 8 to signed 32 int8 int32 + signed 16 to signed 8 int16 int8 + signed 16 to signed 32 int16 int32 + signed 32 to signed 8 int32 int8 + signed 32 to signed 16 int32 int16 + float to signed 8 float int8 + float to signed 16 float int16 + signed 8 to float int8 float + signed 16 to float int16 float + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: rescale +//===----------------------------------------------------------------------===// +def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> { + let summary = "Tosa rescale operator"; + + let description = [{ + Rescale quantized values into a new domain. Supported rescalings are: + Mode Input Output + signed 8 to 8 aint8 aint8 + signed 8 to 16 aint8 int16 + signed 8 to 32 aint8 int32 + signed 16 to 8 int16 aint8 + signed 16 to 16 int16 int16 + signed 16 to 32 int16 int32 + signed 32 to 8 int32 aint8 + signed 32 to 16 int32 int16 + signed 32 to 32 int32 int32 + signed 48 to 8 int48 aint8 + signed 48 to 16 int48 int16 + signed 48 to 32 int48 int32 + unsigned 8 to signed 8 uint8 aint8 + signed 8 to unsigned 8 aint8 uint8 + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input, + I32Attr:$input_zp, + I32Attr:$output_zp, + I32ArrayAttr:$multiplier, + I32ArrayAttr:$shift, + BoolAttr:$scale32, + BoolAttr:$double_round, + BoolAttr:$per_channel + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.13 +// Operator Class: Data Node Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: const +//===----------------------------------------------------------------------===// +def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, NoSideEffect, + FirstAttrDerivedResultType]> { + let summary = "Constant op."; + + let description = [{ + A node containing constant data for use as the input to an operation. May + hold data in any of the supported data formats. + }]; + + let arguments = (ins + AnyAttr:$value + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: identity +//===----------------------------------------------------------------------===// +def Tosa_IdentityOp: Tosa_Op<"identity", [NoSideEffect]> { + let summary = "Identity operator"; + let description = [{ + Returns a tensor with the same shape, size, type + and content as the input. + }]; + + let arguments = (ins + Tosa_TensorUpto6D:$input1 + ); + + let results = (outs + Tosa_TensorUpto6D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: identityn +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Further described in docs/Rationale/RationaleTOSADialect.md . +//===----------------------------------------------------------------------===// +def Tosa_IdentityNOp: Tosa_Op<"identityn", [NoSideEffect]> { + let summary = "IdentityN operator"; + let description = [{ + Returns a list of tensors with the same shape, type, and contents as the + input list of tensors. + }]; + + let arguments = (ins + Variadic:$input1 + ); + + let results = (outs + Variadic:$output + ); +} + + +//===----------------------------------------------------------------------===// +// Operator: placeholder +//===----------------------------------------------------------------------===// +def Tosa_PlaceholderOp : Tosa_Op<"placeholder", [NoSideEffect]> { + let summary = "Placeholder op"; + + let description = [{ + A node where data will be inserted into the network at runtime. Generally + used for inputs to the network. + }]; + + let arguments = (ins + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.14 +// Operator Class: Custom Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: custom +//===----------------------------------------------------------------------===// +def Tosa_CustomOp : Tosa_Op<"custom"> { + + let summary = "Custom operator wrapper for Tosa"; + + let description = [{ + Hardware implementing TOSA may choose to add additional custom operators + that are not expressed in the existing TOSA operations. These operators are + not expected to be portable across TOSA implementations. The input and + output signatures must be expressed in the corresponding TOSA node. + }]; + + let arguments = (ins + StrAttr:$identifier, + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.15 +// Operator Class: Control Flow Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: cond_if +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Further described in docs/Rationale/RationaleTOSADialect.md . +//===----------------------------------------------------------------------===// +def Tosa_IfOp : Tosa_Op<"cond_if", [ + SingleBlockImplicitTerminator<"YieldOp">, + RecursiveSideEffects]> { + let summary = "Conditional if operator"; + + let description = [{ + Evaluates a Boolean condition and then takes one of two distinct execution + paths. This implements the semantic If-then-else structure. + }]; + + let arguments = (ins + I1Tensor:$cond, + Variadic:$inputs + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region + SizedRegion<1>:$then_branch, + SizedRegion<1>:$else_branch + ); +} + +//===----------------------------------------------------------------------===// +// Operator: while_loop +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Further described in docs/Rationale/RationaleTOSADialect.md . +//===----------------------------------------------------------------------===// +def Tosa_WhileOp : Tosa_Op<"while_loop", [ + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">, + RecursiveSideEffects]> { + let summary = "output = input; While (Cond(output)) {output = Body(output)}"; + + let description = [{ + Generates and evaluates a Bool condition and either executes a loop body or + exits to another control point. This action is performed repeatedly after + updating and re-evaluating the Boolean condition every iteration. This + implements the semantic foreach or while iterative loop structure. + }]; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region + SizedRegion<1>:$cond, + SizedRegion<1>:$body + ); +} + +//===----------------------------------------------------------------------===// +// Operator: yield +//===----------------------------------------------------------------------===// +def Tosa_YieldOp : Tosa_Op<"yield", [ + Terminator, + NoSideEffect]> { + let summary = "yield operator"; + + let description = [{ + return operation within the conditional and body of + structured control flow. Operation takes variadic operands + but produces no results of its own. + }]; + + let arguments = (ins + Variadic:$inputs + ); +} + +#endif // TOSA_OPS Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h @@ -0,0 +1,34 @@ +//===-- TosaTraits.h - TOSA dialect operation traits ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect OpTraits in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATRAITS_H +#define MLIR_DIALECT_TOSA_IR_TOSATRAITS_H + +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace OpTrait { +namespace tosa { + +// TBD + +} // namespace tosa +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATRAITS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h @@ -0,0 +1,31 @@ +//===-- TosaTypes.h - TOSA dialect type definitions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect Types in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATYPES_H +#define MLIR_DIALECT_TOSA_IR_TOSATYPES_H + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +namespace tosa { + +// TOSA specific types go here + +} // namespace tosa + +} // end namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATYPES_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -0,0 +1,159 @@ +//===-- TosaTypesBase.td - TOSA type definitions -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the type definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_TYPES_BASE +#define TOSA_TYPES_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Tosa Type Definitions. +//===----------------------------------------------------------------------===// + +// The base class of a quantized type. +// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end]. +// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for +// the 8-bit case. +class Tosa_QuantizedType params, bit signed> + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # !head(params)>]>, + "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = + StrJoinInt.result # !if(signed, ", true", ", false"); +} + +//===----------------------------------------------------------------------===// +// Non-Quantized Signed Integer Types. +// Used to express accumulator results or compare results. +//===----------------------------------------------------------------------===// + +def Tosa_Int32 : I<32>; +def Tosa_Int48 : I<48>; +def Tosa_Int64 : I<64>; + +def Tosa_SignedInt : AnyTypeOf<[Tosa_Int32, + Tosa_Int48, + Tosa_Int64]>; + +def Tosa_Bool : I<1>; + +// No unsigned unquantized int types. +def Tosa_Int : AnyTypeOf<[Tosa_Bool, + Tosa_SignedInt]>; + +def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, + Tosa_Int64]>; + +//===----------------------------------------------------------------------===// +// Quantized Integer Types. +// Datatype for network feature map or weight content. +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Name Symmetry Grouping Sign +//===----------------------------------------------------------------------===// +// aint8 : asymmetric per tensor, signed +// uint8 : asymmetric per tensor , unsigned +// int4 : symmetric per channel, signed +// int8 : symmetric per tensor/per channel, signed +// int16 : symmetric per tensor, signed +//===----------------------------------------------------------------------===// +def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"aint8", [8], 1>, + Tosa_QuantizedType<"uint8", [8], 0>, + Tosa_QuantizedType<"int4", [4, 0], 1>, + Tosa_QuantizedType<"int8", [8, 0], 1>, + Tosa_QuantizedType<"int16", [16, 0], 1>]>; + +//===----------------------------------------------------------------------===// +// Floating-point types. +//===----------------------------------------------------------------------===// +def Tosa_Float : AnyTypeOf<[ + F32, + F16, + BF16]>; + +//===----------------------------------------------------------------------===// +// Multi-category types. +//===----------------------------------------------------------------------===// +def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], + "number">; + +//===----------------------------------------------------------------------===// +// Tensor types +//===----------------------------------------------------------------------===// + +def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>; + +def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; + +// Any tensor element type allowed in Tosa ops. +def Tosa_ElementType : Type, "tosa.dtype">; + +class Tosa_TensorOfOrNone allowedTypes, string description = ""> : + AnyTypeOf<[TensorOf, NoneType], description>; + +//===----------------------------------------------------------------------===// +// Tensor types with constrained ranks. +//===----------------------------------------------------------------------===// + +// Must be listed rank. +def Tosa_Tensor1D : 1DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor2D : 2DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor4D : 4DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor5D : TensorRankOf<[Tosa_AnyNumber], [5]>; +def Tosa_Tensor6D : TensorRankOf<[Tosa_AnyNumber], [6]>; + +// Ranked tensors up to given rank. +def Tosa_Tensor1Dto2D : TensorRankOf<[Tosa_AnyNumber], [1,2]>; +def Tosa_Tensor1Dto4D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>; +def Tosa_Tensor1Dto5D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5]>; +def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>; + +def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>; +def Tosa_TensorUpto6D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4,5,6]>; + +//===----------------------------------------------------------------------===// +// Attribute predicates and classes. +//===----------------------------------------------------------------------===// +class ArrayMaxCt : AttrConstraint< + CPred<"$_self.cast<::mlir::ArrayAttr>().size() <= " # n>, + "with at least " # n # " elements">; + +def Tosa_IntArrayAttr2 : Confined]>; +def Tosa_IntArrayAttr3 : Confined]>; +def Tosa_IntArrayAttr4 : Confined]>; +def Tosa_IntArrayAttr5 : Confined]>; +def Tosa_IntArrayAttr6 : Confined]>; + +def Tosa_IntArrayAttrUpto2 : Confined]>; +def Tosa_IntArrayAttrUpto4 : Confined]>; +def Tosa_IntArrayAttrUpto5 : Confined]>; + +//===----------------------------------------------------------------------===// +// Iterable attributes. +//===----------------------------------------------------------------------===// +// Supported regimes for tosa.resize. +def Tosa_ResizeTypeAttr : StringBasedAttr< + CPred<"$_self.cast().getValue() == \"BILINEAR\" || " # + "$_self.cast().getValue() == \"NEAREST_NEIGHBOR\"">, + "Supported resize/upsampling strategies">; + +def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; + +// Tensor to buffer types. +def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>; +def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>; +def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>; + +#endif // TOSA_TYPES_BASE Index: mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) +add_public_tablegen_target(MLIRTosaPassIncGen) +add_dependencies(mlir-headers MLIRTosaPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc TosaPasses ./) Index: mlir/include/mlir/Dialect/Tosa/Transforms/PassDetail.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/PassDetail.h @@ -0,0 +1,21 @@ +//===- PassDetail.h - TOSA Pass class details -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_TRANSFORMS_PASSDETAIL_H +#define DIALECT_TOSA_TRANSFORMS_PASSDETAIL_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +} // end namespace mlir + +#endif // DIALECT_TOSA_TRANSFORMS_PASSDETAIL_H Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -0,0 +1,31 @@ +//===-- Passes.h - TOSA optimization pass declarations ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +namespace tosa { + +std::unique_ptr createTosaMakeBroadcastablePass(); +std::unique_ptr createTosaTestQuantUtilAPIPass(); + +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -0,0 +1,37 @@ +//===-- Passes.td - TOSA optimization pass declarations ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +include "mlir/Pass/PassBase.td" + +def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> { + let summary = "TOSA rank Reshape to enable Broadcasting"; + let description = [{ + Pass that enables broadcast by making all input arrays have the same + number of dimensions. Insert RESHAPE operations to prepend dimensions + of size one until the number of dimensions is equal. Implements + approach similar to step 1 of Numpy 4-step broadcasting: + https://numpy.org/doc/stable/reference/ufuncs.html#broadcasting + }]; + + let constructor = "createTosaMakeBroadcastablePass()"; +} + +// TOSA Test Passes + +def TosaTestQuantUtils : FunctionPass<"tosa-test-quant-utils"> { + let summary = "TOSA Test: Exercise the APIs in QuantUtils.cpp"; + let description = [{ + Exercises the API that builds a quantized type from min/max quantized range. + }]; + + let constructor = "createTosaTestQuantUtilAPIPass()"; +} Index: mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -0,0 +1,68 @@ +//===-- QuantUtils.h - TOSA numerical support declarations ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Function declarations for TOSA numerical support functions and quantization +// attribute builders +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_UTILS_QUANT_UTILS_H +#define DIALECT_TOSA_UTILS_QUANT_UTILS_H + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +#include "mlir/Dialect/Quant/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/UniformSupport.h" + +using namespace mlir; +using namespace mlir::tosa; + +//===----------------------------------------------------------------------===// +// Utililty functions to support quantization handling in Tosa. +//===----------------------------------------------------------------------===// + +/// From a scale value, computes multiplier and shift values +/// for 16 or 32-bit scale widths. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth); + +//// Builds ConvOpQuantizationAttr from input and weight. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, + Value input, Value weight); + +//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B. +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, + Value a, Value b); + +//// Builds UnaryOpQuantizationAttr for unary operations from input values. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, + Value input, + Type outputRawType); + +//// Builds PadOpQuantizationAttr for pad operations from input values. +PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, + Value input); + +//// construct ConvOp output type with correct bitwidth based on input/weight +/// width. +Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, + Value weight); + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange); + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange); + +#endif // DIALECT_TOSA_UTILS_QUANT_UTILS_H Index: mlir/include/mlir/InitAllDialects.h =================================================================== --- mlir/include/mlir/InitAllDialects.h +++ mlir/include/mlir/InitAllDialects.h @@ -33,6 +33,7 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Dialect.h" @@ -60,7 +61,8 @@ NVVM::NVVMDialect, ROCDL::ROCDLDialect, SDBMDialect, - shape::ShapeDialect>(); + shape::ShapeDialect, + tosa::TosaDialect>(); // clang-format on } Index: mlir/include/mlir/InitAllPasses.h =================================================================== --- mlir/include/mlir/InitAllPasses.h +++ mlir/include/mlir/InitAllPasses.h @@ -24,6 +24,7 @@ #include "mlir/Dialect/SPIRV/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Transforms/Passes.h" #include @@ -54,6 +55,7 @@ registerShapePasses(); spirv::registerSPIRVPasses(); registerStandardPasses(); + tosa::registerTosaOptPasses(); } } // namespace mlir Index: mlir/lib/Dialect/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/CMakeLists.txt +++ mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) set(LLVM_OPTIONAL_SOURCES Index: mlir/lib/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_dialect_library(MLIRTosa + Utils/QuantUtils.cpp + IR/TosaOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + + DEPENDS + MLIRStandardOpsIncGen + MLIRTosaOpsIncGen + MLIRTosaStructsIncGen + MLIRTosaInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRStandard + MLIRCallInterfaces + MLIRControlFlowInterfaces + MLIRSideEffectInterfaces + MLIRViewLikeInterface + ) + +add_subdirectory(Transforms) Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -0,0 +1,273 @@ +//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This file implements the TOSA Specification: +// https://developer.mlplatform.org/w/tosa/ +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; +using namespace mlir::tosa; + +//===----------------------------------------------------------------------===// +// Tosa dialect structs and interface includes. +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" +#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc" + +namespace { +//===----------------------------------------------------------------------===// +// Dialect Function Inliner Interface. +//===----------------------------------------------------------------------===// +struct TosaInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks. + //===--------------------------------------------------------------------===// + + /// All operations can be inlined by default. + bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, + BlockAndValueMapping &map) const final { + return true; + } + + /// All regions with If and While parent operators can be inlined. + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + BlockAndValueMapping &map) const final { + return (isa(dest->getParentOp()) || + isa(dest->getParentOp())); + } +}; +} // end anonymous namespace + +//===----------------------------------------------------------------------===// +// TOSA control flow support. +//===----------------------------------------------------------------------===// + +/// Returns the while loop body. +Region &tosa::WhileOp::getLoopBody() { return body(); } + +bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) { + return !body().isAncestor(value.getParentRegion()); +} + +LogicalResult WhileOp::moveOutOfLoop(ArrayRef ops) { + if (ops.empty()) + return success(); + + Operation *tosaWhileOp = this->getOperation(); + for (auto *op : ops) + op->moveBefore(tosaWhileOp); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Tosa dialect initialization. +//===----------------------------------------------------------------------===// + +void TosaDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// TOSA Operator Verifiers. +//===----------------------------------------------------------------------===// + +template static LogicalResult verifyConvOp(T op) { + // All TOSA conv ops have an input() and weight(). + auto inputType = op.input().getType().template dyn_cast(); + auto weightType = op.weight().getType().template dyn_cast(); + + // Must be ranked tensor types + if (!inputType || !weightType) + return failure(); + + auto inputQType = + inputType.getElementType().template isa(); + auto weightQType = + weightType.getElementType().template isa(); + + // Either both must be quantized or both unquantized. + if (inputQType != weightQType) + return failure(); + + // Quantized type must have constructed the quantizationattr, and unquantized + // types should not have a quantizationattr. + if ((inputQType && !op.quantization_info()) || + (!inputQType && op.quantization_info())) + return failure(); + + return success(); +} + +//===----------------------------------------------------------------------===// +// TOSA Operator Quantization Builders. +//===----------------------------------------------------------------------===// + +/// This builder is called on all convolution operators except TransposeConv, +/// which has specialized output shape semantics. The builder also defines the +/// bitwidth of the output given the bit width of the input & weight content. +void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value weight, + Value bias, ArrayAttr pad, ArrayAttr stride, + ArrayAttr dilation) { + + result.addOperands({input, weight, bias}); + result.addAttribute("pad", pad); + result.addAttribute("stride", stride); + result.addAttribute("dilation", dilation); + + auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); + if (quantAttr) { + result.addAttribute("quantization_info", quantAttr); + result.addTypes( + buildConvOpResultTypeInfo(builder, outputType, input, weight)); + } else { + result.addTypes(outputType); + } +} + +/// Handles tosa.transpose_conv2d which has outpad and output shape attributes. +void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value weight, + Value bias, ArrayAttr outpad, + ArrayAttr stride, ArrayAttr dilation, + ArrayAttr outputShape) { + result.addOperands({input, weight, bias}); + result.addAttribute("out_pad", outpad); + result.addAttribute("stride", stride); + result.addAttribute("dilation", dilation); + result.addAttribute("out_shape", outputShape); + auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); + + if (quantAttr) { + result.addAttribute("quantization_info", quantAttr); + result.addTypes( + buildConvOpResultTypeInfo(builder, outputType, input, weight)); + } else { + result.addTypes(outputType); + } +} + +/// The tosa.fully_connected op has its own builder as it does not have +/// strides/dilation/padding. +void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value weight, + Value bias) { + + result.addOperands({input, weight, bias}); + auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); + if (quantAttr) { + result.addAttribute("quantization_info", quantAttr); + result.addTypes( + buildConvOpResultTypeInfo(builder, outputType, input, weight)); + } else { + result.addTypes(outputType); + } +} + +/// The tosa.matmul op is also intended to be generated where a fully_connected +/// op must be constructed where the weight is not a constant. In this case, +/// the fully_connected op must be expressed using matmul. +/// TODO: Add link to the leglization document explaining this. +void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, + Type outputType, Value a, Value b) { + result.addOperands({a, b}); + auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); + + if (quantAttr) { + result.addAttribute("quantization_info", quantAttr); + + auto inputType = a.getType().dyn_cast(); + assert(inputType && "Input must be a ranked tensor type!"); + + auto inputQType = inputType.getElementType() + .dyn_cast(); + assert(inputQType && "Tensor must have quantized datatype!"); + + unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); + + auto outputShapedType = outputType.dyn_cast(); + assert(outputShapedType && "Output must be a ranked tensor type"); + + auto outputShape = outputShapedType.getShape(); + + IntegerType accElementType; + if (inputBits == 16) + accElementType = builder.getIntegerType(48); + else + accElementType = builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + result.addTypes(accType); + } else { + result.addTypes(outputType); + } +} + +/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr +/// but avg_pool operator has its own builder as it has additional parameters +/// not part of the unary ops. +void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, + Type outputType, Value input, + ArrayAttr kernel, ArrayAttr stride, + ArrayAttr pad) { + result.addOperands(input); + result.addAttribute("kernel", kernel); + result.addAttribute("stride", stride); + result.addAttribute("pad", pad); + auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); + if (quantAttr) + result.addAttribute("quantization_info", quantAttr); + result.types.push_back(outputType); +} + +/// This builder is called on single-parameter unary operators that have scale +/// relationship between their input and output, expressed by the +/// UnaryOpQuantizationAttr. +void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, + Type outputType, Value input) { + result.addOperands(input); + auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); + if (quantAttr) + result.addAttribute("quantization_info", quantAttr); + result.types.push_back(outputType); +} + +/// This builder is called on TOSA pad operator that needs to create its own +/// OptionalAttr quantization_attr parameter to scale the padding values +/// correctly. +void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, + Type outputType, Value input, Value paddings) { + result.addOperands({input, paddings}); + auto quantAttr = buildPadOpQuantizationAttr(builder, input); + if (quantAttr) + result.addAttribute("quantization_info", quantAttr); + result.types.push_back(outputType); +} + +//===----------------------------------------------------------------------===// +// TOSA Operator Definitions. +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" Index: mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRTosaTransforms + TosaMakeBroadcastable.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms + + DEPENDS + MLIRTosaPassIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTosa + ) Index: mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -0,0 +1,272 @@ +//===- TosaMakeBroadcastable.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Insert reshape to binary op's input if needed to match rank +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::tosa; + +/// There are two potential ways implementing broadcast: +/// a. https://www.tensorflow.org/xla/broadcasting#formal_definition +/// b. https://numpy.org/doc/stable/user/basics.broadcasting.html +/// TBD: picking option (a) now. + +/// In this pass, we insert RESHAPE operators to increase the rank of the +/// lower rank operand as a first step in the broadcasting process. The TOSA +/// operators that support broadcast require that the rank of the operands +/// are equal. + +// Examples: +// If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1]. +// TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into +// [1, b, 1]. +// If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c]. +// If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c]. +// If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1]. +// If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c]. +// If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1]. +// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. +// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. + +static void computeReshapeOutput(ArrayRef higherRankShape, + ArrayRef lowerRankShape, + SmallVectorImpl &reshapeOutputShape) { + // Intialize new shapes with [1] * higherRank. + int64_t higherRank = higherRankShape.size(); + int64_t lowerRank = lowerRankShape.size(); + + reshapeOutputShape.assign(higherRank, 1); + + int64_t higherLeftIndex = 0; + int64_t higherRightIndex = higherRank; + int64_t lowerLeftIndex = 0; + int64_t lowerRightIndex = lowerRank; + int64_t higherRankDim, lowerRankDim; + + if (lowerRightIndex != 0 && higherRightIndex != 0) { + // Matches lower rank shape from right dimension first, until not + // matching high rank shape or reaching dimension 0. + while (true) { + higherRankDim = higherRankShape[higherRightIndex - 1]; + lowerRankDim = lowerRankShape[lowerRightIndex - 1]; + if (higherRankDim != lowerRankDim) + break; + + reshapeOutputShape[higherRightIndex - 1] = higherRankDim; + + if (higherRightIndex > 0) + higherRightIndex--; + + if (lowerRightIndex > 0) + lowerRightIndex--; + + if (higherRightIndex == 0 || lowerRightIndex == 0) + break; + } + if (lowerRightIndex != 0 && higherRightIndex != 0) { + // Matches lower rank shape from left dimension, until not matching + // high rank shape or reaching right index. + while (true) { + higherRankDim = higherRankShape[higherLeftIndex]; + lowerRankDim = lowerRankShape[lowerLeftIndex]; + if (higherRankDim != lowerRankDim) + break; + + reshapeOutputShape[higherLeftIndex] = higherRankDim; + + if (higherLeftIndex < higherRightIndex) + higherLeftIndex++; + + if (lowerLeftIndex < lowerRightIndex) + lowerLeftIndex++; + + if (higherLeftIndex == higherRightIndex || + lowerLeftIndex == lowerRightIndex) + break; + } + } + } +} + +/// Common code to reate the reshape op where necessary to make the rank of the +/// operations equal. Returns the updated input1 and input2 for the original +/// input. The caller is expected to use these to rewrite the original operator +/// with the RESHAPE now in the graph. +static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, + RankedTensorType outputType, Value input1, + Value input2, Value &outInput1, + Value &outInput2) { + + int64_t input1Rank = input1.getType().cast().getRank(); + int64_t input2Rank = input2.getType().cast().getRank(); + + Value higherTensorValue, lowerTensorValue; + // return if rank already match + if (input1Rank == input2Rank) { + return 1; + } else if (input1Rank > input2Rank) { + higherTensorValue = input1; + lowerTensorValue = input2; + } else { + higherTensorValue = input2; + lowerTensorValue = input1; + } + + ArrayRef outputRankShape = outputType.getShape(); + ArrayRef higherRankShape = + higherTensorValue.getType().cast().getShape(); + ArrayRef lowerRankShape = + lowerTensorValue.getType().cast().getShape(); + + // outputRank == higherRank == max(input1Rank, input2Rank) + assert(higherRankShape.size() == outputRankShape.size()); + + SmallVector reshapeOutputShape; + + computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape); + + auto reshapeInputType = lowerTensorValue.getType().cast(); + auto reshapeOutputType = RankedTensorType::get( + ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); + + auto reshapeLower = rewriter.create( + loc, reshapeOutputType, lowerTensorValue, + rewriter.getI64ArrayAttr(reshapeOutputShape)); + + if (input1Rank > input2Rank) { + outInput1 = higherTensorValue; + outInput2 = reshapeLower.getResult(); + } else { + outInput1 = reshapeLower.getResult(); + outInput2 = higherTensorValue; + } + + return 0; +} + +namespace { +template struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy tosaBinaryOp, + PatternRewriter &rewriter) const { + + Value input1 = tosaBinaryOp.input1(); + Value input2 = tosaBinaryOp.input2(); + Value output = tosaBinaryOp.getResult(); + auto outputType = output.getType().cast(); + + Value outInput1, outInput2; + if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, + input1, input2, outInput1, outInput2)) + return failure(); + + rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, outInput1, + outInput2); + + return success(); + } +}; + +// The MulOp has an extra parameter 'shift' not present in other elementwise +// binary ops, that necessitates special handling of its builder. +template <> +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp, + PatternRewriter &rewriter) const { + + Value input1 = tosaBinaryOp.input1(); + Value input2 = tosaBinaryOp.input2(); + int32_t shift = tosaBinaryOp.shift(); + Value output = tosaBinaryOp.getResult(); + auto outputType = output.getType().cast(); + + Value outInput1, outInput2; + if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, + input1, input2, outInput1, outInput2)) + return failure(); + + rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, + outInput1, outInput2, shift); + + return success(); + } +}; + +// The ArithmeticRightShiftOp has an extra parameter 'round' not present in +// other elementwise binary ops, that necessitates special handling of its +// builder. +template <> +struct ConvertTosaOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp, + PatternRewriter &rewriter) const { + + Value input1 = tosaBinaryOp.input1(); + Value input2 = tosaBinaryOp.input2(); + int32_t round = tosaBinaryOp.round(); + Value output = tosaBinaryOp.getResult(); + auto outputType = output.getType().dyn_cast(); + + Value outInput1, outInput2; + if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, + input1, input2, outInput1, outInput2)) + return failure(); + + rewriter.replaceOpWithNewOp( + tosaBinaryOp, outputType, outInput1, outInput2, round); + + return success(); + } +}; +} // end anonymous namespace + +namespace { +/// Pass that enables broadcast by making all input arrays have the same +/// number of dimensions. Insert RESHAPE operations to lower rank operand +struct TosaMakeBroadcastable + : public TosaMakeBroadcastableBase { +public: + void runOnFunction() override { + auto func = getFunction(); + OwningRewritePatternList patterns; + MLIRContext *ctx = func.getContext(); + // Add the generated patterns to the list. + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + applyPatternsAndFoldGreedily(func, std::move(patterns)); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::tosa::createTosaMakeBroadcastablePass() { + return std::make_unique(); +} Index: mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -0,0 +1,350 @@ +//===- QuantUtils.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains TOSA numerical support functions and quantization +// attribute builders. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" + +using namespace mlir; +using namespace mlir::tosa; + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 16-bit scaling. +void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 15)); + + // Can't be greater than 1.0. + assert(shiftedM <= (int64_t(1) << 15) && + "Shifted mantissa exceeds 16 signed bits"); + + if (shiftedM == (int64_t(1) << 15)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive and embed (1 << 15) into right + // shift bits. + shift = (-shift) + 15; + + assert(shiftedM <= std::numeric_limits::max() && + "Shifted mantissa exceeds 32-bit signed output type"); + + multiplier = static_cast(shiftedM); +} + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 32-bit scaling. +void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 31)); + + // Can't be greater than 1.0. + assert(shiftedM <= (int64_t(1) << 31) && + "Shifted mantissa exceeds 32 signed bits"); + if (shiftedM == (int64_t(1) << 31)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive, and embed (1 << 31) into right + // shift bits. + shift = (-shift) + 31; + + assert(shiftedM <= std::numeric_limits::max() && + "Shifted mantissa exceeds 32-bit signed output type"); + + multiplier = static_cast(shiftedM); +} + +/// Generates a quantized multiplier/shift from double. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth) { + + switch (scaleWidth) { + case 16: + computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); + return; + case 32: + computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); + return; + default: + assert(0 && "Unsupported Tosa quantized_scale regime specified!"); + } +} + +#define GET_UQTYPE(input_type) \ + ((input_type).getElementType().dyn_cast()) +#define GET_QTYPE(input_type) \ + ((input_type).getElementType().dyn_cast()) + +/// Method to build ConvOpQuantizationAttr, called from +/// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: +/// input_zp: input zeropoint +/// weight_zp: weight zeropoint. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder, + Value input, Value weight) { + + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); + + if (!inputType || !weightType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + auto weightPerTensorQType = GET_UQTYPE(weightType); + auto weightPerAxisQType = weightType.getElementType() + .dyn_cast(); + + // Weights must be either per-tensor quantized or per-axis quantized. + assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType) && + "Weights must be either per-tensor or per-axis quantized"); + + // Either all quantized or all not quantized. + assert(!((bool)inputQType ^ + ((bool)weightPerTensorQType || (bool)weightPerAxisQType)) && + "Inputs and weights must be all quantized or all not quantized"); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t weightZp = 0; + + if (weightPerTensorQType) { + weightZp = weightPerTensorQType.getZeroPoint(); + } else if (weightPerAxisQType) { + weightZp = weightPerAxisQType.getZeroPoints().front(); + } + + auto quantAttr = tosa::ConvOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds MatMulOpQuantizationAttr, called from +/// MatMulOpQuantInfoBuilder: +/// aZp: input a zeropoint +/// bZp: input b zeropoint. +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder, + Value a, Value b) { + + auto aType = a.getType().dyn_cast(); + auto bType = b.getType().dyn_cast(); + + if (!aType || !bType) + return nullptr; + + auto aQType = GET_UQTYPE(aType); + auto bQType = GET_UQTYPE(bType); + + // A and B are either all quantized or all not quantized. + assert(!((bool)aQType ^ (bool)bQType) && + "Matmul operands must be all quantized or all not quantized"); + + if (aQType) { + + int64_t aZp = aQType.getZeroPoint(); + int64_t bZp = bQType.getZeroPoint(); + + auto quantAttr = tosa::MatMulOpQuantizationAttr::get( + builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds UnaryOpQuantizationAttr +/// UnaryOpQuantInfoBuilder: +/// inputZp: input zeropoint +/// outputZp: output zeropoint. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(OpBuilder &builder, + Value input, + Type outputRawType) { + + auto inputType = input.getType().dyn_cast(); + auto outputType = outputRawType.dyn_cast(); + + if (!inputType || !outputType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + auto outputQType = GET_UQTYPE(outputType); + + // Either all quantized or all not quantized. + assert(!((bool)inputQType ^ (bool)outputQType) && + "Unary inputs/outputs must be all quantized or all not quantized"); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t outputZp = outputQType.getZeroPoint(); + + auto quantAttr = tosa::UnaryOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: +/// inputZp: input zeropoint. +PadOpQuantizationAttr buildPadOpQuantizationAttr(OpBuilder &builder, + Value input) { + + auto inputType = input.getType().dyn_cast(); + + if (!inputType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + + auto quantAttr = tosa::PadOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds output type for a quantized ConvOp with the right bitwidth. +/// This is called by the builder when dealing with quantized content. +Type buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, + Value weight) { + + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); + + assert(inputType && weightType && + "Could not extract input or weight tensors from Conv op"); + + auto inputQType = GET_QTYPE(inputType); + auto weightQType = GET_QTYPE(weightType); + + assert(inputQType && weightQType && + "Could not extract input or weight tensor types from Conv op"); + + unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); + unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); + + auto outputShapedType = outputType.dyn_cast(); + assert(outputShapedType && + "Could not extract output shape type from Conv op"); + + auto outputShape = outputShapedType.getShape(); + + IntegerType accElementType; + if (inputBits == 16 && weightBits == 8) + accElementType = builder.getIntegerType(48); + else + accElementType = builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + return accType; +} + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange) { + + quant::QuantizedType retType; + + auto convfunc = + quant::ExpressedToQuantizedConverter::forInputType(inputDType); + + auto minElems = minAttr.dyn_cast(); + auto maxElems = maxAttr.dyn_cast(); + + SmallVector min, max; + + // At least one is per-axis quantized elementsattr. + if (minElems || maxElems) { + // Must have the same number of elements. + if (minElems.getNumElements() != maxElems.getNumElements()) + return {}; + min.reserve(minElems.getNumElements()); + max.reserve(maxElems.getNumElements()); + for (auto i : minElems) + min.push_back(FloatAttr::getValueAsDouble(i)); + for (auto i : maxElems) + max.push_back(FloatAttr::getValueAsDouble(i)); + } else { // Just a single FP value. + auto minVal = minAttr.dyn_cast(); + if (minVal) + min.push_back(minVal.getValueAsDouble()); + else + return {}; + auto maxVal = maxAttr.dyn_cast(); + if (maxVal) + max.push_back(maxVal.getValueAsDouble()); + else + return {}; + } + + if (min.size() == max.size()) { + if (min.size() == 1) { // Per-tensor quantization with one min/max pair. + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], + narrowRange.getValue(), convfunc.expressedType, isSigned); + } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. + auto shape = inputDType.dyn_cast(); + if (!shape) + return {}; + if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], + max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); + } + } else { + return {}; + } + } else { + return {}; + } + + if (!retType) + return {}; + + return convfunc.convert(retType); +} + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange) { + + return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, + maxAttr, quantBits, filterQuantDim, + isSigned, narrowRange)); +} Index: mlir/test/Dialect/Tosa/broadcast.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/broadcast.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s + +// ----- +// CHECK-LABEL: broadcast0 +func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-NOT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +// ----- +// CHECK-LABEL: broadcast1 +func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32> + return %0 : tensor<2x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast2 +func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32> + return %0 : tensor<2x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast3 +func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32> + return %0 : tensor<2x1x1x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast4 +func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32> + return %0 : tensor<1x1x1x2xf32> +} + +// ----- +// CHECK-LABEL: broadcast5 +func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32> + return %0 : tensor<1x1x2x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast6 +func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast7 +func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32> + return %0 : tensor<17x16x1x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast8 +func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast9 +func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast10 +func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast13 +func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast14 +func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> + return %0 : tensor<17x16x1x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast15 +func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast16 +func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast17 +func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast18 +func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> { + // CHECK: add + %0 = "tosa.add"(%arg0, %arg1) : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32> + return %0 : tensor<14x15xf32> +} + +// ----- +// CHECK-LABEL: broadcast_mul +func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { + // CHECK: reshape + %0 = "tosa.mul"(%arg0, %arg1) {shift = 1 : i32 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> + return %0 : tensor<17x16x15x14xi32> +} + +// ----- +// CHECK-LABEL: broadcast_arithmetic_right_shift +func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { + // CHECK: reshape + %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> + return %0 : tensor<17x16x15x14xi32> +} + +// ----- +// CHECK-LABEL: broadcast_scalar +func @test_broadcast_scalar(%arg0: tensor, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> + return %0 : tensor<17x16x15x14xi32> +} Index: mlir/test/Dialect/Tosa/constrained_shapes.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/constrained_shapes.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + + +// ----- +// Uses argmax as canonical example to validate constrained TOSA tensor shapes. +// CHECK-LABEL: argmax +func @test_argmax(%arg0: tensor) -> tensor { + %0 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor) -> tensor + return %0 : tensor +} Index: mlir/test/Dialect/Tosa/inlining.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/inlining.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s -inline | FileCheck %s + +// These tests verify that regions with operations from TOSA dialect +// can be inlined. + +// CHECK-LABEL: func @inlined_if_fn +// Check that both the calls and the functions are eliminated after inlining: +// CHECK-NOT: @add +// CHECK-NOT: @sub +func @inlined_if_fn(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = call @add(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = call @sub(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} +func @add(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { + %0 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} +func @sub(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { + %0 = "tosa.sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: func @inlined_while_fn +func @inlined_while_fn(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> tensor<10xi32> { + // Check that calls are inlined and functions eliminated: + // CHECK-NOT: @while + %1:4 = "tosa.while_loop"(%arg0, %arg1, %arg2, %arg3) ( { + ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor<10xi32>): // no predecessors + %2 = call @while_cond_40(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }, { + ^bb0(%arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor<10xi32>): // no predecessors + %2:4 = call @while_body_50(%arg4, %arg5, %arg6, %arg7) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) + "tosa.yield"(%2#0, %2#1, %2#2, %2#3) : (tensor, tensor, tensor, tensor<10xi32>) -> () + }) : (tensor, tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) + return %1#3 : tensor<10xi32> +} +func @while_body_50(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> (tensor, tensor, tensor, tensor<10xi32>) attributes {sym_visibility = "private"} { + %1 = "tosa.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + %2 = "tosa.add"(%arg3, %1) : (tensor<10xi32>, tensor) -> tensor<10xi32> + return %1, %arg1, %arg2, %2: tensor, tensor, tensor, tensor<10xi32> +} +func @while_cond_40(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<10xi32>) -> tensor attributes {sym_visibility = "private"} { + %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "tosa.logical_not"(%0) : (tensor) -> tensor + return %1 : tensor +} Index: mlir/test/Dialect/Tosa/ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/ops.mlir @@ -0,0 +1,512 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + + +// ----- +// CHECK-LABEL: argmax +func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> { + %0 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<14x19xf32>) -> tensor<14xi32> + return %0 : tensor<14xi32> +} + +// ----- +// CHECK-LABEL: avg_pool2d +func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + return %0 : tensor<1x7x7x9xf32> +} + +// ----- +// CHECK-LABEL: conv2d +func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + +// ----- +// CHECK-LABEL: depthwise_conv2d +func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> { + %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %2 : tensor<1x4x4x8xf32> +} + +// ----- +// CHECK-LABEL: fully_connected +func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> { + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<14x19xf32>, tensor<19x28xf32>, tensor<28xf32>) -> tensor<14x28xf32> + return %0 : tensor<14x28xf32> +} + +// ----- +// CHECK-LABEL: test_matmul +func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> { + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32> + return %0 : tensor<14x28xf32> +} + +// ----- +// CHECK-LABEL: max_pool2d +func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + %0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +/// CHECK-LABEL: transpose_conv2d +func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 32, 32, 16], stride = [1, 1]} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- +// CHECK-LABEL: clamp +func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.clamp"(%arg0) {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: relu +func @test_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.reluN"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + + +// ----- +// CHECK-LABEL: sigmoid +func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.sigmoid"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: tanh +func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.tanh"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: add +func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: arithmetic_right_shift +func @test_arithmetic_right_shift(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = false } : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + + +// ----- +// CHECK-LABEL: bitwise_and +func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: bitwise_or +func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: bitwise_xor +func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_and +func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> { + %0 = "tosa.logical_and"(%arg0, %arg1) : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: logical_left_shift +func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_right_shift +func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_or +func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + %0 = "tosa.logical_or"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: logical_xor +func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> { + %0 = "tosa.logical_xor"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: maximum +func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.maximum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: minimum +func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.minimum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: mul +func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: pow +func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.pow"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: sub +func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: table +func @main(%arg0: tensor<64xi32>, %arg1: tensor<513x!quant.uniform>) -> tensor<64x!quant.uniform> { + %0 = "tosa.table"(%arg0, %arg1) : (tensor<64xi32>, tensor<513x!quant.uniform>) -> tensor<64x!quant.uniform> + return %0 : tensor<64x!quant.uniform> +} + +// ----- +// CHECK-LABEL: abs +func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.abs"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: bitwise_not +func @test_bitwise_not(%arg0: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> { + %0 = "tosa.bitwise_not"(%arg0) : (tensor<13x21x1xi32>) -> tensor<13x21x1xi32> + return %0 : tensor<13x21x1xi32> +} + +// ----- +// CHECK-LABEL: ceil +func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.ceil"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: clz +func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.clz"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: exp +func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: floor +func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.floor"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: log +func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: logical_not +func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { + %0 = "tosa.logical_not"(%arg0) : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1> + return %0 : tensor<1x21x3xi1> +} + +// ----- +// CHECK-LABEL: negate +func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.negate"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: reciprocal +func @test_reciprocal(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.reciprocal"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: rsqrt +func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.rsqrt"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: select +func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + + +// ----- +// CHECK-LABEL: equal +func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> { + %0 = "tosa.equal"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: greater +func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + %0 = "tosa.greater"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: greater_equal +func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> { + %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_all +func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { + %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xi1>) -> tensor<21x3xi1> + return %1 : tensor<21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_any +func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { + %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xi1>) -> tensor<21x3xi1> + return %1 : tensor<21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_max +func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { + %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_min +func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { + %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_product +func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { + %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_sum +func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: concat +func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32> + return %0 : tensor<26x21x3xf32> +} + +// ----- +// CHECK-LABEL: pad +func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> { + %0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: reshape +func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> { + %0 = "tosa.reshape"(%arg0) {new_shape = [1, 819]} : (tensor<13x21x3xf32>) -> tensor<1x819xf32> + return %0 : tensor<1x819xf32> +} + +// ----- +// CHECK-LABEL: reverse +func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: slice +func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { + %0 = "tosa.slice"(%arg0) {start = [6, 8, 0], size = [4, 11, 1]} : (tensor<13x21x3xf32>) -> tensor<4x11x1xf32> + return %0 : tensor<4x11x1xf32> +} + +// ----- +// CHECK-LABEL: tile +func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> { + %0 = "tosa.tile"(%arg0) {multiples = [3, 1, 2]} : (tensor<13x21x3xf32>) -> tensor<39x21x6xf32> + return %0 : tensor<39x21x6xf32> +} + +// ----- +// CHECK-LABEL: transpose +func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32> + return %1 : tensor<3x13x21xf32> +} + +// ----- +// CHECK-LABEL: gather +func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<26xi32>) -> tensor<26x21x3xi32> { + %0 = "tosa.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i64} : (tensor<13x21x3xi32>, tensor<26xi32>) -> tensor<26x21x3xi32> + return %0 : tensor<26x21x3xi32> +} + +// Test TBD +// DISABLED-CHECK-LABEL: resize +//func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { +// %0 = "tosa.const"() {value = dense<64> : tensor<2xi32>} : () -> tensor<2xi32> +// %1 = "tosa.resize"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x32x32x8xf32>, tensor<2xi32>) -> tensor<1x64x64x8xf32> +// return %1 : tensor<1x64x64x8xf32> +//} + +// ----- +// CHECK-LABEL: cast +func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> { + %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: cast2 +func @test_cast2(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> { + %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: cast3 +func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> { + %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: rescale +func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { + %0 = "tosa.rescale"(%arg0) {double_round = false, input_zp = 127 : i32, multiplier = [1073741824 : i32], output_zp = -1 : i32, per_channel = false, scale32 = true, shift = [30 : i32]} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: const +func @test_const(%arg0 : index) -> tensor<4xi32> { + %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- +// CHECK-LABEL: identity +func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> { + %0 = "tosa.identity"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: identityn +func @test_identityn(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + %0:2 = "tosa.identityn"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>) + return %0#0 : tensor<1xi32> +} + +// ----- +// CHECK-LABEL: placeholder +func @test_placeholder() -> tensor<1xi32> { + %0 = "tosa.placeholder"() : () -> tensor<1xi32> + return %0 : tensor<1xi32> +} + +// ----- +// CHECK-LABEL: cond_if +func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = "tosa.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = "tosa.sub"(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- +// CHECK-LABEL: while_loop +func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + %1:3 = "tosa.while_loop"(%0, %0, %arg0) ( { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): // no predecessors + %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor, tensor) -> tensor + %3 = "tosa.logical_not"(%2) : (tensor) -> tensor + "tosa.yield"(%3) : (tensor) -> () + }, { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): // no predecessors + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor + %4 = "tosa.reshape"(%2) {new_shape = [1]} : (tensor) -> tensor<1xi32> + %5 = "tosa.add"(%arg4, %4) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> + %6 = "tosa.add"(%arg2, %2) : (tensor, tensor) -> tensor + "tosa.yield"(%6, %3, %5) : (tensor, tensor, tensor<10xi32>) -> () + }) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) + return +} Index: mlir/test/Dialect/Tosa/quant-test.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/quant-test.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s + +// ----- +// CHECK-LABEL: test_build_qtype +func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> { + // CHECK: tosa.negate + %0 = "tosa.negate"(%arg0) : (tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> + return %0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> +} + +// ----- +// CHECK-LABEL: test_build_mult_and_shift +func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> { + // CHECK: tosa.conv2d + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 2, 2], dilation = [2, 1], stride = [1, 1], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32}} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> + return %0 : tensor<1x32x32x16x!quant.uniform> + +} Index: mlir/test/lib/Dialect/CMakeLists.txt =================================================================== --- mlir/test/lib/Dialect/CMakeLists.txt +++ mlir/test/lib/Dialect/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(Affine) add_subdirectory(SPIRV) add_subdirectory(Test) +add_subdirectory(Tosa) Index: mlir/test/lib/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/test/lib/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_dialect_library(MLIRTosaTestPasses + TosaTestPasses.cpp + + EXCLUDE_FROM_LIBMLIR + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms + + DEPENDS + MLIRTosaPassIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTosa + ) Index: mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp =================================================================== --- /dev/null +++ mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp @@ -0,0 +1,197 @@ +//===- TosaTestPasses.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test passes to exercise TOSA helper functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/PassDetail.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::tosa; + +// This transformation converts quantized uint8 to quantized int8. The +// construction of the new type invokes buildQTypeFromMinMax. Extracted from +// TOSA legalization infrastructure. +struct ConvertTosaNegateOp : public RewritePattern { + explicit ConvertTosaNegateOp(MLIRContext *context) + : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult +ConvertTosaNegateOp::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + + auto tosaNegateOp = cast(op); + + auto inputType = + tosaNegateOp.input1().getType().dyn_cast(); + // skip if input is not ranked tensor type + if (!inputType) + return failure(); + + // skip if it's not ranked tensor type. + auto outputType = + tosaNegateOp.getResult().getType().dyn_cast(); + if (!outputType) + return failure(); + + // skip if output is not per-tensor quantized type. + auto outputElementType = + outputType.getElementType().dyn_cast(); + if (!outputElementType) + return failure(); + + // skip if output is not uint8. + if (outputElementType.isSigned() || + outputElementType.getStorageTypeIntegralWidth() != 8) + return failure(); + + double typeRangeMin = double(outputElementType.getStorageTypeMin() - + outputElementType.getZeroPoint()) * + outputElementType.getScale(); + double typeRangeMax = double(outputElementType.getStorageTypeMax() - + outputElementType.getZeroPoint()) * + outputElementType.getScale(); + bool narrow_range = outputElementType.getStorageTypeMin() == 1 ? true : false; + + auto dstQConstType = RankedTensorType::get( + outputType.getShape(), + buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(), + rewriter.getF64FloatAttr(typeRangeMin), + rewriter.getF64FloatAttr(typeRangeMax), + rewriter.getI32IntegerAttr( + outputElementType.getStorageTypeIntegralWidth()), + 0, true /* signed */, + rewriter.getBoolAttr(narrow_range))); + + ElementsAttr inputElems; + if (!matchPattern(tosaNegateOp.input1(), m_Constant(&inputElems))) + return failure(); + + auto newConstOp = + rewriter.create(op->getLoc(), dstQConstType, inputElems); + auto newNegateOp = rewriter.create( + op->getLoc(), dstQConstType, newConstOp.getResult()); + + rewriter.replaceOp(op, {newNegateOp.getResult()}); + return success(); +} + +// This transformation modifies the quantized output of a test conv2d input and +// appends a TOSA rescale after it. The rescale op requires the invocation of +// computeMultiplierAndShift. From TOSA legalization infrastructure. +struct ConvertTosaConv2DOp : public RewritePattern { + explicit ConvertTosaConv2DOp(MLIRContext *context) + : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult +ConvertTosaConv2DOp::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + + auto tosaConv2DOp = cast(op); + + auto inputType = + tosaConv2DOp.input().getType().dyn_cast(); + + // skip if input is not ranked tensor type + if (!inputType) + return failure(); + + auto weightType = + tosaConv2DOp.weight().getType().dyn_cast(); + + // skip if wt is not ranked tensor type + if (!weightType) + return failure(); + + // skip if it's not ranked tensor type. + auto outputType = + tosaConv2DOp.getResult().getType().dyn_cast(); + if (!outputType) + return failure(); + + auto inputQType = + inputType.getElementType().dyn_cast(); + auto weightQType = + weightType.getElementType().dyn_cast(); + auto outputQType = + outputType.getElementType().dyn_cast(); + + // Works on quantized type only. + if (!(inputQType && weightQType && outputQType)) + return failure(); + + auto newTosaConv2DOpType = + RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32)); + + auto newTosaConv2DOp = rewriter.create( + op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.input(), + tosaConv2DOp.weight(), tosaConv2DOp.bias(), tosaConv2DOp.pad(), + tosaConv2DOp.stride(), tosaConv2DOp.dilation()); + + // Create rescale to quantized type + double inputScale = inputQType.getScale(); + double weightScale = weightQType.getScale(); + double outputScale = outputQType.getScale(); + int64_t outputZp = outputQType.getZeroPoint(); + + double opTensorScale = (inputScale * weightScale) / outputScale; + + int32_t multiplier; + int32_t shift; + + // Obtain the quantized scale = multiplier and shift. + computeMultiplierAndShift(opTensorScale, multiplier, shift, 32); + + auto newTosaRescaleOp = rewriter.create( + op->getLoc(), outputType, newTosaConv2DOp.getResult(), + rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp), + rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}), + rewriter.getBoolAttr(true), rewriter.getBoolAttr(true), + rewriter.getBoolAttr(false)); + + rewriter.replaceOp(op, {newTosaRescaleOp.getResult()}); + return success(); +} + +namespace { + +struct TosaTestQuantUtilAPI + : public TosaTestQuantUtilsBase { + void runOnFunction() override; +}; + +void TosaTestQuantUtilAPI::runOnFunction() { + OwningRewritePatternList patterns; + auto *ctx = &getContext(); + auto func = getFunction(); + + patterns.insert(ctx); + patterns.insert(ctx); + applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // anonymous namespace + +std::unique_ptr mlir::tosa::createTosaTestQuantUtilAPIPass() { + return std::make_unique(); +}