Index: mlir/include/mlir/Dialect/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/CMakeLists.txt +++ mlir/include/mlir/Dialect/CMakeLists.txt @@ -13,4 +13,5 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) Index: mlir/include/mlir/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) Index: mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaOps.h.inc -gen-op-decls) +mlir_tablegen(TosaOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTosaOpsIncGen) + +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaStructs.h.inc -gen-struct-attr-decls) +mlir_tablegen(TosaStructs.cpp.inc -gen-struct-attr-defs) +add_public_tablegen_target(MLIRTosaStructsIncGen) + + +set(LLVM_TARGET_DEFINITIONS TosaInterfaces.td) +mlir_tablegen(TosaInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(TosaInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRTosaInterfaceIncGen) + +add_mlir_doc(TosaOps -gen-op-doc TosaOps Dialects/) Index: mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td @@ -0,0 +1,34 @@ +//===-- TosaInterfaces.td - TOSA dialect interfaces *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the dialect op interfaces for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef TOSA_OP_INTERFACES +#define TOSA_OP_INTERFACES + +include "mlir/IR/OpBase.td" + +def TosaOpInterface : OpInterface<"TosaOp"> { + let description = [{ + Implements interfaces for general Tosa op utility + }]; + + let methods = [ + InterfaceMethod< + [{Returns the TOSA version.}], + "StringRef", "getTOSAVersion", (ins), [{ + return "0.20"; + }] + >, + ]; + +} + +#endif Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -0,0 +1,406 @@ +//===-- TosaOpBase.td - TOSA dialect op builders *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the operation builders for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + + +#ifdef TOSA_OP_BASE +#else +#define TOSA_OP_BASE + +// Quantization attributes used across TOSA operators. +def Tosa_UnaryOpQuantizationAttr : StructAttr<"UnaryOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"output_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Unary Ops."; +} + +def Tosa_ConvOpQuantizationAttr : StructAttr<"ConvOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>, + StructFieldAttr<"weight_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Convolution Ops."; +} + +def Tosa_MatMulOpQuantizationAttr : StructAttr<"MatMulOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"a_zp", I32Attr>, + StructFieldAttr<"b_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Convolution Ops."; +} + +def Tosa_PadOpQuantizationAttr : StructAttr<"PadOpQuantizationAttr", Tosa_Dialect, [ + StructFieldAttr<"input_zp", I32Attr>] > { + let description = "Attribute holding quantization information for Pad Ops."; +} + +// This builder is called on convolution operation types that need to create their +// OptionalAttr quantization_attr parameter. +def Tosa_ConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "Value":$filter, "Value":$bias, "ArrayAttr":$strides, "ArrayAttr":$dilations, "ArrayAttr":$padding), + [{ + $_state.addOperands(input); + $_state.addOperands(filter); + $_state.addOperands(bias); + $_state.addAttribute("strides", strides); + $_state.addAttribute("dilations", dilations); + $_state.addAttribute("padding", padding); + + auto quantattr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + filter); + if ( quantattr ) { + $_state.addAttribute("quantization_info", quantattr); + unsigned input_bits = input.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + unsigned weight_bits = filter.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto output_shape = output_type.dyn_cast().getShape(); + IntegerType acc_element_type; + if(input_bits == 16 && weight_bits == 8) { + acc_element_type = $_builder.getIntegerType(48); + } + else { + acc_element_type = $_builder.getI32Type(); + } + auto acc_type = RankedTensorType::get(output_shape, acc_element_type); + $_state.addTypes(acc_type); + } + else { + $_state.addTypes(output_type); + } + }]>; + +// A variant of ConvOpQuantInfo builder for transpose_conv op which has no bias parameter. +def Tosa_TransConvOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "Value":$filter, "Value":$bias, "ArrayAttr":$strides, "ArrayAttr":$dilations, "ArrayAttr":$outpad, "ArrayAttr":$output_shape), + [{ + $_state.addOperands(input); + $_state.addOperands(filter); + $_state.addOperands(bias); + $_state.addAttribute("strides", strides); + $_state.addAttribute("dilations", dilations); + $_state.addAttribute("outpad", outpad); + $_state.addAttribute("output_shape", output_shape); + auto quantattr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + filter); + + if ( quantattr ) { + $_state.addAttribute("quantization_info", quantattr); + unsigned input_bits = input.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + unsigned weight_bits = filter.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto output_shape = output_type.dyn_cast().getShape(); + IntegerType acc_element_type; + if(input_bits == 16 && weight_bits == 8) { + acc_element_type = $_builder.getIntegerType(48); + } + else { + acc_element_type = $_builder.getI32Type(); + } + auto acc_type = RankedTensorType::get(output_shape, acc_element_type); + $_state.addTypes(acc_type); + } + else { + $_state.addTypes(output_type); + } + }]>; + +// This builder is called on the fully connected operation that needs to create its +// OptionalAttr quantization_attr parameter. +def Tosa_FCOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "Value":$filter, "Value":$bias), + [{ + $_state.addOperands(input); + $_state.addOperands(filter); + $_state.addOperands(bias); + auto quantattr = mlir::tosa::buildConvOpQuantizationAttr($_builder, + input, + filter); + if ( quantattr ) { + $_state.addAttribute("quantization_info", quantattr); + unsigned input_bits = input.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + unsigned weight_bits = filter.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto output_shape = output_type.dyn_cast().getShape(); + IntegerType acc_element_type; + if(input_bits == 16 && weight_bits == 8) { + acc_element_type = $_builder.getIntegerType(48); + } + else { + acc_element_type = $_builder.getI32Type(); + } + auto acc_type = RankedTensorType::get(output_shape, acc_element_type); + $_state.addTypes(acc_type); + } + else { + $_state.addTypes(output_type); + } + }]>; + +// The MatMul quantization info builder is similar to FCOpQuantInfoBuilder, but does not +// have a bias field. +def Tosa_MatMulOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$a, "Value":$b), + [{ + $_state.addOperands(a); + $_state.addOperands(b); + auto quantattr = mlir::tosa::buildMatMulOpQuantizationAttr($_builder, a, b); + + if ( quantattr ) { + $_state.addAttribute("quantization_info", quantattr); + unsigned input_bits = a.getType().dyn_cast() + .getElementType().dyn_cast() + .getStorageTypeIntegralWidth(); + auto output_shape = output_type.dyn_cast().getShape(); + IntegerType acc_element_type; + if(input_bits == 16) { + acc_element_type = $_builder.getIntegerType(48); + } + else { + acc_element_type = $_builder.getI32Type(); + } + auto acc_type = RankedTensorType::get(output_shape, acc_element_type); + $_state.addTypes(acc_type); + } + else { + $_state.addTypes(output_type); + } + }]>; + +// This builder is for pool operation types that need to create their +// OptionalAttr quantization_attr parameter. +def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "ArrayAttr":$kernel_size, "ArrayAttr":$strides, "ArrayAttr":$padding), + [{ + $_state.addOperands(input); + $_state.addAttribute("kernel_size", kernel_size); + $_state.addAttribute("strides", strides); + $_state.addAttribute("padding", padding); + auto quantattr = mlir::tosa::buildUnaryOpQuantizationAttr($_builder, + input, + output_type); + if ( quantattr ) + $_state.addAttribute("quantization_info", quantattr); + $_state.types.push_back(output_type); + }]>; + +// This builder is called on single-parameter unary types that need to create their +// OptionalAttr quantization_attr parameter. +def Tosa_UnaryOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input), + [{ + $_state.addOperands(input); + auto quantattr = mlir::tosa::buildUnaryOpQuantizationAttr($_builder, + input, + output_type); + if ( quantattr ) + $_state.addAttribute("quantization_info", quantattr); + $_state.types.push_back(output_type); + }]>; + +// This builder is called on the TOSA pad operator that needs to create its own +// OptionalAttr quantization_attr parameter. +def Tosa_PadOpQuantInfoBuilder : OpBuilderDAG< + (ins "Type":$output_type, "Value":$input, "Value":$paddings), + [{ + $_state.addOperands(input); + $_state.addOperands(paddings); + auto quantattr = mlir::tosa::buildPadOpQuantizationAttr($_builder, + input); + if ( quantattr ) + $_state.addAttribute("quantization_info", quantattr); + $_state.types.push_back(output_type); + }]>; + +// This builder is called on elementwise binary broadcastable operators. +def Tosa_BroadcastableBinaryBuilder : OpBuilderDAG< + (ins "Value":$lhs, "Value":$rhs), + [{ + auto result_type = + OpTrait::util::getBroadcastedType(lhs.getType(), rhs.getType()); + if (!result_type) + mlir::emitError($_state.location, "Operands are not broadcastable"); + $_state.addOperands(lhs); + $_state.addOperands(rhs); + $_state.types.push_back(result_type); + }]>; + +// Tosa Operator Definitions. +class Tosa_Op traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return ""; } // TBD + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ElemwiseUnaryOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Elemwise"; } + static StringRef getTOSAOpSubtype() { return "Unary"; } + }]; + +} + +class Tosa_ElemwiseBinaryOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Elemwise"; } + static StringRef getTOSAOpSubtype() { return "Binary"; } + }]; + +} + +class Tosa_ElemwiseCompareOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Elemwise"; } + static StringRef getTOSAOpSubtype() { return "Compare"; } + }]; + +} + +class Tosa_ElemwiseTernaryOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Elemwise"; } + static StringRef getTOSAOpSubtype() { return "Ternary"; } + }]; + +} + +class Tosa_DataLayoutOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "DataLayout"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_DataNodeOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "DataNode"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_AggregationOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Aggregation"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_TensorArgOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Tensor"; } + static StringRef getTOSAOpSubtype() { return "Arg"; } + }]; + +} + +class Tosa_TensorConvOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Tensor"; } + static StringRef getTOSAOpSubtype() { return "Conv"; } + }]; + +} + +class Tosa_TensorPoolOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Tensor"; } + static StringRef getTOSAOpSubtype() { return "Pool"; } + }]; + +} + +class Tosa_TensorImageOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Image"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ActivationOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Activation"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ReductionOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Reduction"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ImageOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Reduction"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +class Tosa_ConversionOp traits = []> : + Op { + + let extraClassDeclaration = [{ + static StringRef getTOSAOpType() { return "Conversion"; } + static StringRef getTOSAOpSubtype() { return ""; } // TBD + }]; + +} + +// Specify traits of operators. + +#endif // TOSA_OP_BASE Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -0,0 +1,61 @@ +//===-- TosaOps.h - TOSA dialect operation definitions *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_IR_TOSA_OPS_H +#define DIALECT_TOSA_IR_TOSA_OPS_H + +#include +#include + +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Tosa/IR/TosaTraits.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LLVM.h" + +#include "mlir/Dialect/Tosa/IR/TosaStructs.h.inc" + +namespace mlir { +namespace tosa { + +//===----------------------------------------------------------------------===// +// TOSA Dialect +//===----------------------------------------------------------------------===// +class TosaDialect : public Dialect { + +public: + explicit TosaDialect(MLIRContext *context); + + static StringRef getDialectNamespace() { return "tosa"; } + + /// Materialize a single constant operation from a given attribute value with + /// the desired resultant type. + Operation *materializeConstant(OpBuilder &builder, Attribute value, Type type, + Location loc) override; +}; + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" + +} // end namespace tosa +} // end namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc" + +#endif // TOSA_OPS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -0,0 +1,1686 @@ +//===-- TosaOps.td - TOSA dialect operation definitions *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the operation set for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOSA_OPS +#else +#define TOSA_OPS + +include "mlir/IR/OpBase.td" + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/Dialect/Tosa/IR/TosaInterfaces.td" + +include "mlir/Dialect/Tosa/IR/TosaTypesBase.td" + +//===----------------------------------------------------------------------===// +// The TOSA Dialect. +//===----------------------------------------------------------------------===// +def Tosa_Dialect : Dialect { + let name = "tosa"; + + let description = [{ + The Tosa dialect. + + Invariants: + + * All values are of Tensor type (in particular, scalars are + represented using zero-dimentional tensors); + }]; + + let cppNamespace = "mlir::tosa"; +} + +#ifdef TOSA_OP_BASE +#else +include "mlir/Dialect/Tosa/IR/TosaOpBase.td" +#endif + + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.2 +// Operator Class: Tensor Data Engine Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: argmax +//===----------------------------------------------------------------------===// +def Tosa_ArgMaxOp : Tosa_TensorArgOp<"argmax", [NoSideEffect]> { + let summary = "Perform argmax on the input."; + + let description = [{ + This returns the index with the largest value across the given axis of the input tensor. + }]; + + let arguments = (ins + Tosa_Tensor: $input, + I64Attr: $axis); + + let results = (outs Tosa_Tensor: $output); + +} + +//===----------------------------------------------------------------------===// +// Operator: avg_pool2d +//===----------------------------------------------------------------------===// +def Tosa_AvgPool2dOp : Tosa_TensorPoolOp<"avg_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs an average pooling over the given input tensor. A sliding window of size + given by is passed over the input tensor, with the mean value being placed + in the output tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + + Confined]>:$kernel_size, + Confined]>:$strides, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: conv2d +//===----------------------------------------------------------------------===// +def Tosa_Conv2DOp : Tosa_TensorConvOp<"conv2d", [NoSideEffect]> { + let summary = [{ + 2D Convolution Operator + }]; + + let description = [{ + Performs a 2D convolution over the given tensor input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + DefaultValuedAttr:$strides, + DefaultValuedAttr:$dilations, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: conv3d +//===----------------------------------------------------------------------===// +def Tosa_Conv3DOp : Tosa_TensorConvOp<"conv3d", [NoSideEffect]> { + let summary = [{ + 3D Convolution operator + }]; + + let description = [{ + Performs a 3D convolution over the given input tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + DefaultValuedAttr]>, "{1, 1, 1, 1, 1}">:$strides, + DefaultValuedAttr]>, "{1, 1, 1, 1, 1}">:$dilations, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: depthwise_conv2d +//===----------------------------------------------------------------------===// +def Tosa_DepthwiseConv2DOp : Tosa_TensorConvOp<"depthwise_conv2d", [NoSideEffect]> { + let summary = [{ + Depthwise 2D Convolution operator + }]; + + let description = [{ + Performs 2D convolutions separately over each channel of the given tensor input, using the weight tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + DefaultValuedAttr:$strides, + DefaultValuedAttr:$dilations, + DefaultValuedAttr:$padding, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_ConvOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: fully_connected +//===----------------------------------------------------------------------===// +def Tosa_FullyConnectedOp : Tosa_TensorConvOp<"fully_connected", [NoSideEffect]> { + let summary = "Fully Connected operator"; + + let description = [{ + Performs a fully connected network. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_TensorOfOrNone<[Tosa_AnyNumber]>:$bias, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_FCOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: matmul +//===----------------------------------------------------------------------===// +def Tosa_MatMulOp : Tosa_TensorConvOp<"matmul", [NoSideEffect]> { + let summary = "Matrix multiplication with bias"; + + let description = [{ + Performs a two dimensional matrix multiplication. This allows both inputs to be activations, + rather than reserving weights as an attribute in the FULLY_CONNECTED operator. + }]; + + let arguments = (ins + Tosa_Tensor:$a, + Tosa_Tensor:$b, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$c + ); + + let builders = [Tosa_MatMulOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: max_pool2d +//===----------------------------------------------------------------------===// +def Tosa_MaxPool2dOp : Tosa_TensorPoolOp<"max_pool2d", [NoSideEffect]> { + let summary = "Performs max pooling on the input."; + + let description = [{ + This performs a max pooling over the given input tensor. A sliding window of size given by + is passed over the input tensor, with the maximum value being placed in the + output tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + + Confined]>:$kernel_size, + Confined]>:$strides, + DefaultValuedAttr:$padding + ); + + let results = (outs + Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose_conv2d +//===----------------------------------------------------------------------===// +def Tosa_TransposeConv2DOp : Tosa_TensorConvOp<"transpose_conv2d", [NoSideEffect]> { + let summary = [{ + Transpose 2D Convolution operator. + }]; + + let description = [{ + Performs a 2D transposed convolution over the given tensor input, using the weights tensor. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Tensor:$filter, + Tosa_Tensor:$bias, + + I64ArrayAttr:$strides, + I64ArrayAttr:$dilations, + I64ArrayAttr:$outpad, + I64ArrayAttr:$output_shape, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_TransConvOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.3 +// Operator Class: Activation Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: clamp +//===----------------------------------------------------------------------===// +def Tosa_ClampOp : Tosa_ActivationOp<"clamp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes clamp(features, min, max)."; + + let description = [{ + Clamp to an arbitrary minimum and maximum value. Note that the maximum and minimum values are + specified as signed quantized values, no scaling happens before or after this operation. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$min_int, + I64Attr:$max_int, + F32Attr:$min_fp, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reluN +//===----------------------------------------------------------------------===// +def Tosa_ReluNOp : Tosa_ActivationOp<"reluN", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes rectified linear: `max(features, N)`."; + + let description = [{ + ReLU with a scalar maximum value. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$max_int, + F32Attr:$max_fp + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: sigmoid +//===----------------------------------------------------------------------===// +def Tosa_SigmoidOp : Tosa_ActivationOp<"sigmoid", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise sigmoid of input."; + + let description = [{ + Sigmoid function: output = 1 / (1 + exp(-input)) + For quantized integer data types, the TABLE operator should be used instead with the following definition. + The sigmoid table has 513 entries each of 16-bit precision and covering the input range -16.0 to +16.0 + in steps of 1/16. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: tanh +//===----------------------------------------------------------------------===// +def Tosa_TanhOp : Tosa_ActivationOp<"tanh", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes elementwise hyperbolic tangent of input"; + + let description = [{ + Parameterized hyperbolic tangent. + For quantized integer data types, the TABLE operator should be used instead with the following definition. + The tanh_table has 513 entries each of 16-bit precision and covering the input range -8.0 to +8.0 in steps of 1/32. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.4 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise binary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: add +//===----------------------------------------------------------------------===// +def Tosa_AddOp : Tosa_ElemwiseBinaryOp<"add", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise addition operator"; + + let description = [{ + Elementwise addition of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: arithmetic_right_shift +//===----------------------------------------------------------------------===// +def Tosa_ArithmeticRightShiftOp : Tosa_ElemwiseBinaryOp<"arithmetic_right_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Arithmetic Right Shift"; + + let description = [{ + Elementwise arithmetic right shift of input1 by the amount specified in input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_and +//===----------------------------------------------------------------------===// +def Tosa_BitwiseAndOp : Tosa_ElemwiseBinaryOp<"bitwise_and", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise AND operator"; + + let description = [{ + Elementwise bitwise AND of input tensor 0 and input tensor 1. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_or +//===----------------------------------------------------------------------===// +def Tosa_BitwiseOrOp : Tosa_ElemwiseBinaryOp<"bitwise_or", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise OR operator"; + + let description = [{ + Elementwise bitwise OR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_xor +//===----------------------------------------------------------------------===// +def Tosa_BitwiseXorOp : Tosa_ElemwiseBinaryOp<"bitwise_xor", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Bitwise XOR operator"; + + let description = [{ + Elementwise bitwise XOR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_and +//===----------------------------------------------------------------------===// +def Tosa_LogicalAndOp : Tosa_ElemwiseBinaryOp<"logical_and", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x AND y element-wise."; + + let description = [{ + Elementwise logical AND of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$lhs, + I1Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_left_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalLeftShiftOp : Tosa_ElemwiseBinaryOp<"logical_left_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Logical Left Shift"; + + let description = [{ + Elementwise left shift of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_right_shift +//===----------------------------------------------------------------------===// +def Tosa_LogicalRightShiftOp : Tosa_ElemwiseBinaryOp<"logical_right_shift", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise Logical Right Shift"; + + let description = [{ + Elementwise logical right shift of input1 by the amount specified in input2. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_or +//===----------------------------------------------------------------------===// +def Tosa_LogicalOrOp : Tosa_ElemwiseBinaryOp<"logical_or", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x OR y element-wise."; + + let description = [{ + Elementwise logical OR of input1 and input2. Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$lhs, + I1Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: logical_xor +//===----------------------------------------------------------------------===// +def Tosa_LogicalXorOp : Tosa_ElemwiseBinaryOp<"logical_xor", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of x XOR y element-wise."; + + let description = [{ + Elementwise logical XOR of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + I1Tensor:$lhs, + I1Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: maximum +//===----------------------------------------------------------------------===// +def Tosa_MaximumOp : Tosa_ElemwiseBinaryOp<"maximum", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise Maximum"; + + let description = [{ + Elementwise max of input1 and input2. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: minimum +//===----------------------------------------------------------------------===// +def Tosa_MinimumOp : Tosa_ElemwiseBinaryOp<"minimum", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Elementwise Minimum"; + + let description = [{ + Elementwise minimum of input tensor 0 and input tensor 1. Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: mul +//===----------------------------------------------------------------------===// +def Tosa_MulOp : Tosa_ElemwiseBinaryOp<"mul", [ResultsBroadcastableShape, NoSideEffect, Commutative]> { + let summary = "Multiplication operator"; + + let description = [{ + Elementwise multiplication (Hadamard product) of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: pow +//===----------------------------------------------------------------------===// +def Tosa_PowOp : Tosa_ElemwiseBinaryOp<"pow", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Computes the power of one value to another."; + + let description = [{ + Elementwise input tensor 0 value raised to the power of input 1 tensor. + Axis of size 1 will be broadcast, as necessary. + Rank of input tensors must match. + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + Tosa_Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: sub +//===----------------------------------------------------------------------===// +def Tosa_SubOp : Tosa_ElemwiseBinaryOp<"sub", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Elementwise subtraction operator"; + + let description = [{ + Elementwise subtraction of input tensor 0 and input tensor 1. + Axis of size 1 will be broadcast as necessary. + Rank of input tensors must match. + }]; + + let arguments = ( + ins Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs); + + let results = (outs Tosa_Tensor:$output); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: table +//===----------------------------------------------------------------------===// +def Tosa_TableOp : Tosa_ElemwiseBinaryOp<"table", [NoSideEffect]> { + let summary = "Table lookup op"; + + let description = [{ + Interpolated table lookup operation. Input values are scaled to create a fixed-point 9.7 value. + The high 9 bits are used to index into the table. The fractional bits are used to interpolate + based on the looked up value and the index+1 value in the table. The TABLE operator then returns + a 16.7 interpolated value. Note that there must be 513 values to handle the full range of inputs. + + The TABLE operator is expected to be used as follows: + • A RECALE node is expected before the TABLE operator to scale the input to a full int16_t range + for the table lookup + • If an int16_t result is required then follow the TABLE operator with a RESCALE with a right + shift of 7 + • If an int8_t result is required then follow the TABLE operator with a RESCALE with a right + shift of 15 + }]; + + let arguments = ( ins + Tosa_Tensor: $input, + Tosa_Tensor: $lut + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.5 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise unary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: abs +//===----------------------------------------------------------------------===// +def Tosa_AbsOp : Tosa_ElemwiseUnaryOp<"abs", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise abs op"; + + let description = [{ + Elementwise absolute value operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: bitwise_not +//===----------------------------------------------------------------------===// +def Tosa_BitwiseNotOp : Tosa_ElemwiseUnaryOp<"bitwise_not", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Bitwise NOT operator"; + + let description = [{ + Elementwise bitwise NOT of input tensor. + }]; + + let arguments = ( + ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output); + +} + +//===----------------------------------------------------------------------===// +// Operator: ceil +//===----------------------------------------------------------------------===// +def Tosa_CeilOp : Tosa_ElemwiseUnaryOp<"ceil", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise ceil op"; + + let description = [{ + Elementwise ceiling operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: clz +//===----------------------------------------------------------------------===// +def Tosa_ClzOp : Tosa_ElemwiseUnaryOp<"clz", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise count leading zero op"; + + let description = [{ + Elementwise count leading zeros operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: exp +//===----------------------------------------------------------------------===// +def Tosa_ExpOp : Tosa_ElemwiseUnaryOp<"exp", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise exp op"; + + let description = [{ + Elementwise e to the x operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: floor +//===----------------------------------------------------------------------===// +def Tosa_FloorOp : Tosa_ElemwiseUnaryOp<"floor", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise floor op"; + + let description = [{ + Elementwise floor operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: log +//===----------------------------------------------------------------------===// +def Tosa_LogOp : Tosa_ElemwiseUnaryOp<"log", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise log op"; + + let description = [{ + Elementwise natural logarithm operation + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: logical_not +//===----------------------------------------------------------------------===// +def Tosa_LogicalNotOp : Tosa_ElemwiseBinaryOp<"logical_not", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Returns the truth value of NOT x element-wise."; + + let description = [{ + Elementwise logical NOT of input. + }]; + + let arguments = (ins + I1Tensor:$x + ); + + let results = (outs + I1Tensor:$y + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: negate +//===----------------------------------------------------------------------===// +def Tosa_NegateOp : Tosa_ElemwiseUnaryOp<"negate", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise negate op"; + + let description = [{ + Elementwise negation operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + OptionalAttr:$quantization_info + ); + let results = (outs Tosa_Tensor:$output + ); + + let builders = [Tosa_UnaryOpQuantInfoBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: reciprocal +//===----------------------------------------------------------------------===// +def Tosa_ReciprocalOp : Tosa_ElemwiseUnaryOp<"reciprocal", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise reciprocal op"; + + let description = [{ + Elementwise reciprocal operation. For integer operation, a TABLE should be used + with the appropriate ranges. + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: rsqrt +//===----------------------------------------------------------------------===// +def Tosa_RsqrtOp : Tosa_ElemwiseUnaryOp<"rsqrt", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Elementwise 1/sqrt op"; + + let description = [{ + Elementwise reciprocal square root operation. For integer operation, a TABLE should be + used with the appropriate ranges. + }]; + + let arguments = (ins Tosa_Tensor:$input + ); + let results = (outs Tosa_Tensor:$output + ); +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.6 +// Operator Class: Elementwise unary/binary/ternary operators. +// Operator Subclass: Elementwise ternary ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: select +//===----------------------------------------------------------------------===// +def Tosa_SelectOp : Tosa_ElemwiseTernaryOp<"select", [NoSideEffect]> { + let summary = "Elementwise select operator"; + + let description = [{ + Elementwise select of the output based on a condition. + }]; + + let arguments = (ins + Tosa_Tensor:$condition, + Tosa_Tensor:$a, + Tosa_Tensor:$b + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.7 +// Operator Class: Logical Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: equal +//===----------------------------------------------------------------------===// +def Tosa_EqualOp : Tosa_ElemwiseCompareOp<"equal", [ResultsBroadcastableShape, Commutative, NoSideEffect]> { + let summary = "Returns the truth value of (x == y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: greater +//===----------------------------------------------------------------------===// +def Tosa_GreaterOp : Tosa_ElemwiseCompareOp<"greater", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Returns the truth value of (x > y) element-wise."; + + let description = [{ + Elementwise greater than comparison operation + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// Operator: greater_equal +//===----------------------------------------------------------------------===// +def Tosa_GreaterEqualOp : Tosa_ElemwiseCompareOp<"greater_equal", [ResultsBroadcastableShape, NoSideEffect]> { + let summary = "Returns the truth value of (x >= y) element-wise."; + + let description = [{ + Elementwise comparison operation + }]; + + let arguments = (ins + Tosa_Tensor:$lhs, + Tosa_Tensor:$rhs + ); + + let results = (outs + I1Tensor:$z + ); + + let builders = [Tosa_BroadcastableBinaryBuilder]; +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.8 +// Operator Class: Reduction Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: reduce_all +//===----------------------------------------------------------------------===// +def Tosa_ReduceAllOp : Tosa_ReductionOp<"reduce_all", [NoSideEffect]> { + let summary = [{ + Reduce All operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical AND operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_any +//===----------------------------------------------------------------------===// +def Tosa_ReduceAnyOp : Tosa_ReductionOp<"reduce_any", [NoSideEffect]> { + let summary = [{ + Reduce Any operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a logical OR operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_max +//===----------------------------------------------------------------------===// +def Tosa_ReduceMaxOp : Tosa_ReductionOp<"reduce_max", [NoSideEffect]> { + let summary = [{ + Reduce Max operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a maximum operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_min +//===----------------------------------------------------------------------===// +def Tosa_ReduceMinOp : Tosa_ReductionOp<"reduce_min", [NoSideEffect]> { + let summary = [{ + Reduce Min operator + }]; + + let description = [{ + Reduce a tensor along the given axis with a minimum operation + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_prod +//===----------------------------------------------------------------------===// +def Tosa_ReduceProdOp : Tosa_ReductionOp<"reduce_prod", [NoSideEffect]> { + let summary = [{ + Reduce Prod operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the product of the axis. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: reduce_sum +//===----------------------------------------------------------------------===// +def Tosa_ReduceSumOp : Tosa_ReductionOp<"reduce_sum", [NoSideEffect]> { + let summary = [{ + Reduce Sum operator + }]; + + let description = [{ + Reduce a tensor along the given axis by computing the sum of the axis. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.9 +// Operator Class: Data Layout / Memory Reinterpretation. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: concat +//===----------------------------------------------------------------------===// +def Tosa_ConcatOp : Tosa_DataLayoutOp<"concat", [NoSideEffect]> { + let summary = "Concatenates tensors along one dimension."; + + let description = [{ + Concatenate two tensors along a given axis. No data conversion happens during a concat operation. + }]; + + let arguments = (ins + Tosa_Tensor:$a, + Tosa_Tensor:$b, + I64Attr:$axis + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: pad +//===----------------------------------------------------------------------===// +def Tosa_PadOp : Tosa_DataLayoutOp<"pad", [NoSideEffect]> { + let summary = "Pads a tensor with zeros."; + + let description = [{ + Zero-pads a tensor along borders of each dimension. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + Tosa_Int32Or64Tensor:$paddings, + OptionalAttr:$quantization_info + ); + + let results = (outs + Tosa_Tensor:$output + ); + + let builders = [Tosa_PadOpQuantInfoBuilder]; + +} + +//===----------------------------------------------------------------------===// +// Operator: reshape +//===----------------------------------------------------------------------===// +def Tosa_ReshapeOp: Tosa_DataLayoutOp<"reshape", [ + NoSideEffect]> { + let summary = "Reshape operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with a new shape specified by the shape + argument. Reshape may operate on tensors of any rank. No data conversion happens during a reshape + operation. + }]; + + let arguments = ( + ins Tosa_Tensor:$input, + I64ArrayAttr:$shape); + + let results = (outs Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: reverse +//===----------------------------------------------------------------------===// +def Tosa_ReverseOp: Tosa_DataLayoutOp<"reverse", [ + NoSideEffect]> { + let summary = "Reverse operator"; + + let description = [{ + Returns a tensor with the same type/values as the input, with the data reversed along the given + axis. No data conversion happens during a reverse operation. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64Attr:$axis); + + let results = (outs + Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: slice +//===----------------------------------------------------------------------===// +def Tosa_SliceOp: Tosa_DataLayoutOp<"slice", [ + NoSideEffect]> { + let summary = "Slice operator"; + + let description = [{ + Extracts a slice of the input tensor 0 on the given axis, beginning at the start coordinates, + and extending for size elements in each direction. No data conversion happens during a slice operation. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64ArrayAttr:$begin, + I64ArrayAttr:$size + ); + + let results = (outs + Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: tile +//===----------------------------------------------------------------------===// +def Tosa_TileOp: Tosa_DataLayoutOp<"tile", [NoSideEffect]> { + let summary = "Tile operator"; + + let description = [{ + Replicates input 0 multiplies times along each dimension. + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I64ArrayAttr:$multiples); + + let results = (outs + Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: transpose +//===----------------------------------------------------------------------===// +def Tosa_TransposeOp : Tosa_DataLayoutOp<"transpose", [NoSideEffect]> { + let summary = "Transpose operator"; + + let description = [{ + Permutes the dimensions based on perm. + }]; + + let arguments = (ins + Tosa_Tensor:$x, + Tosa_Int32Or64Tensor:$perm + ); + + let results = ( + outs Tosa_Tensor:$y + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.10 +// Operator Class: Scatter/gather Operations. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: gather +//===----------------------------------------------------------------------===// +def Tosa_GatherOp : Tosa_AggregationOp<"gather", [NoSideEffect]> { + let summary = [{ + Gather operation + }]; + + let description = [{ + Generate a tensor for which each element in the output is a subtensor of the values tensor along + the given axis, based on the value of indices. + }]; + + let arguments = (ins + Tosa_Tensor:$params, + Tosa_Int32Or64Tensor:$indices, + I64Attr:$axis, + + DefaultValuedAttr:$batch_dims + ); + + let results = (outs + Tosa_Tensor:$z + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.11 +// Operator Class: Image Frontend Functions. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: resize +//===----------------------------------------------------------------------===// +def Tosa_ResizeOp : Tosa_ImageOp<"resize", [NoSideEffect]> { + + let summary = "Resize operation, supports various resize/upsample modes"; + + let description = [{ + Resizes a tensor. Resize is only allowed in the H and W dimensions. In expected use, + stride_y is approximately (IH< { + + let summary = "Cast operation"; + + let description = [{ + Performs a set of permissible cast operations + Mode Input Output + --------------------------------------- + signed 8 to bool int8 Boolean + signed 16 to bool int16 Boolean + signed 32 to bool int32 Boolean + bool to 8 Boolean int8 + bool to 16 Boolean int16 + bool to 32 Boolean int32 + signed 8 to signed 16 int8 int16 + signed 8 to signed 32 int8 int32 + signed 16 to signed 8 int16 int8 + signed 16 to signed 32 int16 int32 + signed 32 to signed 8 int32 int8 + signed 32 to signed 16 int32 int16 + float to signed 8 float int8 + float to signed 16 float int16 + signed 8 to float int8 float + signed 16 to float int16 float + }]; + + let arguments = ( + ins Tosa_Tensor:$input + ); + + let results = (outs Tosa_Tensor:$output); + +} + +//===----------------------------------------------------------------------===// +// Operator: rescale +//===----------------------------------------------------------------------===// +def Tosa_RescaleOp: Tosa_ConversionOp<"rescale", [NoSideEffect]> { + let summary = "Tosa rescale operator"; + + let description = [{ + Rescale quantized values into a new domain. Supported rescalings are: + Mode Input Output + signed 8 to 8 aint8 aint8 + signed 8 to 16 aint8 int16 + signed 8 to 32 aint8 int32 + signed 16 to 8 int16 aint8 + signed 16 to 16 int16 int16 + signed 16 to 32 int16 int32 + signed 32 to 8 int32 aint8 + signed 32 to 16 int32 int16 + signed 32 to 32 int32 int32 + signed 48 to 8 int48 aint8 + signed 48 to 16 int48 int16 + signed 48 to 32 int48 int32 + unsigned 8 to signed 8 uint8 aint8 + signed 8 to unsigned 8 aint8 uint8 + }]; + + let arguments = (ins + Tosa_Tensor:$input, + I32Attr:$input_zp, + I32Attr:$output_zp, + I32ArrayAttr:$multiplier, + I32ArrayAttr:$shift, + BoolAttr:$scale32, + BoolAttr:$double_round, + BoolAttr:$per_channel + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.13 +// Operator Class: Data Node Ops. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: const +//===----------------------------------------------------------------------===// +def Tosa_ConstOp : Tosa_DataNodeOp<"const", [ConstantLike, NoSideEffect, FirstAttrDerivedResultType]> { + let summary = "Constant op."; + + let description = [{ + A node containing constant data for use as the input to an operation. May hold data + in any of the supported data formats. + }]; + + let arguments = (ins ElementsAttr:$value); + + let results = (outs AnyTensor:$output); + + let builders = [ + OpBuilderDAG<(ins "Type":$type, "Attribute":$value)>, + ]; + +} + +//===----------------------------------------------------------------------===// +// Operator: identity +//===----------------------------------------------------------------------===// +def Tosa_IdentityOp: Tosa_DataNodeOp<"identity", [NoSideEffect]> { + let summary = "Identity operator"; + let description = [{ + Returns a tensor with the same shape, size, type + and content as the input. + }]; + + let arguments = (ins + Tosa_Tensor:$input + ); + + let results = (outs + Tosa_Tensor:$output); +} + +//===----------------------------------------------------------------------===// +// Operator: identityn +//===----------------------------------------------------------------------===// +def Tosa_IdentityNOp: Tosa_DataNodeOp<"identityn", [NoSideEffect]> { + let summary = "IdentityN operator"; + let description = [{ + Returns a list of tensors with the same shape, type, and contents as the + input list of tensors. + }]; + + let arguments = (ins + Variadic:$input + ); + + let results = (outs + Variadic:$output); +} + + +//===----------------------------------------------------------------------===// +// Operator: placeholder +//===----------------------------------------------------------------------===// +def Tosa_PlaceholderOp : Tosa_DataNodeOp<"placeholder", [NoSideEffect]> { + let summary = "Placeholder op"; + + let description = [{ + A node where data will be inserted into the network at runtime. Generally used for inputs to the network. + }]; + + let arguments = (ins + ); + + let results = (outs + Tosa_Tensor:$output + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.14 +// Operator Class: Custom Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: custom +//===----------------------------------------------------------------------===// +def Tosa_CustomOp : Tosa_Op<"custom"> { + + let summary = "Custom operator wrapper for Tosa"; + + let description = [{ + Hardware implementing TOSA may choose to add additional custom operators that are not expressed in + the existing TOSA operations. These operators are not expected to be portable across TOSA + implementations. The input and output signatures must be expressed in the corresponding TOSA node. + }]; + + let arguments = (ins + StrAttr:$identifier, + Variadic:$inputs + ); + + let results = (outs + Variadic:$outputs + ); + +} + +//===----------------------------------------------------------------------===// +// TOSA Spec Section 2.15 +// Operator Class: Control Flow Operators. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Operator: cond_if +//===----------------------------------------------------------------------===// +def Tosa_IfOp : Tosa_Op<"cond_if", [ + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = "Conditional if operator"; + + let description = [{ + Evaluates a Boolean condition and then takes one of two distinct execution paths. This + implements the semantic If-then-else structure. + }]; + + let arguments = (ins + Tosa_Tensor:$cond, + Variadic:$inputs + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region + SizedRegion<1>:$then_branch, + SizedRegion<1>:$else_branch + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: while_loop +//===----------------------------------------------------------------------===// +def Tosa_WhileOp : Tosa_Op<"while_loop", [ + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">]> { + let summary = [{ + output = input; While (Cond(output)) { output = Body(output) } + }]; + + let description = [{ + Generates and evaluates a Bool condition and either executes a loop body or exits to + another control point. This action is performed repeatedly after updating and re-evaluating + the Boolean condition every iteration. This implements the semantic foreach or while + iterative loop structure. + }]; + + let arguments = (ins + Variadic:$inputs + ); + + let results = (outs + Variadic:$output); + + let regions = (region + SizedRegion<1>:$cond, + SizedRegion<1>:$body + ); + +} + +//===----------------------------------------------------------------------===// +// Operator: yield +//===----------------------------------------------------------------------===// +def Tosa_YieldOp : Tosa_Op<"yield", [Terminator]> { + let summary = "yield operator"; + + let description = [{ + return operation within the conditional abd body of + structured control flow. Operation takes variadic operands + but produces no results of its own. + }]; + + let arguments = (ins + Variadic:$inputs + ); + +} + +#endif // TOSA_OPS Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTraits.h @@ -0,0 +1,34 @@ +//===-- TosaTraits.h - TOSA dialect operation traits *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect OpTraits in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATRAITS_H +#define MLIR_DIALECT_TOSA_IR_TOSATRAITS_H + +#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Tosa/IR/TosaTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +namespace OpTrait { +namespace tosa { + +// TBD + +} // namespace tosa +} // namespace OpTrait +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATRAITS_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypes.h @@ -0,0 +1,31 @@ +//===-- TosaTypes.h - TOSA dialect type definitions *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the TOSA Dialect Types in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_IR_TOSATYPES_H +#define MLIR_DIALECT_TOSA_IR_TOSATYPES_H + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" + +namespace mlir { + +namespace tosa { + +// TOSA specific types go here + +} // namespace tosa + +} // end namespace mlir + +#endif // MLIR_DIALECT_TOSA_IR_TOSATYPES_H Index: mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -0,0 +1,135 @@ +//===-- TosaTypesBase.td - TOSA type definitions *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the type definitions for the TOSA dialect. +// +//===----------------------------------------------------------------------===// + +#ifdef TOSA_TYPES_BASE +#else +#define TOSA_TYPES_BASE + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// Tosa Type Definitions. +//===----------------------------------------------------------------------===// + +// The base class of a quantized type. +// Param tuple is: [bitwidth, zeropt, smantissa, sexp, low_end, high_end]. +// Where low and high ends are 0,255 when unsigned, -128,127 when signe, for the 8-bit case. +class Tosa_QuantizedType params, bit signed> + : Type()">, + CPred<"$_self.cast()" # + ".getStorageTypeIntegralWidth() == " # !head(params)>]>, + "Q" # !if (signed, "int", "uint") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = + StrJoinInt.result # !if(signed, ", true", ", false"); +} + +//===----------------------------------------------------------------------===// +// Non-Quantized Signed Integer Types. +// Used to express accumulator results or compare results. +// +// Booleans are currently assumed to be expressed using int8 or +// built in U1 / U1Tensor type. +//===----------------------------------------------------------------------===// + +def Tosa_Int32 : SI<32>; +def Tosa_Int48 : SI<48>; +def Tosa_Int64 : SI<64>; + +def Tosa_SignedInt : AnyTypeOf<[Tosa_Int32, + Tosa_Int48, + Tosa_Int64]>; + +// No unsigned unquantized int types. +def Tosa_Int : AnyTypeOf<[Tosa_SignedInt]>; + +def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, + Tosa_Int64]>; + +def Tosa_IntTensor : TensorOf<[Tosa_SignedInt]>; + +def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>; + +//===----------------------------------------------------------------------===// +// Quantized Integer Types. +// Datatype for network feature map or weight content. +//===----------------------------------------------------------------------===// +def Tosa_Quint8 : Tosa_QuantizedType<"Uniform", [8], 0>; +def Tosa_Qint8 : Tosa_QuantizedType<"Uniform", [8], 1>; +def Tosa_Qint16 : Tosa_QuantizedType<"Uniform", [16], 1>; + +//===----------------------------------------------------------------------===// +// Name Symmetry Grouping Sign +//===----------------------------------------------------------------------===// +// aint8 : asymmetric per tensor, signed +// uint8 : asymmetric per tensor , unsigned +// int4 : symmetric per channel, signed +// int8 : symmetric per tensor/per channel, signed +// int16 : symmetric per tensor, signed +//===----------------------------------------------------------------------===// +def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"aint8", [8], 1>, + Tosa_QuantizedType<"uint8", [8], 0>, + Tosa_QuantizedType<"int4", [4, 0], 1>, + Tosa_QuantizedType<"int8", [8, 0], 1>, + Tosa_QuantizedType<"int16", [16, 0], 1>]>; + + +//===----------------------------------------------------------------------===// +// Floating-point types +//===----------------------------------------------------------------------===// +def Tosa_Float : AnyTypeOf<[F32, + F16, + BF16]>; + +def Tosa_FpTensor : TensorOf<[Tosa_Float]>; + +//===----------------------------------------------------------------------===// +// Multi-category types +//===----------------------------------------------------------------------===// +def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], + "number">; + +//===----------------------------------------------------------------------===// +// Tensor types +//===----------------------------------------------------------------------===// +def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; + +def Tosa_QTypeTensor : TensorOf<[Tosa_QuantizedInt]>; + +class Tosa_TensorOrNone possibleTypes, string description = ""> : + AnyTypeOf<[TensorOf, NoneType], description>; + +// Any tensor element type allowed in Tosa ops. +def Tosa_ElementType : Type, + "tosa.dtype">; + +class Tosa_TensorOfOrNone allowedTypes, string description = ""> : + AnyTypeOf<[TensorOf, NoneType], description>; + +//===----------------------------------------------------------------------===// +// Iterable attributes +//===----------------------------------------------------------------------===// +// Supported regimes for tosa.resize +def Tosa_ResizeTypeAttr : StringBasedAttr< + CPred<"$_self.cast().getValue() == \"TRANSPOSE\" || " # + "$_self.cast().getValue() == \"BILINEAR\" || " # + "$_self.cast().getValue() == \"NEAREST_NEIGHBOR\"">, + "Supported resize/upsampling strategies">; + +def Tosa_TensorTypeAttr : TypeAttrBase<"TensorType", "Tensor type attribute">; + +// Tensor to buffer types. +def Tosa_Buffer : MemRefOf<[Tosa_AnyNumber]>; +def Tosa_TupleBuffer : NestedTupleOf<[Tosa_Buffer]>; +def Tosa_BufOrTuple : AnyTypeOf<[Tosa_Buffer, Tosa_TupleBuffer]>; + +#endif // TOSA_TYPES_BASE Index: mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name TosaOpt) +add_public_tablegen_target(MLIRTosaPassIncGen) +add_dependencies(mlir-headers MLIRTosaPassIncGen) + +add_mlir_doc(Passes -gen-pass-doc TosaPasses ./) Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h @@ -0,0 +1,36 @@ +//===-- Passes.h - TOSA optimization pass declarations *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H +#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +class FuncOp; +class ModuleOp; +class Pass; +template +class OperationPass; + +namespace tosa { + +std::unique_ptr> createTosaMakeBroadcastablePass(); + +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc" + +} // namespace tosa +} // namespace mlir + +#endif // MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H Index: mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -0,0 +1,18 @@ +//===-- Passes.td - TOSA optimization pass declarations *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the optimization passes for the TOSA Dialect in MLIR. +// +//===----------------------------------------------------------------------===// + +include "mlir/Pass/PassBase.td" + +def TosaBinaryInputReshapePass : Pass<"tosa-make-broadcastable", "FuncOp"> { + let summary = "TOSA rank Reshape to enable Broadcasting"; + let constructor = "createTosaMakeBroadcastablePass()"; +} Index: mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h @@ -0,0 +1,89 @@ +//===-- QuantUtils.h - TOSA numerical support declarations *- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Function declarations for TOSA numerical support functions and quantization +// attribute builders +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_TOSA_UTILS_QUANT_UTILS_H +#define DIALECT_TOSA_UTILS_QUANT_UTILS_H + +//===----------------------------------------------------------------------===// +// Utililty functions to support quantization handling in Tosa. +//===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include + +#include "mlir/Dialect/Quant/FakeQuantSupport.h" +#include "mlir/Dialect/Quant/UniformSupport.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringSwitch.h" + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +namespace mlir { +namespace tosa { + +/// From a scale value, computes multiplier and shift values +/// for 16 or 32-bit scale widths. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth); + +//// Builds ConvOpQuantizationAttr from input and weight. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, Value weight); + +//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B. +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(mlir::OpBuilder &builder, + Value a, Value b); + +//// Builds UnaryOpQuantizationAttr for unary operations from input values. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, + Type outputRawType); + +//// Builds PadOpQuantizationAttr for pad operations from input values. +PadOpQuantizationAttr buildPadOpQuantizationAttr(mlir::OpBuilder &builder, + Value input); + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange); + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDType, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange); + +} // namespace tosa +} // namespace mlir + +#endif // DIALECT_TOSA_UTILS_QUANT_UTILS_H Index: mlir/lib/Dialect/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/CMakeLists.txt +++ mlir/lib/Dialect/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) +add_subdirectory(Tosa) add_subdirectory(Vector) set(LLVM_OPTIONAL_SOURCES Index: mlir/lib/Dialect/Tosa/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -0,0 +1,24 @@ +add_mlir_dialect_library(MLIRTosa + IR/TosaOps.cpp + Utils/QuantUtils.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa + + DEPENDS + MLIRStandardOpsIncGen + MLIRTosaOpsIncGen + MLIRTosaStructsIncGen + MLIRTosaInterfaceIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRStandard + MLIRCallInterfaces + MLIRControlFlowInterfaces + MLIRSideEffectInterfaces + MLIRViewLikeInterface + ) + +# add_subdirectory(Serialization) +add_subdirectory(Transforms) Index: mlir/lib/Dialect/Tosa/IR/TosaOps.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -0,0 +1,145 @@ +//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This file implements the TOSA Specification: +// https://developer.mlplatform.org/w/tosa/ +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +#include +#include +#include +#include +#include +#include + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Parser.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/InliningUtils.h" +#include "mlir/Transforms/RegionUtils.h" + +using namespace mlir; +using namespace mlir::tosa; + +#include "mlir/Dialect/Tosa/IR/TosaStructs.cpp.inc" + +//===----------------------------------------------------------------------===// +// Tosa dialect interfaces. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" + +//===----------------------------------------------------------------------===// +// Dialect Function Inliner Interface. +//===----------------------------------------------------------------------===// +struct TosaInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + + //===--------------------------------------------------------------------===// + // Analysis Hooks. + //===--------------------------------------------------------------------===// + + /// All operations can be inlined by default. + bool isLegalToInline(Operation *op, Region *region, + BlockAndValueMapping &map) const { + return true; + } + + /// All regions with If and While parent operators can be inlined. + bool isLegalToInline(Region *dest, Region *src, + BlockAndValueMapping &map) const { + return (isa(dest->getParentOp()) || + isa(dest->getParentOp())); + } +}; + +//===----------------------------------------------------------------------===// +// TOSA control flow support. +//===----------------------------------------------------------------------===// + +/// Returns the while loop body. +Region &tosa::WhileOp::getLoopBody() { return body(); } + +/// WIP MLIR enhancements with exposed API. +bool tosa::WhileOp::isDefinedOutsideOfLoop(Value value) { return false; } + +LogicalResult WhileOp::moveOutOfLoop(llvm::ArrayRef ops) { + if (ops.empty()) + return success(); + + Operation *tosaWhileOp = this->getOperation(); + for (auto *op : ops) + op->moveBefore(tosaWhileOp); + + return success(); +} + +struct TosaDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; + + bool shouldMaterializeInto(Region *region) const final { + return isa(region->getParentOp()); + } +}; + +//===----------------------------------------------------------------------===// +// Tosa Dialect. +//===----------------------------------------------------------------------===// + +TosaDialect::TosaDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + >(); + addInterfaces(); + + allowUnknownOperations(); +} + +//===----------------------------------------------------------------------===// +// ConstOp +//===----------------------------------------------------------------------===// + +void ConstOp::build(OpBuilder &builder, OperationState &result, Type type, + Attribute value) { + result.addTypes(type); + result.addAttribute("value", value); + return; +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" + +Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + if (value.isa() || + (value.isa() && value.getType() != type)) + return builder.create(loc, type, value.cast()); + return nullptr; +} Index: mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRTosaTransforms + TosaMakeBroadcastable.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa/Transforms + + DEPENDS + MLIRTosaPassIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTosa + ) Index: mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -0,0 +1,225 @@ +//===- TosaMakeBroadcastable.cpp ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Insert reshape to binary op's input if needed to match rank +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tosa/IR//TosaOps.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define PASS_NAME "tosa-make=broadcastable" +#define DEBUG_TYPE PASS_NAME + +namespace mlir { + +namespace tosa { + +namespace { +/// Pass that reshapes the inputs of binary operators to make them the same +/// rank, adding a dimention of size 1 where necessary to enable broadcasting. +class TosaMakeBroadcastable + : public PassWrapper { +public: + explicit TosaMakeBroadcastable() {} + void runOnFunction() override; +}; + +#define REPLACE_OP_LOGICAL(tosa_op, LHS_VALUE, RHS_VALUE) + +#define REPLACE_OP(tosa_op, LHS_VALUE, RHS_VALUE) \ + { \ + rewriter.replaceOpWithNewOp(op, output_type, LHS_VALUE, \ + RHS_VALUE); \ + } + +/// For binary operators, insert a RESHAPE operation if necessary to create +/// inputs of the same rank, with dimensions of size 1 where broadcasting +/// will be used. +// Examples: +// If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1]. +// TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into +// [1, b, 1]. If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c]. If +// lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c]. If lower=[a, +// b], target=[a, b, c], [a, b] reshaped into [a, b, 1]. If lower=[b, c], +// target=[a, b, c], [b, c] reshaped into [1, b, c]. If lower=[a], target=[a, +// a], [a] reshaped into [1, a] instead of [a, 1]. If lower=[a], target=[a, b, +// a], [a] reshaped into [1, 1, a]. If lower=[], target=[a, b, c], [] reshaped +// into [1, 1, 1]. + +#define DECL_TOSACONVERT_OP(tosa_op) \ + struct ConvertTosa##tosa_op##Op : public RewritePattern { \ + explicit ConvertTosa##tosa_op##Op(MLIRContext *context) \ + : RewritePattern(tosa::tosa_op##Op::getOperationName(), 1, context) {} \ + LogicalResult matchAndRewrite(Operation *op, \ + PatternRewriter &rewriter) const { \ + auto tosa_binary_op = cast(op); \ + \ + auto lhs = tosa_binary_op.lhs(); \ + auto rhs = tosa_binary_op.rhs(); \ + \ + int64_t lhs_rank = lhs.getType().dyn_cast().getRank(); \ + int64_t rhs_rank = rhs.getType().dyn_cast().getRank(); \ + \ + auto output_type = \ + tosa_binary_op.getResult().getType().dyn_cast(); \ + \ + int64_t higher_rank, lower_rank; \ + Value higher_tensor_value, lower_tensor_value; \ + /* return if rank already match */ \ + if (lhs_rank == rhs_rank) { \ + return failure(); \ + } else if (lhs_rank > rhs_rank) { \ + higher_rank = lhs_rank; \ + lower_rank = rhs_rank; \ + higher_tensor_value = lhs; \ + lower_tensor_value = rhs; \ + } else { \ + higher_rank = rhs_rank; \ + lower_rank = lhs_rank; \ + higher_tensor_value = rhs; \ + lower_tensor_value = lhs; \ + } \ + \ + ArrayRef higher_rank_shape = output_type.getShape(); \ + ArrayRef lower_rank_shape = lower_tensor_value.getType() \ + .dyn_cast() \ + .getShape(); \ + \ + SmallVector reshape_output_shape; \ + reshape_output_shape.assign(higher_rank, 1); \ + \ + int64_t higher_left_index = 0; \ + int64_t higher_right_index = higher_rank; \ + int64_t lower_left_index = 0; \ + int64_t lower_right_index = lower_rank; \ + int64_t higher_rank_dim, lower_rank_dim; \ + \ + if (lower_right_index != 0 && higher_right_index != 0) { \ + while (true) { \ + higher_rank_dim = higher_rank_shape[higher_right_index - 1]; \ + lower_rank_dim = lower_rank_shape[lower_right_index - 1]; \ + if (higher_rank_dim == lower_rank_dim) { \ + reshape_output_shape[higher_right_index - 1] = higher_rank_dim; \ + \ + if (higher_right_index > 0) { \ + higher_right_index--; \ + } \ + \ + if (lower_right_index > 0) { \ + lower_right_index--; \ + } \ + \ + if (higher_right_index == 0 || lower_right_index == 0) { \ + break; \ + } \ + } else { \ + break; \ + } \ + } \ + if (lower_right_index != 0 && higher_right_index != 0) { \ + while (true) { \ + higher_rank_dim = higher_rank_shape[higher_left_index]; \ + lower_rank_dim = lower_rank_shape[lower_left_index]; \ + if (higher_rank_dim == lower_rank_dim) { \ + reshape_output_shape[higher_left_index] = higher_rank_dim; \ + \ + if (higher_left_index < higher_right_index) { \ + higher_left_index++; \ + } \ + \ + if (lower_left_index < lower_right_index) { \ + lower_left_index++; \ + } \ + \ + if (higher_left_index == higher_right_index || \ + lower_left_index == lower_right_index) { \ + break; \ + } \ + } else { \ + break; \ + } \ + } \ + } \ + } \ + \ + auto reshape_input_type = \ + lower_tensor_value.getType().dyn_cast(); \ + auto reshape_output_type = \ + RankedTensorType::get(ArrayRef(reshape_output_shape), \ + reshape_input_type.getElementType()); \ + \ + auto reshape_lower = rewriter.create( \ + op->getLoc(), reshape_output_type, lower_tensor_value, \ + rewriter.getI64ArrayAttr(reshape_output_shape)); \ + \ + if (lhs_rank > rhs_rank) { \ + REPLACE_OP(tosa_op, higher_tensor_value, reshape_lower.getResult()); \ + } else { \ + REPLACE_OP(tosa_op, reshape_lower.getResult(), higher_tensor_value); \ + } \ + \ + return success(); \ + } \ + }; +DECL_TOSACONVERT_OP(Add) +DECL_TOSACONVERT_OP(Sub) +DECL_TOSACONVERT_OP(Mul) +DECL_TOSACONVERT_OP(LogicalLeftShift) +DECL_TOSACONVERT_OP(ArithmeticRightShift) +DECL_TOSACONVERT_OP(LogicalRightShift) +#undef DECL_TOSACONVERT_OP + +#undef REPLACE_OP + +void TosaMakeBroadcastable::runOnFunction() { + OwningRewritePatternList patterns; + auto *ctx = &getContext(); + auto func = getFunction(); + + // Add the generated patterns to the list. + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + patterns.insert(ctx); + applyPatternsAndFoldGreedily(func, std::move(patterns)); +} + +} // anonymous namespace + +std::unique_ptr> createTosaMakeBroadcastablePass() { + return std::make_unique(); +} + +static PassRegistration + pass(PASS_NAME, + "Perform broadcast on elementwise TosaOps to ensure same rank"); + +} // namespace tosa + +} // namespace mlir Index: mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -0,0 +1,347 @@ +//===- QuantUtils.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains TOSA numerical support functions and quantization +// attribute builders. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" + +namespace mlir { +namespace tosa { + +namespace { + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 16-bit scaling. +void computeMultiplierAndShiftTosaScale16(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 15)); + + assert(shiftedM <= (int64_t(1) << 15)); // Can't be greater that 1.0. + if (shiftedM == (int64_t(1) << 15)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive and embed (1 << 15) into right + // shift bits. + shift = (-shift) + 15; + + assert(shiftedM <= std::numeric_limits::max()); + + multiplier = static_cast(shiftedM); +} + +/// From a scale value, generates multiplier and shift values where +/// mantissa is in [-1.0,-0.5] or [0.5, 1.0] such that +/// multiplier = mantissa*2^shift for 32-bit scaling. +void computeMultiplierAndShiftTosaScale32(double scale, int32_t &multiplier, + int32_t &shift) { + + const double mantissa = std::frexp(scale, &shift); + auto shiftedM = std::round(mantissa * (int64_t(1) << 31)); + + assert(shiftedM <= (int64_t(1) << 31)); // Can't be greater that 1.0. + if (shiftedM == (int64_t(1) << 31)) { + shiftedM /= 2; + shift++; + } + + // TOSA expects right shift to be positive, and embed (1 << 31) into right + // shift bits. + shift = (-shift) + 31; + + assert(shiftedM <= std::numeric_limits::max()); + + multiplier = static_cast(shiftedM); +} + +} // namespace + +/// Generates a quantized multiplier/shift from double. +void computeMultiplierAndShift(double scale, int32_t &multiplier, + int32_t &shift, int32_t scaleWidth) { + + switch (scaleWidth) { + case 16: + computeMultiplierAndShiftTosaScale16(scale, multiplier, shift); + return; + case 32: + computeMultiplierAndShiftTosaScale32(scale, multiplier, shift); + return; + default: + assert(0 && "Unsupported Tosa quantized_scale regime specified!"); + } +} + +#define GET_UQTYPE(input) \ + ((input) \ + .getType() \ + .dyn_cast() \ + .getElementType() \ + .dyn_cast()) + +// Method to build ConvOpQuantizationAttr, called from +// ConvOpQuantInfoBuilder/TransConvOpQuantInfoBuilder: input_zp: input zeropoint +// weight_zp: weight zeropoint. +ConvOpQuantizationAttr buildConvOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, Value weight) { + + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); + + if (!inputType || !weightType) + return nullptr; + + bool inputIsQType = + inputType.getElementType().isa(); + bool weightIsQType = + weightType.getElementType().isa(); + + // Either all quantized or all not quantized. + assert(!(inputIsQType ^ weightIsQType)); + + if (inputIsQType) { + + auto inputQType = inputType.getElementType() + .dyn_cast(); + assert(inputQType); // We don't support any other kind of input + // quantization here + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t weightZp = 0; + + if (auto weightQType = weightType.getElementType() + .dyn_cast()) { + // Per tensor quantization. + weightZp = weightQType.getZeroPoint(); + } else if (auto weightQType = + weightType.getElementType() + .dyn_cast()) { + // Per channel quantization. + weightZp = weightQType.getZeroPoints().front(); + } + + auto quantAttr = mlir::tosa::ConvOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(weightZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds MatMulOpQuantizationAttr, called from +/// MatMulOpQuantInfoBuilder: aZp: input a zeropoint bZp: input b zeropoint +MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(mlir::OpBuilder &builder, + Value a, Value b) { + + auto aType = a.getType().dyn_cast(); + auto bType = b.getType().dyn_cast(); + + if (!aType || !bType) + return nullptr; + + bool aIsQType = + aType.getElementType().isa(); + bool bIsQType = + bType.getElementType().isa(); + + // A and B are either all quantized or all not quantized. + assert(!(aIsQType ^ bIsQType)); + + if (aIsQType) { + + auto aQType = GET_UQTYPE(a); + auto bQType = GET_UQTYPE(b); + + assert(aQType && bQType); + + int64_t aZp = aQType.getZeroPoint(); + int64_t bZp = bQType.getZeroPoint(); + + auto quantAttr = mlir::tosa::MatMulOpQuantizationAttr::get( + builder.getI32IntegerAttr(aZp), builder.getI32IntegerAttr(bZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds UnaryOpQuantizationAttr +/// UnaryOpQuantInfoBuilder: inputZp: input zeropoint +/// outputZp: output zeropoint. +UnaryOpQuantizationAttr buildUnaryOpQuantizationAttr(mlir::OpBuilder &builder, + Value input, + Type outputRawType) { + + auto inputType = input.getType().dyn_cast(); + auto outputType = outputRawType.dyn_cast(); + + if (!inputType || !outputType) + return nullptr; + + bool inputIsQType = + inputType.getElementType().isa(); + bool outputIsQType = + outputType.getElementType().isa(); + + // Either all quantized or all not quantized. + assert(!(inputIsQType ^ outputIsQType)); + + if (inputIsQType) { + + auto inputQType = inputType.getElementType() + .dyn_cast(); + auto outputQType = outputType.getElementType() + .dyn_cast(); + assert(inputQType && outputQType); + + int64_t inputZp = inputQType.getZeroPoint(); + int64_t outputZp = outputQType.getZeroPoint(); + + auto quantAttr = mlir::tosa::UnaryOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getI32IntegerAttr(outputZp), + builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds PadOpQuantizationAttr, called from PadOpQuantInfoBuilder: +/// inputZp: input zeropoint +PadOpQuantizationAttr buildPadOpQuantizationAttr(mlir::OpBuilder &builder, + Value input) { + + auto inputType = input.getType().dyn_cast(); + + if (!inputType) + return nullptr; + + bool inputIsQType = + inputType.getElementType().isa(); + + if (inputIsQType) { + + auto inputQType = inputType.getElementType() + .dyn_cast(); + assert(inputQType); + + int64_t inputZp = inputQType.getZeroPoint(); + + auto quantAttr = mlir::tosa::PadOpQuantizationAttr::get( + builder.getI32IntegerAttr(inputZp), builder.getContext()); + + return quantAttr; + } + + return nullptr; +} + +/// Builds Tosa quantization attributes from min/max values. +Type buildQTypeFromMinMax(OpBuilder builder, Type inputDType, Attribute minAttr, + Attribute maxAttr, IntegerAttr quantBits, + int filterQuantDim, bool isSigned, + BoolAttr narrowRange) { + + quant::QuantizedType retType; + + auto convfunc = + quant::ExpressedToQuantizedConverter::forInputType(inputDType); + + auto minElems = minAttr.dyn_cast(); + auto maxElems = maxAttr.dyn_cast(); + + SmallVector min, max; + + if (minElems || + maxElems) { // At least one is per-axis quantized elementsattr. + + // Must have the same number of elements. + if (minElems.getNumElements() != maxElems.getNumElements()) + return {}; + + min.reserve(minElems.getNumElements()); + max.reserve(maxElems.getNumElements()); + for (auto i : minElems) { + min.push_back(FloatAttr::getValueAsDouble(i)); + } + for (auto i : maxElems) { + max.push_back(FloatAttr::getValueAsDouble(i)); + } + } else { // Just a single FP value. + + auto minVal = minAttr.dyn_cast(); + if (minVal) + min.push_back(minVal.getValueAsDouble()); + else + return {}; + auto maxVal = maxAttr.dyn_cast(); + if (maxVal) + max.push_back(maxVal.getValueAsDouble()); + else + return {}; + } + + if (min.size() == max.size()) { + + if (min.size() == 1) { // Per-tensor quantization with one min/max pair. + + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), min[0], max[0], + narrowRange.getValue(), convfunc.expressedType, isSigned); + + } else if (min.size() > 1) { // Per-axis quant on filterQuantDim. + + auto shape = inputDType.dyn_cast(); + if (!shape) + return {}; + if ((filterQuantDim) >= 0 && (shape.getRank() > filterQuantDim)) { + + retType = quant::fakeQuantAttrsToType( + builder.getUnknownLoc(), quantBits.getInt(), filterQuantDim, min[0], + max[0], narrowRange.getValue(), convfunc.expressedType, isSigned); + } + + } else { + return {}; + } + } else { + return {}; + } + + if (!retType) + return {}; + + return convfunc.convert(retType); +} + +/// Builds Tosa quantization attributes from min/max values. +TypeAttr buildQTypeAttrFromMinMax(OpBuilder builder, Type inputDtype, + Attribute minAttr, Attribute maxAttr, + IntegerAttr quantBits, int filterQuantDim, + bool isSigned, BoolAttr narrowRange) { + + return TypeAttr::get(buildQTypeFromMinMax(builder, inputDtype, minAttr, + maxAttr, quantBits, filterQuantDim, + isSigned, narrowRange)); +} + +} // namespace tosa +} // namespace mlir