Index: mlir/docs/Rationale/RationaleTOSADialect.md =================================================================== --- /dev/null +++ mlir/docs/Rationale/RationaleTOSADialect.md @@ -0,0 +1,89 @@ +# TOSA Dialect Rationale + +The MLIR TOSA dialect implements the [TOSA +specification](https://developer.mlplatform.org/w/tosa/). This document +describes the decision process for how TOSA expresses operators in +high level dialects. + +TOSA was developed after parallel efforts to rationalize the top-down picture from multiple high-level frameworks, as well as a bottom-up view of different hardware target concerns (CPU, GPU and NPU), and reflects a set of choices that attempt to manage both sets of requirements. + +## TOSA and Tensor Level Expressiveness + +TOSA endeavors to provide an operator set that tries to fulfil the following expressivenes goals at the *tensor level of abstraction* : + +### Complete + +This is driven by the top-down perspective, needing to express as much of multiple high level frameworks fully in TOSA, as possible. This was originally done from an operator frequency analysis done upon dozens of high level networks in different frameworks, to select the most frequently occuring ones and establish a common set of tensor-level operators that could express them. + +TOSA categorizes its operator set into classes and attempts to address major functional operations at the tensor level, including compute, reduction, elementwise transformations, comparison and control flow. + +### Minimal + +This takes the bottom-up approach - keep the TOSA operator set minimal in order to bound the design of hardware, operator kernels, code generation strategies and associated considerations that effect the executability of TOSA content. + +In this regard TOSA seeks to avoid creating compound operators, instead leaving it to compiler backend to fuse multiple TOSA ops if required. This choice also benefits the numerical precision goal, since it is easier to fuse the numerical functionality of successive operators, than to split the numerical functionality of a compound operator. + +### Numerical Precision + +TOSA began as a means to address operator-level numerical precision for code generation and hardware development. It therefore incorporates precision detail into the operator set. + +In this regard, TOSA operators are best understood as a combination of the visible quantization information embedded within an operation, together with the functional information about how that information is used, as described in the specification of the operation. + +## TOSA Operator Rationale + +The general basis of selection of the operator set that constitutes TOSA is described in the TOSA specification document under Section 1.3 Operator Selection. Explanation of the thinking behind some operators is listed here: + +### IDENTITYN + +tosa.IDENTITYN is used to form a list of Operator results during +lowering of operations such as tf.Split from a sequence of tosa.SLICE +ops. If there are alternate ways to express this lowering without the +tosa.IDENTITYN op, the tosa.IDENTITYN op could be removed from TOSA. + +``` +Value lower_split_op(Value %value, size_t axis, size_t +num_split) { Value %output[] + + size_t slice_size = %value.shape[axis] / num_split + + for (int i = 0; i < num_split; i++) { + vector begin_vals, size_vals + + for (int j = 0; j < %value.rank; j++) { + if (j == axis) { + begin_vals.push_back(slice_size * i) + size_vals.push_back(slice_size) + } else { + begin_vals.push_back(0) + size_vals.push_bac(%value.shape[j]) + } + + %output[i] = tosa.SLICE(%value) {start=begin_vals, size=size_vals} (tensor<%value.type>) -> tensor + } + + } + + %output_list = tosa.IDENTITYN(%output) (tensor<%output:*.type>) -> tensor<%output_list:*.type> + return %output_list +} +``` + +### COND\_IF and WHILE\_LOOP + +Several neural networks express conditional control flow at the tensor level. A survey of multiple high level frameworks indicated that conditional if and a loop construct are common in all major frameworks, with some variation. Since TOSA endeavors to be complete in expressing tensor level functionality including control flow, it implements these constructs. + +MLIR has dialects like SCF that express control flow for its own code generation rationale. Where a compiler uses SCF it can choose to directly express legalization in SCF.if for example. Control over the actual legalization expression lies with the compiler designer in this case. + +## Using TOSA In A Compiler + +The TOSA specification describes each operator in functional detail. It is expected that compilers that use TOSA will use its builders to construct the operators so that the quantization information for the operator is correctly generated. + +The functional steps described in the pseudocode of the specification enables the construction of code generation for that operation, or designs on the design of underlying hardware. The functional pseudocode also describes how the quantization parameters are utilized within the operation. + +### Quantization Parameters in Ops vs Tensors + +TOSA uses the quantization parameters embedded in the input and output tensors to construct the quantization attributes that sit within the operator. Once these attributes are constructed, the quantization information within the tensors are no longer necessary for code generation. + +This enables the tensors to be subsequently interpreted simply as contiguous buffers containing raw data, with no 'meta information' in the form of the quantization_type. Precision related manipulation of the input or output are instead described by the operator itself which describes, for example, when the zero point is applied, or when the scale multiplication is done. + +However, TOSA does *not* eliminate the existing quantization type information within the tensors; this leaves the choice of how to handle quantization information, to later backend code generation steps. Index: mlir/include/mlir/Dialect/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/CMakeLists.txt +++ mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,4 +13,5 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) Index: mlir/include/mlir/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) Index: mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaOps.h.inc -gen-op-decls) +mlir_tablegen(TosaOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTosaOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaStructs.h.inc -gen-struct-attr-decls) +mlir_tablegen(TosaStructs.cpp.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRTosaStructsIncGen) + + +set(LLVM_TARGET_DEFINITIONS TosaInterfaces.td) +mlir_tablegen(TosaInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TosaInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRTosaInterfaceIncGen) + +add_mlir_doc(TosaOps -gen-op-doc TosaOps Dialects/) Index: mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td @@ -0,0 +1,27 @@ +//===-- TosaInterfaces.td - TOSA dialect interfaces --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dialect op interfaces for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_OP_INTERFACES +#define TOSA_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def TosaOpInterface : OpInterface<"TosaOp"> { + let description = [{ + Implements interfaces implemented by ops that correspond to the Tosa specification. + }]; + + let methods = []; + +} + +#endif Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -0,0 +1,230 @@ +//===-- TosaOpBase.td - TOSA dialect op builders -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the common definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + + +#ifdef TOSA_OP_BASE +#else +#define TOSA_OP_BASE + +// Quantization attributes used across TOSA operators. Quantization attributes feed +// numerical precision parameters to the functional implementation of TOSA operators. +// The functional behavior is defined in the TOSA specification maintained at +// https://developer.mlplatform.org/w/tosa/ . TOSA leverages MLIR's built in +// quantization support: https://mlir.llvm.org/docs/Quantization/ , and supports +// uniform quantization. Depending on datatype, asymmetric and symmetric quantization +// are supported. The types themselves are described in TosaTypesBase.td . + +// This quantization attribute expresses numerical behavior of operators where the +// operator has a numerical relationship between a single input and output. +// For example: tosa.negate. +def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"output_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Unary Ops."; +} + +// There is no explicit BinaryOpQuantizationAttr for 2-input/1-output ops. In this +// case, tosa.rescale is used to express the inputs to the same scale. +// TODO: Upload WIP legalization document describing this construction by example. + +// This quantization attribute holds input and weight zero point. Both the ConvOp and +// MatMulOp QuantizationAttrs follow a common design semantic where their own quantization +// attribute only expresses the numerical behavior at the inputs. The scaling of their +// accumulator output is done using an explicit tosa.rescale operator that scales the +// accumulator result to the output scale. +def Tosa_ConvOpQuantizationAttr : StructAttr<"ConvOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"weight_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Convolution Ops."; +} + +def Tosa_MatMulOpQuantizationAttr : StructAttr<"MatMulOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"a_zp", I32Attr>, + StructFieldAttr<"b_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Convolution Ops."; +} + +// This attribute holds input zero point correction applied to the padding zeros to ensure +// numerical accuracy in the subsequent TOSA operations. Its functional application is +// described in the tosa.pad() operator description in the specification. +def Tosa_PadOpQuantizationAttr : StructAttr<"PadOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Pad Ops."; +} + +// TOSA Quantization Builders. + +// This builder is called on all convolution operators except for TransposeConv, which has +// specialized output shape semantics. The builder also defines the bitwidth of the output +// given the bit width of the input & weight content. +def Tosa_ConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias, "ArrayAttr":$pad, "ArrayAttr":$stride, "ArrayAttr":$dilation), + [{ + $_state.addOperands(input); + $_state.addOperands(weight); + $_state.addOperands(bias); + $_state.addAttribute("pad", pad); + $_state.addAttribute("stride", stride); + $_state.addAttribute("dilation", dilation); + + auto quantAttr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + weight); + if (quantAttr) { + $_state.addAttribute("quantization_info", quantAttr); + $_state.addTypes(mlir::tosa::buildConvOpResultTypeInfo($_builder, outputType, input, weight)); + } + else { + $_state.addTypes(outputType); + } + }]>; + +// Handles tosa.transpose_conv2d which has an outpad and output shape attribute. +def Tosa_TransConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias, "ArrayAttr":$outpad, "ArrayAttr":$stride, "ArrayAttr":$dilation, "ArrayAttr":$outputShape), + [{ + $_state.addOperands(input); + $_state.addOperands(weight); + $_state.addOperands(bias); + $_state.addAttribute("out_pad", outpad); + $_state.addAttribute("stride", stride); + $_state.addAttribute("dilation", dilation); + $_state.addAttribute("output_shape", outputShape); + auto quantAttr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + weight); + + if (quantAttr) { + $_state.addAttribute("quantization_info", quantAttr); + $_state.addTypes(mlir::tosa::buildConvOpResultTypeInfo($_builder, outputType, input, weight)); + } + else { + $_state.addTypes(outputType); + } + }]>; + +// The tosa.fully_connected op has its own builder as it does not have +// strides/dilation/padding. +def Tosa_FCOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$weight, "Value":$bias), + [{ + $_state.addOperands(input); + $_state.addOperands(weight); + $_state.addOperands(bias); + auto quantAttr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + weight); + if (quantAttr) { + $_state.addAttribute("quantization_info", quantAttr); + $_state.addTypes(mlir::tosa::buildConvOpResultTypeInfo($_builder, outputType, input, weight)); + } + else { + $_state.addTypes(outputType); + } + }]>; + +// The tosa.matmul op is also intended to be generated where a fully_connected +// op must be constructed where the weight is not a constant. In this case, +// the fully_connected op must be expressed using matmul. +// TODO: Add link to the leglization document explaining this. +def Tosa_MatMulOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$a, "Value":$b), + [{ + $_state.addOperands(a); + $_state.addOperands(b); + auto quantAttr = mlir::tosa::buildMatMulOpQuantizationAttr($_builder, a, b); + + if (quantAttr) { + $_state.addAttribute("quantization_info", quantAttr); + + auto inputType = a.getType().dyn_cast(); + assert(inputType); + + auto inputQType = inputType.getElementType().dyn_cast(); + assert(inputQType); + + unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); + + auto outputShapedType = outputType.dyn_cast(); + assert ( outputShapedType ); + + auto outputShape = outputShapedType.getShape(); + + IntegerType accElementType; + if(inputBits == 16) + accElementType = $_builder.getIntegerType(48); + else + accElementType = $_builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + $_state.addTypes(accType); + } + else { + $_state.addTypes(outputType); + } + }]>; + +// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr but +// the avg_pool operator has its own builder as it has additional parameters not part +// of the unary ops. +def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "ArrayAttr":$kernel_size, "ArrayAttr":$strides, "ArrayAttr":$padding), + [{ + $_state.addOperands(input); + $_state.addAttribute("kernel", kernel_size); + $_state.addAttribute("stride", strides); + $_state.addAttribute("pad", padding); + auto quantAttr = mlir::tosa::buildUnaryOpQuantizationAttr($_builder, + input, + outputType); + if (quantAttr) + $_state.addAttribute("quantization_info", quantAttr); + $_state.types.push_back(outputType); + }]>; + +// This builder is called on single-parameter unary operators that have a scale +// relationship between their input and output, expressed by the UnaryOpQuantizationAttr. +def Tosa_UnaryOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input), + [{ + $_state.addOperands(input); + auto quantAttr = mlir::tosa::buildUnaryOpQuantizationAttr($_builder, + input, + outputType); + if (quantAttr) + $_state.addAttribute("quantization_info", quantAttr); + $_state.types.push_back(outputType); + }]>; + +// This builder is called on the TOSA pad operator that needs to create its own +// OptionalAttr quantization_attr parameter to scale the padding values correctly. +def Tosa_PadOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$outputType, "Value":$input, "Value":$paddings), + [{ + $_state.addOperands(input); + $_state.addOperands(paddings); + auto quantAttr = mlir::tosa::buildPadOpQuantizationAttr($_builder, + input); + if (quantAttr) + $_state.addAttribute("quantization_info", quantAttr); + $_state.types.push_back(outputType); + }]>; + +// TOSA Operator. + +class Tosa_Op traits = []> : + Op { + +} + +// Specify traits of operators. + +#endif // TOSA_OP_BASE Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -0,0 +1,53 @@ +//===-- TosaOps.h - TOSA dialect operation definitions ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_IR_TOSA_OPS_H +#define DIALECT_TOSA_IR_TOSA_OPS_H + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tosa/IR/TosaTraits.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" + +#include "mlir/Dialect/Tosa/IR/TosaStructs.h.inc" + +namespace mlir { +namespace tosa { + +//===----------------------------------------------------------------------===// +// TOSA Dialect +//===----------------------------------------------------------------------===// +class TosaDialect : public Dialect { + +public: + explicit TosaDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "tosa"; } +}; + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" + +} // end namespace tosa +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" + +#endif // TOSA_OPS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -0,0 +1,1700 @@ +//===-- TosaOps.td - TOSA dialect operation definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the operation set for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOSA_OPS +#else +#define TOSA_OPS + +include "mlir/IR/OpBase.td" + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" + +include "mlir/Dialect/Tosa/IR/TosaTypesBase.td" + +//===----------------------------------------------------------------------===// +// The TOSA Dialect. +//===----------------------------------------------------------------------===// +def Tosa_Dialect : Dialect { + let name = "tosa"; + + let description = [{ + The Tensor Operator Set Architecture (TOSA) dialect. + + This dialect implements the TOSA standard described at + https://developer.mlplatform.org/w/tosa/ . + + Tensor Operator Set Architecture (TOSA) provides a set of whole-tensor operations + commonly employed by Deep Neural Networks. The intent is to enable a variety of + implementations running on a diverse range of processors, with the results at the + TOSA level consistent across those implementations. Applications or frameworks + which target TOSA can therefore be deployed on a wide range of different processors, + such as CPUs or GPUs, with defined accuracy and compatibility constraints. Most + operators from the common ML frameworks should be expressible in TOSA. It is + expected that there will be tools to lower from the ML frameworks into TOSA. + + }]; + + let cppNamespace = "mlir::tosa"; +} + +#ifdef TOSA_OP_BASE +#else +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" +#endif + + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.2 +// Operator Class: Tensor Data Engine Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: argmax +//===----------------------------------------------------------------------===// +def Tosa_ArgMaxOp : Tosa_Op<"argmax", [NoSideEffect]> { + let summary = "Perform argmax on the input."; + + let description = [{ + This returns the index with the largest value across the given axis of the input tensor. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D: $input, + I64Attr: $axis); + + let results = (outs Tosa_Tensor1Dto4D: $output); + +} + +//===----------------------------------------------------------------------===// +// Operator: avg_pool2d +//===----------------------------------------------------------------------===// +def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs an average pooling over the given input tensor. A sliding window of size + given by is passed over the input tensor, with the mean value being placed + in the output tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + + Tosa_IntArrayAttr2:$kernel, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr4:$pad, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: conv2d +//===----------------------------------------------------------------------===// +def Tosa_Conv2DOp : Tosa_Op<"conv2d", [NoSideEffect]> { + let summary = [{ + 2D Convolution Operator + }]; + + let description = [{ + Performs a 2D convolution over the given tensor input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + Tosa_Tensor4D:$weight, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr4:$pad, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr2:$dilation, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; + +} + +//===----------------------------------------------------------------------===// +// Operator: conv3d +//===----------------------------------------------------------------------===// +def Tosa_Conv3DOp : Tosa_Op<"conv3d", [NoSideEffect]> { + let summary = [{ + 3D Convolution operator + }]; + + let description = [{ + Performs a 3D convolution over the given input tensor. + }]; + + let arguments = (ins + Tosa_Tensor5D:$input, + Tosa_Tensor5D:$weight, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr6:$pad, + Tosa_IntArrayAttr3:$stride, + Tosa_IntArrayAttr3:$dilation, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor5D:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; + +} + +//===----------------------------------------------------------------------===// +// Operator: depthwise_conv2d +//===----------------------------------------------------------------------===// +def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [NoSideEffect]> { + let summary = [{ + Depthwise 2D Convolution operator + }]; + + let description = [{ + Performs 2D convolutions separately over each channel of the given tensor input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + Tosa_Tensor4D:$weight, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr4:$pad, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr2:$dilation, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; + +} + +//===----------------------------------------------------------------------===// +// Operator: fully_connected +//===----------------------------------------------------------------------===// +def Tosa_FullyConnectedOp : Tosa_Op<"fully_connected", [NoSideEffect]> { + let summary = "Fully Connected operator"; + + let description = [{ + Performs a fully connected network. + }]; + + let arguments = (ins + Tosa_Tensor2D:$input, + Tosa_Tensor2D:$weight, + Tosa_Tensor1D:$bias, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor2D:$output + ); + + let builders = [Tosa_FCOpQuantInfoBuilder]; + + let verifier = [{ return ::verifyConvOp(*this); }]; + +} + +//===----------------------------------------------------------------------===// +// Operator: matmul +//===----------------------------------------------------------------------===// +def Tosa_MatMulOp : Tosa_Op<"matmul", [NoSideEffect]> { + let summary = "Matrix multiplication with bias"; + + let description = [{ + Performs a two dimensional matrix multiplication. This allows both inputs to be activations, + rather than reserving weights as an attribute in the FULLY_CONNECTED operator. + }]; + + let arguments = (ins + Tosa_Tensor2D:$a, + Tosa_Tensor2D:$b, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor2D:$c + ); + + let builders = [Tosa_MatMulOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: max_pool2d +//===----------------------------------------------------------------------===// +def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs a max pooling over the given input tensor. A sliding window of size given by + is passed over the input tensor, with the maximum value being placed in the + output tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + + Tosa_IntArrayAttr2:$kernel, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr4:$pad + ); + + let results = (outs + Tosa_Tensor4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose_conv2d +//===----------------------------------------------------------------------===// +def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [NoSideEffect]> { + let summary = [{ + Transpose 2D Convolution operator. + }]; + + let description = [{ + Performs a 2D transposed convolution over the given tensor input, using the weights tensor. + }]; + + let arguments = (ins + Tosa_Tensor4D:$input, + Tosa_Tensor4D:$filter, + Tosa_Tensor1D:$bias, + + Tosa_IntArrayAttr2:$out_pad, + Tosa_IntArrayAttr2:$stride, + Tosa_IntArrayAttr2:$dilation, + Tosa_IntArrayAttrUpto4:$out_shape, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor4D:$output + ); + + let builders = [Tosa_TransConvOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.3 +// Operator Class: Activation Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: clamp +//===----------------------------------------------------------------------===// +def Tosa_ClampOp : Tosa_Op<"clamp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes clamp(features, min, max)."; + + let description = [{ + Clamp to an arbitrary minimum and maximum value. Note that the maximum and minimum values are + specified as signed quantized values, no scaling happens before or after this operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + // TODO: Merge the int/fp types into a union to match spec + I64Attr:$min_int, + I64Attr:$max_int, + F32Attr:$min_fp, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reluN +//===----------------------------------------------------------------------===// +def Tosa_ReluNOp : Tosa_Op<"reluN", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes rectified linear: `max(features, N)`."; + + let description = [{ + ReLU with a scalar maximum value. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + // TODO: Merge the int/fp types into a union to match spec + I64Attr:$max_int, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: sigmoid +//===----------------------------------------------------------------------===// +def Tosa_SigmoidOp : Tosa_Op<"sigmoid", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise sigmoid of input."; + + let description = [{ + Sigmoid function: output = 1 / (1 + exp(-input)) + For quantized integer data types, the TABLE operator should be used instead with the following definition. + The sigmoid table has 513 entries each of 16-bit precision and covering the input range -16.0 to +16.0 + in steps of 1/16. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: tanh +//===----------------------------------------------------------------------===// +def Tosa_TanhOp : Tosa_Op<"tanh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise hyperbolic tangent of input"; + + let description = [{ + Parameterized hyperbolic tangent. + For quantized integer data types, the TABLE operator should be used instead with the following definition. + The tanh_table has 513 entries each of 16-bit precision and covering the input range -8.0 to +8.0 in steps of 1/32. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.4 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise binary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: add +//===----------------------------------------------------------------------===// +def Tosa_AddOp : Tosa_Op<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise addition operator"; + + let description = [{ + Elementwise addition of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: arithmetic_right_shift +//===----------------------------------------------------------------------===// +def Tosa_ArithmeticRightShiftOp : Tosa_Op<"arithmetic_right_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Arithmetic Right Shift"; + + let description = [{ + Elementwise arithmetic right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2, + BoolAttr:$round + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_and +//===----------------------------------------------------------------------===// +def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise AND operator"; + + let description = [{ + Elementwise bitwise AND of input tensor 0 and input tensor 1. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_or +//===----------------------------------------------------------------------===// +def Tosa_BitwiseOrOp : Tosa_Op<"bitwise_or", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise OR operator"; + + let description = [{ + Elementwise bitwise OR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_xor +//===----------------------------------------------------------------------===// +def Tosa_BitwiseXorOp : Tosa_Op<"bitwise_xor", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise XOR operator"; + + let description = [{ + Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_and +//===----------------------------------------------------------------------===// +def Tosa_LogicalAndOp : Tosa_Op<"logical_and", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x AND y element-wise."; + + let description = [{ + Elementwise logical AND of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$input1, + I1Tensor:$input2 + ); + + let results = (outs + I1Tensor:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_left_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalLeftShiftOp : Tosa_Op<"logical_left_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Logical Left Shift"; + + let description = [{ + Elementwise left shift of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_right_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalRightShiftOp : Tosa_Op<"logical_right_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Logical Right Shift"; + + let description = [{ + Elementwise logical right shift of input1 by the amount specified in input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_or +//===----------------------------------------------------------------------===// +def Tosa_LogicalOrOp : Tosa_Op<"logical_or", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x OR y element-wise."; + + let description = [{ + Elementwise logical OR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$input1, + I1Tensor:$input2 + ); + + let results = (outs + I1Tensor:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_xor +//===----------------------------------------------------------------------===// +def Tosa_LogicalXorOp : Tosa_Op<"logical_xor", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x XOR y element-wise."; + + let description = [{ + Elementwise logical XOR of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$input1, + I1Tensor:$input2 + ); + + let results = (outs + I1Tensor:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: maximum +//===----------------------------------------------------------------------===// +def Tosa_MaximumOp : Tosa_Op<"maximum", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise Maximum"; + + let description = [{ + Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: minimum +//===----------------------------------------------------------------------===// +def Tosa_MinimumOp : Tosa_Op<"minimum", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise Minimum"; + + let description = [{ + Elementwise minimum of input tensor 0 and input tensor 1. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: mul +//===----------------------------------------------------------------------===// +def Tosa_MulOp : Tosa_Op<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Multiplication operator"; + + let description = [{ + Elementwise multiplication (Hadamard product) of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2, + I32Attr:$shift + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: pow +//===----------------------------------------------------------------------===// +def Tosa_PowOp : Tosa_Op<"pow", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Computes the power of one value to another."; + + let description = [{ + Elementwise input tensor 0 value raised to the power of input 1 tensor. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + Tosa_TensorUpto4D:$z + ); +} + +//===----------------------------------------------------------------------===// +// Operator: sub +//===----------------------------------------------------------------------===// +def Tosa_SubOp : Tosa_Op<"sub", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise subtraction operator"; + + let description = [{ + Elementwise subtraction of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: table +//===----------------------------------------------------------------------===// +def Tosa_TableOp : Tosa_Op<"table", [NoSideEffect]> { + let summary = "Table lookup op"; + + let description = [{ + Interpolated table lookup operation. Input values are scaled to create a fixed-point 9.7 value. + The high 9 bits are used to index into the table. The fractional bits are used to interpolate + based on the looked up value and the index+1 value in the table. The TABLE operator then returns + a 16.7 interpolated value. Note that there must be 513 values to handle the full range of inputs. + + The TABLE operator is expected to be used as follows: + • A RESCALE node is expected before the TABLE operator to scale the input to a full int16_t range + for the table lookup + • If an int16_t result is required then follow the TABLE operator with a RESCALE with a right + shift of 7 + • If an int8_t result is required then follow the TABLE operator with a RESCALE with a right + shift of 15 + }]; + + let arguments = (ins + Tosa_TensorUpto4D: $input, + Tosa_Tensor1D: $table + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.5 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise unary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: abs +//===----------------------------------------------------------------------===// +def Tosa_AbsOp : Tosa_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise abs op"; + + let description = [{ + Elementwise absolute value operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_not +//===----------------------------------------------------------------------===// +def Tosa_BitwiseNotOp : Tosa_Op<"bitwise_not", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Bitwise NOT operator"; + + let description = [{ + Elementwise bitwise NOT of input tensor. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs Tosa_TensorUpto4D:$output); + +} + +//===----------------------------------------------------------------------===// +// Operator: ceil +//===----------------------------------------------------------------------===// +def Tosa_CeilOp : Tosa_Op<"ceil", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise ceil op"; + + let description = [{ + Elementwise ceiling operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: clz +//===----------------------------------------------------------------------===// +def Tosa_ClzOp : Tosa_Op<"clz", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise count leading zero op"; + + let description = [{ + Elementwise count leading zeros operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: exp +//===----------------------------------------------------------------------===// +def Tosa_ExpOp : Tosa_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise exp op"; + + let description = [{ + Elementwise e to the x operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: floor +//===----------------------------------------------------------------------===// +def Tosa_FloorOp : Tosa_Op<"floor", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise floor op"; + + let description = [{ + Elementwise floor operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: log +//===----------------------------------------------------------------------===// +def Tosa_LogOp : Tosa_Op<"log", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise log op"; + + let description = [{ + Elementwise natural logarithm operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_not +//===----------------------------------------------------------------------===// +def Tosa_LogicalNotOp : Tosa_Op<"logical_not", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Returns the truth value of NOT x element-wise."; + + let description = [{ + Elementwise logical NOT of input. + }]; + + let arguments = (ins + I1Tensor:$input1 + ); + + let results = (outs + I1Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: negate +//===----------------------------------------------------------------------===// +def Tosa_NegateOp : Tosa_Op<"negate", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise negate op"; + + let description = [{ + Elementwise negation operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + OptionalAttr:$quantization_info + ); + let results = (outs Tosa_TensorUpto4D:$output + ); + + let builders = [Tosa_UnaryOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: reciprocal +//===----------------------------------------------------------------------===// +def Tosa_ReciprocalOp : Tosa_Op<"reciprocal", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise reciprocal op"; + + let description = [{ + Elementwise reciprocal operation. For integer operation, a TABLE should be used + with the appropriate ranges. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + let results = (outs Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: rsqrt +//===----------------------------------------------------------------------===// +def Tosa_RsqrtOp : Tosa_Op<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise 1/sqrt op"; + + let description = [{ + Elementwise reciprocal square root operation. For integer operation, a TABLE should be + used with the appropriate ranges. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + let results = (outs Tosa_TensorUpto4D:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.6 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise ternary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: select +//===----------------------------------------------------------------------===// +def Tosa_SelectOp : Tosa_Op<"select", [NoSideEffect]> { + let summary = "Elementwise select operator"; + + let description = [{ + Elementwise select of the output based on a condition. + }]; + + let arguments = (ins + I1Tensor:$input1, + Tosa_TensorUpto4D:$input2, + Tosa_TensorUpto4D:$input3 + ); + + let results = (outs + Tosa_TensorUpto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.7 +// Operator Class: Logical Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: equal +//===----------------------------------------------------------------------===// +def Tosa_EqualOp : Tosa_Op<"equal", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of (x == y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: greater +//===----------------------------------------------------------------------===// +def Tosa_GreaterOp : Tosa_Op<"greater", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Returns the truth value of (x > y) element-wise."; + + let description = [{ + Elementwise greater than comparison operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: greater_equal +//===----------------------------------------------------------------------===// +def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Returns the truth value of (x >= y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1, + Tosa_TensorUpto4D:$input2 + ); + + let results = (outs + I1Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.8 +// Operator Class: Reduction Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: reduce_all +//===----------------------------------------------------------------------===// +def Tosa_ReduceAllOp : Tosa_Op<"reduce_all", [NoSideEffect]> { + let summary = [{ + Reduce All operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical AND operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_any +//===----------------------------------------------------------------------===// +def Tosa_ReduceAnyOp : Tosa_Op<"reduce_any", [NoSideEffect]> { + let summary = [{ + Reduce Any operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical OR operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_max +//===----------------------------------------------------------------------===// +def Tosa_ReduceMaxOp : Tosa_Op<"reduce_max", [NoSideEffect]> { + let summary = [{ + Reduce Max operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a maximum operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_min +//===----------------------------------------------------------------------===// +def Tosa_ReduceMinOp : Tosa_Op<"reduce_min", [NoSideEffect]> { + let summary = [{ + Reduce Min operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a minimum operation + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_prod +//===----------------------------------------------------------------------===// +def Tosa_ReduceProdOp : Tosa_Op<"reduce_prod", [NoSideEffect]> { + let summary = [{ + Reduce Prod operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the product of the axis. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_sum +//===----------------------------------------------------------------------===// +def Tosa_ReduceSumOp : Tosa_Op<"reduce_sum", [NoSideEffect]> { + let summary = [{ + Reduce Sum operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the sum of the axis. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.9 +// Operator Class: Data Layout / Memory Reinterpretation. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: concat +//===----------------------------------------------------------------------===// +def Tosa_ConcatOp : Tosa_Op<"concat", [NoSideEffect]> { + let summary = "Concatenates tensors along one dimension."; + + let description = [{ + Concatenate two tensors along a given axis. No data conversion happens during a concat operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input1, + Tosa_Tensor1Dto4D:$input2, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: pad +//===----------------------------------------------------------------------===// +def Tosa_PadOp : Tosa_Op<"pad", [NoSideEffect]> { + let summary = "Pads a tensor with zeros."; + + let description = [{ + Zero-pads a tensor along borders of each dimension. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input1, + Tosa_Int32Or64Tensor:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + + let builders = [Tosa_PadOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: reshape +//===----------------------------------------------------------------------===// +def Tosa_ReshapeOp: Tosa_Op<"reshape", [ + NoSideEffect]> { + let summary = "Reshape operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with a new shape specified by the shape + argument. Reshape may operate on tensors of any rank. No data conversion happens during a reshape + operation. + }]; + + let arguments = (ins + Tosa_TensorUpto6D:$input1, + I64ArrayAttr:$new_shape + ); + + let results = (outs Tosa_TensorUpto6D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: reverse +//===----------------------------------------------------------------------===// +def Tosa_ReverseOp: Tosa_Op<"reverse", [ + NoSideEffect]> { + let summary = "Reverse operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with the data reversed along the given + axis. No data conversion happens during a reverse operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I64Attr:$axis); + + let results = (outs + Tosa_Tensor1Dto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: slice +//===----------------------------------------------------------------------===// +def Tosa_SliceOp: Tosa_Op<"slice", [ + NoSideEffect]> { + let summary = "Slice operator"; + + let description = [{ + Extracts a slice of the input tensor 0 on the given axis, beginning at the start coordinates, + and extending for size elements in each direction. No data conversion happens during a slice operation. + }]; + + let arguments = (ins + Tosa_Tensor1Dto6D:$input, + I64ArrayAttr:$start, + I64ArrayAttr:$size + ); + + let results = (outs + Tosa_Tensor1Dto6D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: tile +//===----------------------------------------------------------------------===// +def Tosa_TileOp: Tosa_Op<"tile", [NoSideEffect]> { + let summary = "Tile operator"; + + let description = [{ + Replicates input 0 multiplies times along each dimension. + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input1, + I64ArrayAttr:$multiples); + + let results = (outs + Tosa_Tensor1Dto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose +//===----------------------------------------------------------------------===// +def Tosa_TransposeOp : Tosa_Op<"transpose", [NoSideEffect]> { + let summary = "Transpose operator"; + + let description = [{ + Permutes the dimensions based on perm. + }]; + + let arguments = (ins + Tosa_Tensor1Dto6D:$input1, + Tosa_Int32Or64Tensor:$perms + ); + + let results = ( + outs Tosa_Tensor1Dto6D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.10 +// Operator Class: Scatter/gather Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: gather +//===----------------------------------------------------------------------===// +def Tosa_GatherOp : Tosa_Op<"gather", [NoSideEffect]> { + let summary = [{ + Gather operation + }]; + + let description = [{ + Generate a tensor for which each element in the output is a subtensor of the values tensor along + the given axis, based on the value of indices. + }]; + + let arguments = (ins + Tosa_Int32Or64Tensor:$indices, + Tosa_Tensor1Dto4D:$values, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.11 +// Operator Class: Image Frontend Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: resize +//===----------------------------------------------------------------------===// +def Tosa_ResizeOp : Tosa_Op<"resize", [NoSideEffect]> { + + let summary = "Resize operation, supports various resize/upsample modes"; + + let description = [{ + Resizes a tensor. Resize is only allowed in the H and W dimensions. In expected use, + stride_y is approximately (IH< { + + let summary = "Cast operation"; + + let description = [{ + Performs a set of permissible cast operations + Mode Input Output + --------------------------------------- + signed 8 to bool int8 Boolean + signed 16 to bool int16 Boolean + signed 32 to bool int32 Boolean + bool to 8 Boolean int8 + bool to 16 Boolean int16 + bool to 32 Boolean int32 + signed 8 to signed 16 int8 int16 + signed 8 to signed 32 int8 int32 + signed 16 to signed 8 int16 int8 + signed 16 to signed 32 int16 int32 + signed 32 to signed 8 int32 int8 + signed 32 to signed 16 int32 int16 + float to signed 8 float int8 + float to signed 16 float int16 + signed 8 to float int8 float + signed 16 to float int16 float + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input + ); + + let results = (outs Tosa_TensorUpto4D:$output); + +} + +//===----------------------------------------------------------------------===// +// Operator: rescale +//===----------------------------------------------------------------------===// +def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect]> { + let summary = "Tosa rescale operator"; + + let description = [{ + Rescale quantized values into a new domain. Supported rescalings are: + Mode Input Output + signed 8 to 8 aint8 aint8 + signed 8 to 16 aint8 int16 + signed 8 to 32 aint8 int32 + signed 16 to 8 int16 aint8 + signed 16 to 16 int16 int16 + signed 16 to 32 int16 int32 + signed 32 to 8 int32 aint8 + signed 32 to 16 int32 int16 + signed 32 to 32 int32 int32 + signed 48 to 8 int48 aint8 + signed 48 to 16 int48 int16 + signed 48 to 32 int48 int32 + unsigned 8 to signed 8 uint8 aint8 + signed 8 to unsigned 8 aint8 uint8 + }]; + + let arguments = (ins + Tosa_Tensor1Dto4D:$input, + I32Attr:$input_zp, + I32Attr:$output_zp, + I32ArrayAttr:$multiplier, + I32ArrayAttr:$shift, + BoolAttr:$scale32, + BoolAttr:$double_round, + BoolAttr:$per_channel + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.13 +// Operator Class: Data Node Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: const +//===----------------------------------------------------------------------===// +def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, NoSideEffect, FirstAttrDerivedResultType]> { + let summary = "Constant op."; + + let description = [{ + A node containing constant data for use as the input to an operation. May hold data + in any of the supported data formats. + }]; + + let arguments = (ins + ElementsAttr:$value + ); + + let results = (outs Tosa_TensorUpto4D:$output); + + let builders = [ + OpBuilderDAG<(ins "Type":$type, "Attribute":$value)>, + ]; + +} + +//===----------------------------------------------------------------------===// +// Operator: identity +//===----------------------------------------------------------------------===// +def Tosa_IdentityOp: Tosa_Op<"identity", [NoSideEffect]> { + let summary = "Identity operator"; + let description = [{ + Returns a tensor with the same shape, size, type + and content as the input. + }]; + + let arguments = (ins + Tosa_TensorUpto4D:$input1 + ); + + let results = (outs + Tosa_TensorUpto4D:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: identityn +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Further described in docs/Rationale/RationaleTOSADialect.md . +//===----------------------------------------------------------------------===// +def Tosa_IdentityNOp: Tosa_Op<"identityn", [NoSideEffect]> { + let summary = "IdentityN operator"; + let description = [{ + Returns a list of tensors with the same shape, type, and contents as the + input list of tensors. + }]; + + let arguments = (ins + Variadic:$input1 + ); + + let results = (outs + Variadic:$output); +} + + +//===----------------------------------------------------------------------===// +// Operator: placeholder +//===----------------------------------------------------------------------===// +def Tosa_PlaceholderOp : Tosa_Op<"placeholder", [NoSideEffect]> { + let summary = "Placeholder op"; + + let description = [{ + A node where data will be inserted into the network at runtime. Generally used for inputs to the network. + }]; + + let arguments = (ins + ); + + let results = (outs + Tosa_Tensor1Dto4D:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.14 +// Operator Class: Custom Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: custom +//===----------------------------------------------------------------------===// +def Tosa_CustomOp : Tosa_Op<"custom"> { + + let summary = "Custom operator wrapper for Tosa"; + + let description = [{ + Hardware implementing TOSA may choose to add additional custom operators that are not expressed in + the existing TOSA operations. These operators are not expected to be portable across TOSA + implementations. The input and output signatures must be expressed in the corresponding TOSA node. + }]; + + let arguments = (ins + StrAttr:$identifier, + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.15 +// Operator Class: Control Flow Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: cond_if +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Further described in docs/Rationale/RationaleTOSADialect.md . +//===----------------------------------------------------------------------===// +def Tosa_IfOp : Tosa_Op<"cond_if", [ + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Conditional if operator"; + + let description = [{ + Evaluates a Boolean condition and then takes one of two distinct execution paths. This + implements the semantic If-then-else structure. + }]; + + let arguments = (ins + I1Tensor:$cond, + Variadic:$inputs + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region + SizedRegion<1>:$then_branch, + SizedRegion<1>:$else_branch + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: while_loop +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Further described in docs/Rationale/RationaleTOSADialect.md . +//===----------------------------------------------------------------------===// +def Tosa_WhileOp : Tosa_Op<"while_loop", [ + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = [{ + output = input; While (Cond(output)) { output = Body(output) } + }]; + + let description = [{ + Generates and evaluates a Bool condition and either executes a loop body or exits to + another control point. This action is performed repeatedly after updating and re-evaluating + the Boolean condition every iteration. This implements the semantic foreach or while + iterative loop structure. + }]; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + Variadic:$output); + + let regions = (region + SizedRegion<1>:$cond, + SizedRegion<1>:$body + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: yield +//===----------------------------------------------------------------------===// +def Tosa_YieldOp : Tosa_Op<"yield", [Terminator]> { + let summary = "yield operator"; + + let description = [{ + return operation within the conditional and body of + structured control flow. Operation takes variadic operands + but produces no results of its own. + }]; + + let arguments = (ins + Variadic:$inputs + ); + +} + +#endif // TOSA_OPS Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h @@ -0,0 +1,34 @@ +//===-- TosaTraits.h - TOSA dialect operation traits ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect OpTraits in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATRAITS_H +#define MLIR_DIALECT_TOSA_IR_TOSATRAITS_H + +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace OpTrait { +namespace tosa { + +// TBD + +} // namespace tosa +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATRAITS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h @@ -0,0 +1,31 @@ +//===-- TosaTypes.h - TOSA dialect type definitions -------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect Types in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATYPES_H +#define MLIR_DIALECT_TOSA_IR_TOSATYPES_H + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +namespace tosa { + +// TOSA specific types go here + +} // namespace tosa + +} // end namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATYPES_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -0,0 +1,159 @@ +//===-- TosaTypesBase.td - TOSA type definitions -----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the type definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOSA_TYPES_BASE +#else +#define TOSA_TYPES_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Tosa Type Definitions. +//===----------------------------------------------------------------------===// + +// The base class of a quantized type. +// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end]. +// Where low and high ends are 0,255 when unsigned, -128,127 when signed, for the 8-bit case. +class Tosa_QuantizedType params, bit signed> + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # !head(params)>]>, + "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = + StrJoinInt.result # !if(signed, ", true", ", false"); +} + +//===----------------------------------------------------------------------===// +// Non-Quantized Signed Integer Types. +// Used to express accumulator results or compare results. +//===----------------------------------------------------------------------===// + +def Tosa_Int32 : I<32>; +def Tosa_Int48 : I<48>; +def Tosa_Int64 : I<64>; + +def Tosa_SignedInt : AnyTypeOf<[Tosa_Int32, + Tosa_Int48, + Tosa_Int64]>; + +def Tosa_Bool : I<1>; + +// No unsigned unquantized int types. +def Tosa_Int : AnyTypeOf<[Tosa_Bool, + Tosa_SignedInt]>; + +def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, + Tosa_Int64]>; + +//===----------------------------------------------------------------------===// +// Quantized Integer Types. +// Datatype for network feature map or weight content. +//===----------------------------------------------------------------------===// +//===----------------------------------------------------------------------===// +// Name Symmetry Grouping Sign +//===----------------------------------------------------------------------===// +// aint8 : asymmetric per tensor, signed +// uint8 : asymmetric per tensor , unsigned +// int4 : symmetric per channel, signed +// int8 : symmetric per tensor/per channel, signed +// int16 : symmetric per tensor, signed +//===----------------------------------------------------------------------===// +def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"aint8", [8], 1>, + Tosa_QuantizedType<"uint8", [8], 0>, + Tosa_QuantizedType<"int4", [4, 0], 1>, + Tosa_QuantizedType<"int8", [8, 0], 1>, + Tosa_QuantizedType<"int16", [16, 0], 1>]>; + + +//===----------------------------------------------------------------------===// +// Floating-point types +//===----------------------------------------------------------------------===// +def Tosa_Float : AnyTypeOf<[F32, + F16, + BF16]>; + +//===----------------------------------------------------------------------===// +// Multi-category types +//===----------------------------------------------------------------------===// +def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], + "number">; + +//===----------------------------------------------------------------------===// +// Tensor types +//===----------------------------------------------------------------------===// + +def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>; + +def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; + +// Any tensor element type allowed in Tosa ops. +def Tosa_ElementType : Type, + "tosa.dtype">; + +class Tosa_TensorOfOrNone allowedTypes, string description = ""> : + AnyTypeOf<[TensorOf, NoneType], description>; + +//===----------------------------------------------------------------------===// +// Tensor types with constrained ranks +//===----------------------------------------------------------------------===// + +// must be listed rank +def Tosa_Tensor1D : 1DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor2D : 2DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor4D : 4DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor5D : TensorRankOf<[Tosa_AnyNumber], [5]>; +def Tosa_Tensor6D : TensorRankOf<[Tosa_AnyNumber], [6]>; + +// Ranked tensors up to given rank +def Tosa_Tensor1Dto2D : TensorRankOf<[Tosa_AnyNumber], [1,2]>; +def Tosa_Tensor1Dto4D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>; +def Tosa_Tensor1Dto5D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5]>; +def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>; + +def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>; +def Tosa_TensorUpto6D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4,5,6]>; + +//===----------------------------------------------------------------------===// +// Attribute predicates and classes +//===----------------------------------------------------------------------===// +class ArrayMaxCt : AttrConstraint< + CPred<"$_self.cast<::mlir::ArrayAttr>().size() <= " # n>, + "with at least " # n # " elements">; + +def Tosa_IntArrayAttr2 : Confined]>; +def Tosa_IntArrayAttr3 : Confined]>; +def Tosa_IntArrayAttr4 : Confined]>; +def Tosa_IntArrayAttr5 : Confined]>; +def Tosa_IntArrayAttr6 : Confined]>; + +def Tosa_IntArrayAttrUpto2 : Confined]>; +def Tosa_IntArrayAttrUpto4 : Confined]>; +def Tosa_IntArrayAttrUpto5 : Confined]>; + +//===----------------------------------------------------------------------===// +// Iterable attributes +//===----------------------------------------------------------------------===// +// Supported regimes for tosa.resize +def Tosa_ResizeTypeAttr : StringBasedAttr< + CPred<"$_self.cast().getValue() == \"BILINEAR\" || " # + "$_self.cast().getValue() == \"NEAREST_NEIGHBOR\"">, + "Supported resize/upsampling strategies">; + +def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; + +// Tensor to buffer types. +def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>; +def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>; +def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>; + +#endif // TOSA_TYPES_BASE Index: mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) +add_public_tablegen_target(MLIRTosaPassIncGen) +add_dependencies(mlir-headers MLIRTosaPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc TosaPasses ./) Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -0,0 +1,38 @@ +//===-- Passes.h - TOSA optimization pass declarations ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class FuncOp; +class ModuleOp; +class Pass; +template +class OperationPass; + +namespace tosa { + +std::unique_ptr> createTosaMakeBroadcastablePass(); + +std::unique_ptr> createTOSATestQuantUtilAPIPass(); + +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -0,0 +1,37 @@ +//===-- Passes.td - TOSA optimization pass declarations ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +include "mlir/Pass/PassBase.td" + +def TosaBinaryInputReshapePass : Pass<"tosa-make-broadcastable", "FuncOp"> { + let summary = "TOSA rank Reshape to enable Broadcasting"; + let description = [{ + Pass that enables broadcast by making all input arrays have the same + number of dimensions. Insert RESHAPE operations to prepend dimensions + of size one until the number of dimensions is equal. Implements + approach similar to step 1 of Numpy 4-step broadcasting: + https://numpy.org/doc/stable/reference/ufuncs.html#broadcasting + }]; + + let constructor = "createTosaMakeBroadcastablePass()"; +} + +// TOSA Test Passes + +def TosaTestBuildQTypeAPIPass : Pass<"tosa-test-quant-utils", "FuncOp"> { + let summary = "TOSA Test: Exercise the BuildQTypeFromMinMax API"; + let description = [{ + Exercises the API that builds a quantized type from min/max quantized range. + }]; + + let constructor = "createTOSATestQuantUtilAPIPass()"; +} Index: mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -0,0 +1,71 @@ +//===-- QuantUtils.h - TOSA numerical support declarations ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Function declarations for TOSA numerical support functions and quantization +// attribute builders +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_UTILS_QUANT_UTILS_H +#define DIALECT_TOSA_UTILS_QUANT_UTILS_H + +//===----------------------------------------------------------------------===// +// Utililty functions to support quantization handling in Tosa. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +#include "mlir/Dialect/Quant/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/UniformSupport.h" + +namespace mlir { +namespace tosa { + +/// From a scale value, computes multiplier and shift values +/// for 16 or 32-bit scale widths. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth); + +//// Builds ConvOpQuantizationAttr from input and weight. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, Value weight); + +//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B. +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(mlir::OpBuilder &builder, + Value a, Value b); + +//// Builds UnaryOpQuantizationAttr for unary operations from input values. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, + Type outputRawType); + +//// Builds PadOpQuantizationAttr for pad operations from input values. +PadOpQuantizationAttr buildPadOpQuantizationAttr(mlir::OpBuilder &builder, + Value input); + +//// construct ConvOp output type with correct bitwidth based on input/weight +/// width. +Type buildConvOpResultTypeInfo(mlir::OpBuilder &builder, Type outputType, + Value input, Value weight); + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange); + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange); + +} // namespace tosa +} // namespace mlir + +#endif // DIALECT_TOSA_UTILS_QUANT_UTILS_H Index: mlir/include/mlir/InitAllDialects.h =================================================================== --- mlir/include/mlir/InitAllDialects.h +++ mlir/include/mlir/InitAllDialects.h @@ -33,6 +33,7 @@ #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Dialect.h" @@ -60,7 +61,8 @@ NVVM::NVVMDialect, ROCDL::ROCDLDialect, SDBMDialect, - shape::ShapeDialect>(); + shape::ShapeDialect, + tosa::TosaDialect>(); // clang-format on } Index: mlir/include/mlir/InitAllPasses.h =================================================================== --- mlir/include/mlir/InitAllPasses.h +++ mlir/include/mlir/InitAllPasses.h @@ -24,6 +24,7 @@ #include "mlir/Dialect/SPIRV/Passes.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Transforms/Passes.h" #include @@ -54,6 +55,7 @@ registerShapePasses(); spirv::registerSPIRVPasses(); registerStandardPasses(); + tosa::registerTosaOptPasses(); } } // namespace mlir Index: mlir/lib/Dialect/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/CMakeLists.txt +++ mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) set(LLVM_OPTIONAL_SOURCES Index: mlir/lib/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_dialect_library(MLIRTosa + IR/TosaOps.cpp + Utils/QuantUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + + DEPENDS + MLIRStandardOpsIncGen + MLIRTosaOpsIncGen + MLIRTosaStructsIncGen + MLIRTosaInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRStandard + MLIRCallInterfaces + MLIRControlFlowInterfaces + MLIRSideEffectInterfaces + MLIRViewLikeInterface + ) + +add_subdirectory(Transforms) Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -0,0 +1,155 @@ +//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This file implements the TOSA Specification: +// https://developer.mlplatform.org/w/tosa/ +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Parser.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; +using namespace mlir::tosa; + +#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc" + +//===----------------------------------------------------------------------===// +// Tosa dialect interfaces. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" + +//===----------------------------------------------------------------------===// +// Dialect Function Inliner Interface. +//===----------------------------------------------------------------------===// +struct TosaInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks. + //===--------------------------------------------------------------------===// + + /// All operations can be inlined by default. + bool isLegalToInline(Operation *op, Region *region, + BlockAndValueMapping &map) const { + return true; + } + + /// All regions with If and While parent operators can be inlined. + /// TODO TEST: need test for the inliner + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &map) const { + return (isa(dest->getParentOp()) || + isa(dest->getParentOp())); + } +}; + +//===----------------------------------------------------------------------===// +// TOSA control flow support. +//===----------------------------------------------------------------------===// + +/// Returns the while loop body. +Region &tosa::WhileOp::getLoopBody() { return body(); } + +bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) { + return !body().isAncestor(value.getParentRegion()); +} + +/// TODO TEST: Test for loop invariant code motion +LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef ops) { + if (ops.empty()) + return success(); + + Operation *tosaWhileOp = this->getOperation(); + for (auto *op : ops) + op->moveBefore(tosaWhileOp); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Tosa Dialect. +//===----------------------------------------------------------------------===// + +TosaDialect::TosaDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + >(); + addInterfaces(); +} + +//===----------------------------------------------------------------------===// +// ConstOp +//===----------------------------------------------------------------------===// + +void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, + Attribute value) { + result.addTypes(type); + result.addAttribute("value", value); + return; +} + +//===----------------------------------------------------------------------===// +// Verifiers +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyConvOp(T op) { + + // All TOSA conv ops have an input() and weight(). + auto inputType = op.input().getType().template dyn_cast(); + auto weightType = op.weight().getType().template dyn_cast(); + + // Must be ranked tensor types + if (!inputType || !weightType) + return failure(); + + auto inputQType = inputType.getElementType() + .template isa(); + auto weightQType = weightType.getElementType() + .template isa(); + + // Either both must be quantized or both unquantized. + if (inputQType ^ weightQType) + return failure(); + + // Quantized type must have constructed the quantizationattr, and unquantized + // types should not have a quantizationattr. + if ((inputQType && !op.quantization_info()) || + (!inputQType && op.quantization_info())) + return failure(); + + return success(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" Index: mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRTosaTransforms + TosaMakeBroadcastable.cpp + TosaTestPasses.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms + + DEPENDS + MLIRTosaPassIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTosa + ) Index: mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -0,0 +1,302 @@ +//===- TosaMakeBroadcastable.cpp ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Insert reshape to binary op's input if needed to match rank +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define PASS_NAME "tosa-make-broadcastable" +#define DEBUG_TYPE PASS_NAME + +namespace mlir { + +namespace tosa { + +namespace { +/// Pass that enables broadcast by making all input arrays have the same +/// number of dimensions. Insert RESHAPE operations to lower rank operand +class TosaMakeBroadcastable + : public PassWrapper { +public: + explicit TosaMakeBroadcastable() {} + void runOnFunction() override; +}; + +/// There are two potential ways implementing broadcast: +/// a. https://www.tensorflow.org/xla/broadcasting#formal_definition +/// b. https://numpy.org/doc/stable/user/basics.broadcasting.html +/// TBD: picking option (a) now. + +/// In this pass, we only insert RESHAPE to lower rank operand as a first step +/// so that it can be broadcastable with high rank operand. real broadcast +/// happens in TOSA + +// Examples: +// If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1]. +// TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into +// [1, b, 1]. +// If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c]. +// If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c]. +// If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1]. +// If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c]. +// If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1]. +// If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. +// If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. + +static void +computeReshapeOutput(ArrayRef higher_rank_shape, + ArrayRef lower_rank_shape, int64_t higher_rank, + int64_t lower_rank, + SmallVector &reshape_output_shape) { + // initializae new shapes with [1] * higher_rank + reshape_output_shape.assign(higher_rank, 1); + + int64_t higher_left_index = 0; + int64_t higher_right_index = higher_rank; + int64_t lower_left_index = 0; + int64_t lower_right_index = lower_rank; + int64_t higher_rank_dim, lower_rank_dim; + + if (lower_right_index != 0 && higher_right_index != 0) { + // matches lower rank shape from right dimension first, until not + // matching high rank shape or reaching dimension 0 + while (true) { + higher_rank_dim = higher_rank_shape[higher_right_index - 1]; + lower_rank_dim = lower_rank_shape[lower_right_index - 1]; + if (higher_rank_dim == lower_rank_dim) { + reshape_output_shape[higher_right_index - 1] = higher_rank_dim; + + if (higher_right_index > 0) { + higher_right_index--; + } + + if (lower_right_index > 0) { + lower_right_index--; + } + + if (higher_right_index == 0 || lower_right_index == 0) { + break; + } + } else { + break; + } + } + if (lower_right_index != 0 && higher_right_index != 0) { + // matches lower rank shape from left dimension, until not matching + // high rank shape or reaching right index + while (true) { + higher_rank_dim = higher_rank_shape[higher_left_index]; + lower_rank_dim = lower_rank_shape[lower_left_index]; + if (higher_rank_dim == lower_rank_dim) { + reshape_output_shape[higher_left_index] = higher_rank_dim; + + if (higher_left_index < higher_right_index) { + higher_left_index++; + } + + if (lower_left_index < lower_right_index) { + lower_left_index++; + } + + if (higher_left_index == higher_right_index || + lower_left_index == lower_right_index) { + break; + } + } else { + break; + } + } + } + } +} + +int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, + RankedTensorType output_type, Value in_lhs, + Value in_rhs, Value &out_lhs, Value &out_rhs) { + + int64_t lhs_rank = in_lhs.getType().dyn_cast().getRank(); + int64_t rhs_rank = in_rhs.getType().dyn_cast().getRank(); + + int64_t higher_rank, lower_rank; + Value higher_tensor_value, lower_tensor_value; + /* return if rank already match */ + if (lhs_rank == rhs_rank) { + return 1; + } else if (lhs_rank > rhs_rank) { + higher_rank = lhs_rank; + lower_rank = rhs_rank; + higher_tensor_value = in_lhs; + lower_tensor_value = in_rhs; + } else { + higher_rank = rhs_rank; + lower_rank = lhs_rank; + higher_tensor_value = in_rhs; + lower_tensor_value = in_lhs; + } + + ArrayRef higher_rank_shape = output_type.getShape(); + ArrayRef lower_rank_shape = + lower_tensor_value.getType().dyn_cast().getShape(); + + SmallVector reshape_output_shape; + + computeReshapeOutput(higher_rank_shape, lower_rank_shape, higher_rank, + lower_rank, reshape_output_shape); + + auto reshape_input_type = + lower_tensor_value.getType().dyn_cast(); + auto reshape_output_type = + RankedTensorType::get(ArrayRef(reshape_output_shape), + reshape_input_type.getElementType()); + + auto reshape_lower = rewriter.create( + loc, reshape_output_type, lower_tensor_value, + rewriter.getI64ArrayAttr(reshape_output_shape)); + + if (lhs_rank > rhs_rank) { + out_lhs = higher_tensor_value; + out_rhs = reshape_lower.getResult(); + } else { + out_lhs = reshape_lower.getResult(); + out_rhs = higher_tensor_value; + } + + return 0; +} + +template +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy tosa_binary_op, + PatternRewriter &rewriter) const { + + Value in_lhs = tosa_binary_op.input1(); + Value in_rhs = tosa_binary_op.input2(); + Value output = tosa_binary_op.getResult(); + auto output_type = output.getType().dyn_cast(); + + Value out_lhs, out_rhs; + if (reshapeLowerToHigher(rewriter, tosa_binary_op.getLoc(), output_type, + in_lhs, in_rhs, out_lhs, out_rhs)) { + return failure(); + } + + rewriter.replaceOpWithNewOp(tosa_binary_op, output_type, out_lhs, + out_rhs); + + return success(); + } +}; + +// The MulOp has an extra parameter 'shift' not present in other elementwise +// binary ops, that necessitates special handling of its builder. +template <> +struct ConvertTosaOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MulOp tosa_binary_op, + PatternRewriter &rewriter) const { + + Value in_lhs = tosa_binary_op.input1(); + Value in_rhs = tosa_binary_op.input2(); + int32_t shift = tosa_binary_op.shift(); + Value output = tosa_binary_op.getResult(); + auto output_type = output.getType().dyn_cast(); + + Value out_lhs, out_rhs; + if (reshapeLowerToHigher(rewriter, tosa_binary_op.getLoc(), output_type, + in_lhs, in_rhs, out_lhs, out_rhs)) { + return failure(); + } + + rewriter.replaceOpWithNewOp(tosa_binary_op, output_type, + out_lhs, out_rhs, shift); + + return success(); + } +}; + +// The ArithmeticRightShiftOp has an extra parameter 'round' not present in +// other elementwise binary ops, that necessitates special handling of its +// builder. +template <> +struct ConvertTosaOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosa_binary_op, + PatternRewriter &rewriter) const { + + Value in_lhs = tosa_binary_op.input1(); + Value in_rhs = tosa_binary_op.input2(); + int32_t round = tosa_binary_op.round(); + Value output = tosa_binary_op.getResult(); + auto output_type = output.getType().dyn_cast(); + + Value out_lhs, out_rhs; + if (reshapeLowerToHigher(rewriter, tosa_binary_op.getLoc(), output_type, + in_lhs, in_rhs, out_lhs, out_rhs)) { + return failure(); + } + + rewriter.replaceOpWithNewOp( + tosa_binary_op, output_type, out_lhs, out_rhs, round); + + return success(); + } +}; + +void TosaMakeBroadcastable::runOnFunction() { + OwningRewritePatternList patterns; + auto *ctx = &getContext(); + auto func = getFunction(); + + // Add the generated patterns to the list. + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + patterns.insert>(ctx); + applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // anonymous namespace + +std::unique_ptr> createTosaMakeBroadcastablePass() { + return std::make_unique(); +} + +static PassRegistration + pass(PASS_NAME, + "Perform broadcast on elementwise TosaOps to ensure same rank"); + +} // namespace tosa + +} // namespace mlir Index: mlir/lib/Dialect/Tosa/Transforms/TosaTestPasses.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/TosaTestPasses.cpp @@ -0,0 +1,217 @@ +//===- TosaTestPasses.cpp -------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test passes to exercise TOSA helper functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +// #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define PASS_NAME "tosa-test-quant-utils" + +namespace mlir { + +namespace tosa { + +namespace { + +class TosaTestQuantUtilAPI + : public PassWrapper { +public: + explicit TosaTestQuantUtilAPI() {} + void runOnFunction() override; +}; + +// This transformation converts quantized uint8 to quantized int8. The +// construction of the new type invokes buildQTypeFromMinMax. Extracted from +// TOSA legalization infrastructure. +struct ConvertTosaNegateOp : public RewritePattern { + explicit ConvertTosaNegateOp(MLIRContext *context) + : RewritePattern(tosa::NegateOp::getOperationName(), 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult +ConvertTosaNegateOp::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + + auto tosaNegateOp = cast(op); + + auto inputType = + tosaNegateOp.input1().getType().dyn_cast(); + // skip if input is not ranked tensor type + if (!inputType) + return failure(); + + // skip if it's not ranked tensor type. + auto outputType = + tosaNegateOp.getResult().getType().dyn_cast(); + if (!outputType) + return failure(); + + // skip if output is not per-tensor quantized type. + auto outputElementType = + outputType.getElementType().dyn_cast(); + if (!outputElementType) + return failure(); + + // skip if output is not uint8. + if (outputElementType.isSigned() || + outputElementType.getStorageTypeIntegralWidth() != 8) { + return failure(); + } + + double typeRangeMin = double(outputElementType.getStorageTypeMin() - + outputElementType.getZeroPoint()) * + outputElementType.getScale(); + double typeRangeMax = double(outputElementType.getStorageTypeMax() - + outputElementType.getZeroPoint()) * + outputElementType.getScale(); + bool narrow_range = outputElementType.getStorageTypeMin() == 1 ? true : false; + + auto dstQConstType = RankedTensorType::get( + outputType.getShape(), + buildQTypeFromMinMax(rewriter, outputElementType.getExpressedType(), + rewriter.getF64FloatAttr(typeRangeMin), + rewriter.getF64FloatAttr(typeRangeMax), + rewriter.getI32IntegerAttr( + outputElementType.getStorageTypeIntegralWidth()), + 0, true /* signed */, + rewriter.getBoolAttr(narrow_range))); + + ElementsAttr inputElems; + if (!matchPattern(tosaNegateOp.input1(), m_Constant(&inputElems))) + return failure(); + + auto newConstOp = + rewriter.create(op->getLoc(), dstQConstType, inputElems); + auto newNegateOp = rewriter.create( + op->getLoc(), dstQConstType, newConstOp.getResult()); + + rewriter.replaceOp(op, {newNegateOp.getResult()}); + return success(); +} + +// This transformation modifies the quantized output of a test conv2d input and +// appends a TOSA rescale after it. The rescale op requires the invocation of +// computeMultiplierAndShift. From TOSA legalization infrastructure. +struct ConvertTosaConv2DOp : public RewritePattern { + explicit ConvertTosaConv2DOp(MLIRContext *context) + : RewritePattern(tosa::Conv2DOp::getOperationName(), 1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult +ConvertTosaConv2DOp::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + + auto tosaConv2DOp = cast(op); + + auto inputType = + tosaConv2DOp.input().getType().dyn_cast(); + + // skip if input is not ranked tensor type + if (!inputType) + return failure(); + + auto weightType = + tosaConv2DOp.weight().getType().dyn_cast(); + + // skip if wt is not ranked tensor type + if (!weightType) + return failure(); + + // skip if it's not ranked tensor type. + auto outputType = + tosaConv2DOp.getResult().getType().dyn_cast(); + if (!outputType) + return failure(); + + auto inputQType = + inputType.getElementType().dyn_cast(); + auto weightQType = + weightType.getElementType().dyn_cast(); + auto outputQType = + outputType.getElementType().dyn_cast(); + + // Works on quantized type only. + if (!(inputQType && weightQType && outputQType)) + return failure(); + + auto newTosaConv2DOpType = + RankedTensorType::get(outputType.getShape(), rewriter.getIntegerType(32)); + + auto newTosaConv2DOp = rewriter.create( + op->getLoc(), newTosaConv2DOpType, tosaConv2DOp.input(), + tosaConv2DOp.weight(), tosaConv2DOp.bias(), tosaConv2DOp.pad(), + tosaConv2DOp.stride(), tosaConv2DOp.dilation()); + + // Create rescale to quantized type + double inputScale = inputQType.getScale(); + double weightScale = weightQType.getScale(); + double outputScale = outputQType.getScale(); + int64_t outputZp = outputQType.getZeroPoint(); + + double opTensorScale = (inputScale * weightScale) / outputScale; + + int32_t multiplier; + int32_t shift; + + // Obtain the quantized scale = multiplier and shift. + computeMultiplierAndShift(opTensorScale, multiplier, shift, 32); + + auto newTosaRescaleOp = rewriter.create( + op->getLoc(), outputType, newTosaConv2DOp.getResult(), + rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp), + rewriter.getI32ArrayAttr({multiplier}), rewriter.getI32ArrayAttr({shift}), + rewriter.getBoolAttr(true), rewriter.getBoolAttr(true), + rewriter.getBoolAttr(false)); + + rewriter.replaceOp(op, {newTosaRescaleOp.getResult()}); + return success(); +} + +void TosaTestQuantUtilAPI::runOnFunction() { + OwningRewritePatternList patterns; + auto *ctx = &getContext(); + auto func = getFunction(); + + patterns.insert(ctx); + patterns.insert(ctx); + applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // anonymous namespace + +std::unique_ptr> createTOSATestQuantUtilAPIPass() { + return std::make_unique(); +} + +static PassRegistration + pass(PASS_NAME, + "Perform broadcast on elementwise TosaOps to ensure same rank"); + +} // namespace tosa + +} // namespace mlir Index: mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -0,0 +1,351 @@ +//===- QuantUtils.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains TOSA numerical support functions and quantization +// attribute builders. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" + +namespace mlir { +namespace tosa { + +namespace { + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 16-bit scaling. +void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 15)); + + assert(shiftedM <= (int64_t(1) << 15)); // Can't be greater that 1.0. + if (shiftedM == (int64_t(1) << 15)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive and embed (1 << 15) into right + // shift bits. + shift = (-shift) + 15; + + assert(shiftedM <= std::numeric_limits::max()); + + multiplier = static_cast(shiftedM); +} + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 32-bit scaling. +void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 31)); + + assert(shiftedM <= (int64_t(1) << 31)); // Can't be greater that 1.0. + if (shiftedM == (int64_t(1) << 31)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive, and embed (1 << 31) into right + // shift bits. + shift = (-shift) + 31; + + assert(shiftedM <= std::numeric_limits::max()); + + multiplier = static_cast(shiftedM); +} + +} // namespace + +/// Generates a quantized multiplier/shift from double. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth) { + + switch (scaleWidth) { + case 16: + computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); + return; + case 32: + computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); + return; + default: + assert(0 && "Unsupported Tosa quantized_scale regime specified!"); + } +} + +#define GET_UQTYPE(input_type) \ + ((input_type).getElementType().dyn_cast()) +#define GET_QTYPE(input_type) \ + ((input_type).getElementType().dyn_cast()) + +// Method to build ConvOpQuantizationAttr, called from +// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: input_zp: input zeropoint +// weight_zp: weight zeropoint. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, Value weight) { + + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); + + if (!inputType || !weightType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + auto weightPerTensorQType = GET_UQTYPE(weightType); + auto weightPerAxisQType = + weightType.getElementType() + .dyn_cast(); + + // Weight either per-tensor quantized or per-axis quantized + assert(!((bool)weightPerTensorQType && (bool)weightPerAxisQType)); + + // Either all quantized or all not quantized. + assert(!((bool)inputQType ^ + ((bool)weightPerTensorQType || (bool)weightPerAxisQType))); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t weightZp = 0; + + if (weightPerTensorQType) { + weightZp = weightPerTensorQType.getZeroPoint(); + } else if (weightPerAxisQType) { + weightZp = weightPerAxisQType.getZeroPoints().front(); + } + + auto quantAttr = mlir::tosa::ConvOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds MatMulOpQuantizationAttr, called from +/// MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: input b zeropoint +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(mlir::OpBuilder &builder, + Value a, Value b) { + + auto aType = a.getType().dyn_cast(); + auto bType = b.getType().dyn_cast(); + + if (!aType || !bType) + return nullptr; + + auto aQType = GET_UQTYPE(aType); + auto bQType = GET_UQTYPE(bType); + + // A and B are either all quantized or all not quantized. + assert(!((bool)aQType ^ (bool)bQType)); + + if (aQType) { + + int64_t aZp = aQType.getZeroPoint(); + int64_t bZp = bQType.getZeroPoint(); + + auto quantAttr = mlir::tosa::MatMulOpQuantizationAttr::get( + builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds UnaryOpQuantizationAttr +/// UnaryOpQuantInfoBuilder: inputZp: input zeropoint +/// outputZp: output zeropoint. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, + Type outputRawType) { + + auto inputType = input.getType().dyn_cast(); + auto outputType = outputRawType.dyn_cast(); + + if (!inputType || !outputType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + auto outputQType = GET_UQTYPE(outputType); + + // Either all quantized or all not quantized. + assert(!((bool)inputQType ^ (bool)outputQType)); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t outputZp = outputQType.getZeroPoint(); + + auto quantAttr = mlir::tosa::UnaryOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: +/// inputZp: input zeropoint +PadOpQuantizationAttr buildPadOpQuantizationAttr(mlir::OpBuilder &builder, + Value input) { + + auto inputType = input.getType().dyn_cast(); + + if (!inputType) + return nullptr; + + auto inputQType = GET_UQTYPE(inputType); + + if (inputQType) { + + int64_t inputZp = inputQType.getZeroPoint(); + + auto quantAttr = mlir::tosa::PadOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds output type for a quantized ConvOp with the right bitwidth. +/// This is called by the builder when dealing with quantized content. +Type buildConvOpResultTypeInfo(mlir::OpBuilder &builder, Type outputType, + Value input, Value weight) { + + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); + + assert(inputType && weightType); + + auto inputQType = GET_QTYPE(inputType); + auto weightQType = GET_QTYPE(weightType); + + assert(inputQType && weightQType); + + unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); + unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); + + auto outputShapedType = outputType.dyn_cast(); + assert(outputShapedType); + + auto outputShape = outputShapedType.getShape(); + + IntegerType accElementType; + if (inputBits == 16 && weightBits == 8) + accElementType = builder.getIntegerType(48); + else + accElementType = builder.getI32Type(); + auto accType = RankedTensorType::get(outputShape, accElementType); + return accType; +} + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange) { + + quant::QuantizedType retType; + + auto convfunc = + quant::ExpressedToQuantizedConverter::forInputType(inputDType); + + auto minElems = minAttr.dyn_cast(); + auto maxElems = maxAttr.dyn_cast(); + + SmallVector min, max; + + // At least one is per-axis quantized elementsattr. + if (minElems || maxElems) { + + // Must have the same number of elements. + if (minElems.getNumElements() != maxElems.getNumElements()) + return {}; + + min.reserve(minElems.getNumElements()); + max.reserve(maxElems.getNumElements()); + for (auto i : minElems) { + min.push_back(FloatAttr::getValueAsDouble(i)); + } + for (auto i : maxElems) { + max.push_back(FloatAttr::getValueAsDouble(i)); + } + } else { // Just a single FP value. + + auto minVal = minAttr.dyn_cast(); + if (minVal) + min.push_back(minVal.getValueAsDouble()); + else + return {}; + auto maxVal = maxAttr.dyn_cast(); + if (maxVal) + max.push_back(maxVal.getValueAsDouble()); + else + return {}; + } + + if (min.size() == max.size()) { + + if (min.size() == 1) { // Per-tensor quantization with one min/max pair. + + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], + narrowRange.getValue(), convfunc.expressedType, isSigned); + + } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. + + auto shape = inputDType.dyn_cast(); + if (!shape) + return {}; + if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { + + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], + max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); + } + + } else { + return {}; + } + } else { + return {}; + } + + if (!retType) + return {}; + + return convfunc.convert(retType); +} + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange) { + + return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, + maxAttr, quantBits, filterQuantDim, + isSigned, narrowRange)); +} + +} // namespace tosa +} // namespace mlir Index: mlir/test/Dialect/Tosa/broadcast.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/broadcast.mlir @@ -0,0 +1,153 @@ +// RUN: mlir-opt --tosa-make-broadcastable %s | FileCheck %s + +// ----- +// CHECK-LABEL: broadcast0 +func @test_broadcast0(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-NEXT: add + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +// ----- +// CHECK-LABEL: broadcast1 +func @test_broadcast1(%arg0: tensor<1xf32>, %arg1: tensor<2x1xf32>) -> tensor<2x1xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<2x1xf32>) -> tensor<2x1xf32> + return %0 : tensor<2x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast2 +func @test_broadcast2(%arg0: tensor<2x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1xf32>, tensor<1xf32>) -> tensor<2x1xf32> + return %0 : tensor<2x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast3 +func @test_broadcast3(%arg0: tensor<2x1x1x1xf32>, %arg1: tensor<1xf32>) -> tensor<2x1x1x1xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<2x1x1x1xf32>, tensor<1xf32>) -> tensor<2x1x1x1xf32> + return %0 : tensor<2x1x1x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast4 +func @test_broadcast4(%arg0: tensor<1x1x1x2xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x1x2xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x1x2xf32>, tensor<1xf32>) -> tensor<1x1x1x2xf32> + return %0 : tensor<1x1x1x2xf32> +} + +// ----- +// CHECK-LABEL: broadcast5 +func @test_broadcast5(%arg0: tensor<1x1x2x1xf32>, %arg1: tensor<1xf32>) -> tensor<1x1x2x1xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1x2x1xf32>, tensor<1xf32>) -> tensor<1x1x2x1xf32> + return %0 : tensor<1x1x2x1xf32> +} + +// ----- +// CHECK-LABEL: broadcast6 +func @test_broadcast6(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1xf32>) -> tensor<17x16x15x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast7 +func @test_broadcast7(%arg0: tensor<17x16x1x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x1x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x1x14xf32>, tensor<1x1xf32>) -> tensor<17x16x1x14xf32> + return %0 : tensor<17x16x1x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast8 +func @test_broadcast8(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<1x1xf32>) -> tensor<17x16x15x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<1x1xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast9 +func @test_broadcast9(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x1xf32>) -> tensor<17x16x15x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x1xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast10 +func @test_broadcast10(%arg0: tensor<17x16x15x14xf32>, %arg1: tensor<15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<17x16x15x14xf32>, tensor<15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast13 +func @test_broadcast13(%arg0: tensor<1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast14 +func @test_broadcast14(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x1x14xf32>) -> tensor<17x16x1x14xf32> + return %0 : tensor<17x16x1x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast15 +func @test_broadcast15(%arg0: tensor<1x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<1x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast16 +func @test_broadcast16(%arg0: tensor<15x1xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x1xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast17 +func @test_broadcast17(%arg0: tensor<15x14xf32>, %arg1: tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> { + // CHECK-NEXT: reshape + %0 = "tosa.add"(%arg0, %arg1) : (tensor<15x14xf32>, tensor<17x16x15x14xf32>) -> tensor<17x16x15x14xf32> + return %0 : tensor<17x16x15x14xf32> +} + +// ----- +// CHECK-LABEL: broadcast18 +func @test_broadcast18(%arg0: tensor<14x1xf32>, %arg1: tensor<1x15xf32>) -> tensor<14x15xf32> { + // CHECK-NEXT: add + %0 = "tosa.add"(%arg0, %arg1) : (tensor<14x1xf32>, tensor<1x15xf32>) -> tensor<14x15xf32> + return %0 : tensor<14x15xf32> +} + +// ----- +// CHECK-LABEL: broadcast_mul +func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { + // CHECK-NEXT: reshape + %0 = "tosa.mul"(%arg0, %arg1) {shift = 1 : i32 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> + return %0 : tensor<17x16x15x14xi32> +} + +// ----- +// CHECK-LABEL: broadcast_arithmetic_right_shift +func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> { + // CHECK-NEXT: reshape + %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = true } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> + return %0 : tensor<17x16x15x14xi32> +} \ No newline at end of file Index: mlir/test/Dialect/Tosa/dynamic_shape.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/dynamic_shape.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + + +// ----- +// CHECK-LABEL: argmax +func @test_argmax(%arg0: tensor) -> tensor { + %0 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor) -> tensor + return %0 : tensor +} + Index: mlir/test/Dialect/Tosa/illegal-types.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/illegal-types.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +//===-----------------------------------------------------------------------===// +// Tests to check for illegal TOSA data types +//===-----------------------------------------------------------------------===// + +// CHECK: func @legal_i32 +func @legal_i32(tensor<16xi32>) -> () + +// ----- +// expected-error @+1 {{expected non-function type}} +func @illegal_u8(tensor<16xu8>) -> () + +// ----- +// expected-error @+1 {{expected non-function type}} +func @illegal_u16(tensor<16xu16>) -> () + +// ----- +// expected-error @+1 {{expected non-function type}} +func @illegal_u32(tensor<16xu32>) -> () + +// ----- +// expected-error @+1 {{expected floating point literal}} +func @illegal_missing_fp_literal(tensor<2x!quant.uniform>) -> () + +// ----- +// expected-error @+1 {{expected floating point literal}} +func @illegal_missing_fp_literal2(tensor<2x!quant.uniform>) -> () + +// ----- +// expected-error @+1 {{expected floating point literal}} +func @illegal_missing_fp_literal3(tensor<2x!quant.uniform>) -> () + +// ----- +// expected-error @+1 {{expected ','}} +func @illegal_missing_comma(tensor<2x!quant.uniform>) -> () + +// ----- +// expected-error @+1 {{expected ','}} +func @illegal_missing_comma2(tensor<2x!quant.uniform>) -> () + +// ----- +// expected-error @+1 {{unknown quantized type multiform}} +func @illegal_missing_uniform(tensor<2x!quant.multiform>) -> () + +// ----- +// expected-error @+1 {{invalid kind of type specified}} +func @illegal_invalid_type_i8(tensor<2x!quant.uniform>) -> () + +// ----- +// expected-error @+1 {{invalid kind of type specified}} +func @illegal_invalid_type_i32(tensor<2x!quant.uniform>) -> () + +// ----- +// expected-error @+1 {{expected non-function type}} +func @illegal_invalid_shape(tensor<-1xf32>) -> () + +// ----- +// expected-error @+1 {{illegal storage type prefix}} +func @illegal_invalid_type_i32(tensor<2x!quant.uniform>) -> () Index: mlir/test/Dialect/Tosa/inlining.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/inlining.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt %s -inline | FileCheck %s + +// These tests verify that regions with operations from TOSA dialect +// can be inlined. + +// CHECK-LABEL: func @inlined_into_if +func @inlined_into_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = call @inlined_if_fn(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @inlined_if_fn +func @inlined_if_fn(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: tosa.cond_if + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = call @add(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = call @sub(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} +func @add(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { + return %arg0 : tensor +} +func @sub(%arg0: tensor, %arg1: tensor) -> tensor attributes {sym_visibility = "private"} { + return %arg1 : tensor +} + +// ----- + +// CHECK-LABEL: func @inlined_into_while +func @inlined_into_while(%arg0: tensor, %arg1: tensor, %arg2: tensor<10xi32>) -> tensor<10xi32> { + %0 = call @inlined_while_fn(%arg0, %arg1, %arg2) : (tensor, tensor, tensor<10xi32>) -> tensor<10xi32> + return %0 : tensor<10xi32> +} + +// CHECK-LABEL: func @inlined_while_fn +func @inlined_while_fn(%arg0: tensor, %arg1: tensor, %arg2: tensor<10xi32>) -> tensor<10xi32> { + // CHECK: tosa.while_loop + %1:3 = "tosa.while_loop"(%arg0, %arg1, %arg2) ( { + ^bb0(%arg5: tensor, %arg3: tensor, %arg4: tensor<10xi32>): // no predecessors + %2 = call @while_cond_40(%arg5, %arg3, %arg4) : (tensor, tensor, tensor<10xi32>) -> tensor + "tosa.yield"(%2) : (tensor) -> () + }, { + ^bb0(%arg5: tensor, %arg3: tensor, %arg4: tensor<10xi32>): // no predecessors + %2:3 = call @while_body_50(%arg5, %arg3, %arg4) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) + "tosa.yield"(%2#0, %2#1, %2#2) : (tensor, tensor, tensor<10xi32>) -> () + }) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) + return %1#2 : tensor<10xi32> +} +func @while_body_50(%arg0: tensor, %arg1: tensor, %arg2: tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) attributes {sym_visibility = "private"} { + // CHECK: tosa.const + // CHECK-NEXT: tosa.add + %0 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + %1 = "tosa.add"(%arg0, %0) : (tensor, tensor) -> tensor + %2 = "tosa.add"(%arg2, %1) : (tensor<10xi32>, tensor) -> tensor<10xi32> + return %1, %arg1, %2 : tensor, tensor, tensor<10xi32> +} +func @while_cond_40(%arg0: tensor, %arg1: tensor, %arg2: tensor<10xi32>) -> tensor attributes {sym_visibility = "private"} { + // CHECK: tosa.greater_equal + // CHECK-NEXT: tosa.logical_not + %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor, tensor) -> tensor + %1 = "tosa.logical_not"(%0) : (tensor) -> tensor + return %1 : tensor +} Index: mlir/test/Dialect/Tosa/ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/ops.mlir @@ -0,0 +1,512 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + + +// ----- +// CHECK-LABEL: argmax +func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> { + %0 = "tosa.argmax"(%arg0) {axis = 1 : i64} : (tensor<14x19xf32>) -> tensor<14xi32> + return %0 : tensor<14xi32> +} + +// ----- +// CHECK-LABEL: avg_pool2d +func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + return %0 : tensor<1x7x7x9xf32> +} + +// ----- +// CHECK-LABEL: conv2d +func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %0 : tensor<1x4x4x8xf32> +} + +// ----- +// CHECK-LABEL: depthwise_conv2d +func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32> + return %2 : tensor<1x4x4x8xf32> +} + +// ----- +// CHECK-LABEL: fully_connected +func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) : (tensor<14x19xf32>, tensor<19x28xf32>, tensor<28xf32>) -> tensor<14x28xf32> + return %0 : tensor<14x28xf32> +} + +// ----- +// CHECK-LABEL: test_matmul +func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32> + return %0 : tensor<14x28xf32> +} + +// ----- +// CHECK-LABEL: max_pool2d +func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- +/// CHECK-LABEL: transpose_conv2d +func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 32, 32, 16], stride = [1, 1]} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- +// CHECK-LABEL: clamp +func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.clamp"(%arg0) {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: relu +func @test_relu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reluN"(%arg0) {max_fp = 3.40282347E+38 : f32, max_int = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + + +// ----- +// CHECK-LABEL: sigmoid +func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.sigmoid"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: tanh +func @test_tanh(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.tanh"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: add +func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.add"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: arithmetic_right_shift +func @test_arithmetic_right_shift(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.arithmetic_right_shift"(%arg0, %arg1) { round = false } : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + + +// ----- +// CHECK-LABEL: bitwise_and +func @test_bitwise_and(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.bitwise_and"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: bitwise_or +func @test_bitwise_or(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.bitwise_or"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: bitwise_xor +func @test_bitwise_xor(%arg0: tensor<13x21x1xi32>, %arg1: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.bitwise_xor"(%arg0, %arg1) : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_and +func @test_logical_and(%arg0: tensor<13x21x3xi1>, %arg1: tensor<13x21x1xi1>) -> tensor<13x21x3xi1> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.logical_and"(%arg0, %arg1) : (tensor<13x21x3xi1>, tensor<13x21x1xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: logical_left_shift +func @test_logical_left_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.logical_left_shift"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_right_shift +func @test_logical_right_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x21x1xi32>) -> tensor<13x21x3xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.logical_right_shift"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x21x1xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: logical_or +func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.logical_or"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: logical_xor +func @test_logical_xor(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) -> tensor<13x21x3xi1> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.logical_xor"(%arg0, %arg1) : (tensor<13x1x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: maximum +func @test_max(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.maximum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: minimum +func @test_min(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.minimum"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<1x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: mul +func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: pow +func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.pow"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x1xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: sub +func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.sub"(%arg0, %arg1) : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: table +func @main(%arg0: tensor<64xi32>, %arg1: tensor<513x!quant.uniform>) -> tensor<64x!quant.uniform> attributes {tf.entry_function = {inputs = "placeholder_0", outputs = "result"}} { + %0 = "tosa.table"(%arg0, %arg1) : (tensor<64xi32>, tensor<513x!quant.uniform>) -> tensor<64x!quant.uniform> + return %0 : tensor<64x!quant.uniform> +} + +// ----- +// CHECK-LABEL: abs +func @test_abs(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.abs"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: bitwise_not +func @test_bitwise_not(%arg0: tensor<13x21x1xi32>) -> tensor<13x21x1xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.bitwise_not"(%arg0) : (tensor<13x21x1xi32>) -> tensor<13x21x1xi32> + return %0 : tensor<13x21x1xi32> +} + +// ----- +// CHECK-LABEL: ceil +func @test_ceil(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.ceil"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: clz +func @test_clz(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.clz"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: exp +func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: floor +func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.floor"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: log +func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: logical_not +func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.logical_not"(%arg0) : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1> + return %0 : tensor<1x21x3xi1> +} + +// ----- +// CHECK-LABEL: negate +func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.negate"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: reciprocal +func @test_reciprocal(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reciprocal"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: rsqrt +func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.rsqrt"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: select +func @test_select(%arg0: tensor<1x1x1xi1>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.select"(%arg0, %arg1, %arg2) : (tensor<1x1x1xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + + +// ----- +// CHECK-LABEL: equal +func @test_equal(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xi1> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.equal"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: greater +func @test_greater(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.greater"(%arg0, %arg1) : (tensor<13x21x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: greater_equal +func @test_greater_equal(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xi1> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.greater_equal"(%arg0, %arg1) : (tensor<13x1x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> + return %0 : tensor<13x21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_all +func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reduce_all"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xi1>) -> tensor<21x3xi1> + return %1 : tensor<21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_any +func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xi1>) -> tensor<1x21x3xi1> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xi1>) -> tensor<21x3xi1> + return %1 : tensor<21x3xi1> +} + +// ----- +// CHECK-LABEL: reduce_max +func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_min +func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_product +func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: reduce_sum +func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<1x21x3xf32> + %1 = "tosa.reshape"(%0) {new_shape = [21, 3]} : (tensor<1x21x3xf32>) -> tensor<21x3xf32> + return %1 : tensor<21x3xf32> +} + +// ----- +// CHECK-LABEL: concat +func @test_concat(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<26x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<26x21x3xf32> + return %0 : tensor<26x21x3xf32> +} + +// ----- +// CHECK-LABEL: pad +func @test_pad(%arg0: tensor<13x21x3xf32>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.pad"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<3x2xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: reshape +func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reshape"(%arg0) {new_shape = [1, 819]} : (tensor<13x21x3xf32>) -> tensor<1x819xf32> + return %0 : tensor<1x819xf32> +} + +// ----- +// CHECK-LABEL: reverse +func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: slice +func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.slice"(%arg0) {start = [6, 8, 0], size = [4, 11, 1]} : (tensor<13x21x3xf32>) -> tensor<4x11x1xf32> + return %0 : tensor<4x11x1xf32> +} + +// ----- +// CHECK-LABEL: tile +func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<39x21x6xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.tile"(%arg0) {multiples = [3, 1, 2]} : (tensor<13x21x3xf32>) -> tensor<39x21x6xf32> + return %0 : tensor<39x21x6xf32> +} + +// ----- +// CHECK-LABEL: transpose +func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xf32> + return %1 : tensor<3x13x21xf32> +} + +// ----- +// CHECK-LABEL: gather +func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<26xi32>) -> tensor<26x21x3xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.gather"(%arg0, %arg1) {axis = 0 : i64, batch_dims = 0 : i64} : (tensor<13x21x3xi32>, tensor<26xi32>) -> tensor<26x21x3xi32> + return %0 : tensor<26x21x3xi32> +} + +// Test TBD +// DISABLED-CHECK-LABEL: resize +//func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { +// %0 = "tosa.const"() {value = dense<64> : tensor<2xi32>} : () -> tensor<2xi32> +// %1 = "tosa.resize"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x32x32x8xf32>, tensor<2xi32>) -> tensor<1x64x64x8xf32> +// return %1 : tensor<1x64x64x8xf32> +//} + +// ----- +// CHECK-LABEL: cast +func @test_cast1(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: cast2 +func @test_cast2(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: cast3 +func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.cast"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: rescale +func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> attributes {tf.entry_function = {inputs = "placeholder_0,placeholder_1", outputs = "result"}} { + %0 = "tosa.rescale"(%arg0) {double_round = false, input_zp = 127 : i32, multiplier = [1073741824 : i32], output_zp = -1 : i32, per_channel = false, scale32 = true, shift = [30 : i32]} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: const +func @test_const(%arg0 : index) -> tensor<4xi32> { + %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- +// CHECK-LABEL: identity +func @test_identity(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.identity"(%arg0) : (tensor<13x21x3xi32>) -> tensor<13x21x3xi32> + return %0 : tensor<13x21x3xi32> +} + +// ----- +// CHECK-LABEL: identityn +func @test_identityn(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> { + %0:2 = "tosa.identityn"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32>) -> (tensor<1xi32>, tensor<1xi32>) + return %0#0 : tensor<1xi32> +} + +// ----- +// CHECK-LABEL: placeholder +func @test_placeholder() -> tensor<1xi32> { + %0 = "tosa.placeholder"() : () -> tensor<1xi32> + return %0 : tensor<1xi32> +} + +// ----- +// CHECK-LABEL: cond_if +func @test_cond_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ( { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = "tosa.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): // no predecessors + %1 = "tosa.sub"(%arg3, %arg4) : (tensor, tensor) -> tensor + "tosa.yield"(%1) : (tensor) -> () + }) : (tensor, tensor, tensor) -> tensor + return %0 : tensor +} + +// ----- +// CHECK-LABEL: while_loop +func @test_while_loop(%arg0: tensor<10xi32>, %arg1: tensor) { + %0 = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + %1:3 = "tosa.while_loop"(%0, %0, %arg0) ( { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): // no predecessors + %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor, tensor) -> tensor + %3 = "tosa.logical_not"(%2) : (tensor) -> tensor + "tosa.yield"(%3) : (tensor) -> () + }, { + ^bb0(%arg2: tensor, %arg3: tensor, %arg4: tensor<10xi32>): // no predecessors + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor + %4 = "tosa.reshape"(%2) {new_shape = [1]} : (tensor) -> tensor<1xi32> + %5 = "tosa.add"(%arg4, %4) : (tensor<10xi32>, tensor<1xi32>) -> tensor<10xi32> + %6 = "tosa.add"(%arg2, %2) : (tensor, tensor) -> tensor + "tosa.yield"(%6, %3, %5) : (tensor, tensor, tensor<10xi32>) -> () + }) : (tensor, tensor, tensor<10xi32>) -> (tensor, tensor, tensor<10xi32>) + return +} Index: mlir/test/Dialect/Tosa/quant-ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/quant-ops.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + +//===-----------------------------------------------------------------------===// +// Tests to check for parsing of legal TOSA operators with quantized types +// and quantization_info attributes. +//===-----------------------------------------------------------------------===// + +// ----- +// CHECK-LABEL: avg_pool2d +func @test_avg_pool2d(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + return %0 : tensor<1x7x7x9x!quant.uniform> +} + +// ----- +// CHECK-LABEL: conv2d +func @test_conv2d(%arg0: tensor<1x4x4x4x!quant.uniform>, %arg1: tensor<8x5x5x4x!quant.uniform:f32, 0.015746051445603371>>, %arg2: tensor<8x!quant.uniform>) -> tensor<1x4x4x8xi32> { + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 1], pad = [4, 4, 2, 2], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32}, stride = [1, 1]} : (tensor<1x4x4x4x!quant.uniform>, tensor<8x5x5x4x!quant.uniform:f32, 0.015746051445603371>>, tensor<8x!quant.uniform>) -> tensor<1x4x4x8xi32> + return %0 : tensor<1x4x4x8xi32> +} + +// ----- +// CHECK-LABEL: depthwise_conv2d +func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4x!quant.uniform>, %arg1: tensor<1x1x4x2x!quant.uniform>, %arg2: tensor<8x!quant.uniform>) -> tensor<1x4x4x8xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %2 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1,1], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32 }, pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<1x4x4x4x!quant.uniform>, tensor<1x1x4x2x!quant.uniform>, tensor<8x!quant.uniform>) -> tensor<1x4x4x8xi32> + return %2 : tensor<1x4x4x8xi32> +} + +// ----- +// CHECK-LABEL: fully_connected +func @test_fully_connected(%arg0: tensor<14x19x!quant.uniform>, %arg1: tensor<19x28x!quant.uniform>, %arg2: tensor<28x!quant.uniform>) -> tensor<14x28x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.fully_connected"(%arg0, %arg1, %arg2) {quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32 } } : (tensor<14x19x!quant.uniform>, tensor<19x28x!quant.uniform>, tensor<28x!quant.uniform>) -> tensor<14x28x!quant.uniform> + return %0 : tensor<14x28x!quant.uniform> +} + +// ----- +// CHECK-LABEL: test_matmul +func @test_matmul(%arg0: tensor<14x19x!quant.uniform>, %arg1: tensor<19x28x!quant.uniform>) -> tensor<14x28xi32> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0,placeholder_1:0", outputs = "result"}} { + %0 = "tosa.matmul"(%arg0, %arg1) { quantization_info = {a_zp = 1 : i32, b_zp = 2: i32}}: (tensor<14x19x!quant.uniform>, tensor<19x28x!quant.uniform>) -> tensor<14x28xi32> + return %0 : tensor<14x28xi32> +} + +// ----- +// CHECK-LABEL: transpose_conv2d +func @test_transpose_conv2d(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1: tensor<16x1x1x8x!quant.uniform>, %arg2: tensor<16x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [1, 32, 32, 16], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32 }, stride = [1, 1]} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform>, tensor<16x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> + return %0 : tensor<1x32x32x16x!quant.uniform> +} + +// ----- +// CHECK-LABEL: negate +func @test_negate(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.negate"(%arg0) { quantization_info = {input_zp = -1 : i32, output_zp = 1 : i32 }}: (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: pad +func @test_pad(%arg0: tensor<13x21x3x!quant.uniform>, %arg1: tensor<3x2xi32>) -> tensor<13x21x3x!quant.uniform> attributes {tf.entry_function = {control_outputs = "", inputs = "placeholder_0:0", outputs = "result"}} { + %0 = "tosa.pad"(%arg0, %arg1) { quantization_info = {input_zp = -1 : i32 }}: (tensor<13x21x3x!quant.uniform>, tensor<3x2xi32>) -> tensor<13x21x3x!quant.uniform> + return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- +// CHECK-LABEL: rescale +func @test_rescale(%arg0: tensor<1x32x32x16xi32>) -> tensor<1x32x32x16x!quant.uniform> { + %0 = "tosa.rescale"(%arg0) {double_round = true, input_zp = 0 : i32, multiplier = [1723896117 : i32], output_zp = 0 : i32, per_channel = false, scale32 = true, shift = [39 : i32]} : (tensor<1x32x32x16xi32>) -> tensor<1x32x32x16x!quant.uniform> + return %0 : tensor<1x32x32x16x!quant.uniform> +} Index: mlir/test/Dialect/Tosa/quant-test.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/quant-test.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt --tosa-test-quant-utils %s | FileCheck %s + +// ----- +// CHECK-LABEL: test_build_qtype +func @test_build_qtype(%arg0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> { + // CHECK: tosa.negate + %0 = "tosa.negate"(%arg0) : (tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>>) -> tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> + return %0 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489:128>> +} + +// ----- +// CHECK-LABEL: test_build_mult_and_shift +func @test_build_mult_and_shift(%arg0: tensor<1x32x32x8x!quant.uniform>, %arg1 : tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, %arg2 : tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> { + // CHECK: tosa.conv2d + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [1, 1, 2, 2], dilation = [2, 1], stride = [1, 1], quantization_info = {input_zp = -1 : i32, weight_zp = 0 : i32}} : (tensor<1x32x32x8x!quant.uniform>, tensor<16x1x1x8x!quant.uniform:f32, 0.015680249780416489>>, tensor<16xi32>) -> tensor<1x32x32x16x!quant.uniform> + return %0 : tensor<1x32x32x16x!quant.uniform> + +} Index: mlir/test/Dialect/Tosa/types.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Tosa/types.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s + +//===-----------------------------------------------------------------------===// +// Tests to check for parsing of legal TOSA data types +//===-----------------------------------------------------------------------===// + +// ----- +// CHECK-LABEL: @test_bool +func @test_bool(tensor<2xi1>) -> () + +// ----- +// CHECK-LABEL: @test_int32 +func @test_int32(tensor<2xi32>) -> () + +// ----- +// CHECK-LABEL: @test_i48 +func @test_i48(tensor<2xi48>) -> () + +// ----- +// CHECK-LABEL: @test_i64 +func @test_i64(tensor<2xi64>) -> () + +// ----- +// CHECK-LABEL: @test_f16 +func @test_f16(tensor<2xf16>) -> () + +// ----- +// CHECK-LABEL: @test_f32 +func @test_f32(tensor<2xf32>) -> () + +// ----- +// CHECK-LABEL: @test_bf16 +func @test_bf16(tensor<2xbf16>) -> () + +// ----- +// CHECK-LABEL: @test_quant1 +func @test_quant1(tensor<2x!quant.uniform>) -> () + +// ----- +// CHECK-LABEL: @test_quant2 +func @test_quant2(tensor<2x!quant.uniform>) -> () + +// ----- +// CHECK-LABEL: @test_quant3 +func @test_quant3(tensor<2x!quant.uniform>) -> () + +// ----- +// CHECK-LABEL: @test_quant4 +func @test_quant4(tensor<2x!quant.uniform>) -> () + +// ----- +// CHECK-LABEL: @test_perchannel_quant +func @test_perchannel_quant(tensor<16x3x3x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>) -> ()