diff --git a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt @@ -2,6 +2,11 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc) add_mlir_interface(TosaInterfaces) +set(LLVM_TARGET_DEFINITIONS TosaOps.td) +mlir_tablegen(TosaOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(TosaOpsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRTosaOpsEnumsGen) + set(LLVM_TARGET_DEFINITIONS TosaOps.td) mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls) mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -173,10 +173,10 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder< (ins "::mlir::Type":$outputType, "::mlir::Value":$input, "::mlir::DenseI64ArrayAttr":$kernel, "::mlir::DenseI64ArrayAttr":$stride, - "::mlir::DenseI64ArrayAttr":$pad), + "::mlir::DenseI64ArrayAttr":$pad, "AccTypeAttr":$acc_type), [{ buildAvgPool2dOpWithQuantInfo($_builder, $_state, outputType, - input, kernel, stride, pad); + input, kernel, stride, pad, acc_type); }]>; // This builder is called on single-parameter unary operators that have a scale diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -25,6 +25,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc" +#include "mlir/Dialect/Tosa/IR/TosaOpsEnums.h.inc" namespace mlir { class PatternRewriter; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -53,6 +53,20 @@ ); } +//===----------------------------------------------------------------------===// +// Accumulator types. +//===----------------------------------------------------------------------===// + +def Tosa_AccType : I32EnumAttr<"AccType", "Specify the type of the accumulator", + [Tosa_Enum_I32, Tosa_Enum_F16, Tosa_Enum_F32]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::tosa"; +} + +def Tosa_AccTypeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // Operator: avg_pool2d //===----------------------------------------------------------------------===// @@ -74,6 +88,7 @@ Tosa_IntArrayAttr2:$kernel, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$pad, + Tosa_AccTypeAttr:$acc_type, OptionalAttr:$quantization_info ); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -14,6 +14,7 @@ #define TOSA_TYPES_BASE include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" //===----------------------------------------------------------------------===// // Tosa Type Definitions. @@ -99,6 +100,14 @@ def Tosa_AnyNumber_Cast : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64], "number_cast">; +//===----------------------------------------------------------------------===// +// Enumeration. +//===----------------------------------------------------------------------===// + +def Tosa_Enum_I32 : I32EnumAttrCase<"I32", 0, "i32">; +def Tosa_Enum_F16 : I32EnumAttrCase<"F16", 1, "f16">; +def Tosa_Enum_F32 : I32EnumAttrCase<"F32", 2, "f32">; + //===----------------------------------------------------------------------===// // Tensor types //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -757,6 +757,22 @@ public: using OpRewritePattern::OpRewritePattern; + Type getAccumulatorType(PatternRewriter &rewriter, + tosa::AvgPool2dOp op) const { + switch (op.getAccType()) { + case tosa::AccType::I32: + return rewriter.getI32Type(); + case tosa::AccType::F16: + return rewriter.getF16Type(); + case tosa::AccType::F32: + return rewriter.getF32Type(); + default: + (void)rewriter.notifyMatchFailure( + op, "tosa.avg_pool ops does not support this accumulator type"); + return rewriter.getNoneType(); + } + } + LogicalResult matchAndRewrite(tosa::AvgPool2dOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); @@ -767,8 +783,7 @@ ShapedType resultTy = op.getType().template cast(); Type resultETy = op.getType().cast().getElementType(); - Type accETy = - inElementTy.isa() ? rewriter.getI32Type() : inElementTy; + Type accETy = getAccumulatorType(rewriter, op); ShapedType accTy = resultTy.clone(accETy); auto dynamicDimsOr = diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -10,6 +10,7 @@ DEPENDS MLIRTosaAttributesIncGen MLIRTosaOpsIncGen + MLIRTosaOpsEnumsGen MLIRTosaInterfacesIncGen LINK_LIBS PUBLIC diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -267,13 +267,16 @@ /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr /// but avg_pool operator has its own builder as it has additional parameters /// not part of the unary ops. -static void buildAvgPool2dOpWithQuantInfo( - OpBuilder &builder, OperationState &result, Type outputType, Value input, - DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad) { +static void +buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, + Type outputType, Value input, + DenseArrayAttr kernel, DenseArrayAttr stride, + DenseArrayAttr pad, AccTypeAttr acc_type) { result.addOperands(input); result.addAttribute("kernel", kernel); result.addAttribute("stride", stride); result.addAttribute("pad", pad); + result.addAttribute("acc_type", acc_type); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) result.addAttribute("quantization_info", quantAttr); @@ -1434,6 +1437,12 @@ #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" +//===----------------------------------------------------------------------===// +// TOSA Enum Definitions. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tosa/IR/TosaOpsEnums.cpp.inc" + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -286,7 +286,7 @@ // CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]] // CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]] // CHECK: linalg.yield %[[DIV]] - %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) return %0 : tensor<1x5x33x62xf32> } @@ -329,7 +329,7 @@ // CHECK: %[[CLAMP:.+]] = arith.select %[[CMP]], %[[CMAX]], %[[SEL]] // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]] // CHECK: linalg.yield %[[TRUNC]] - %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) return %0 : tensor<1x5x33x62xi8> } @@ -352,7 +352,7 @@ // CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor // CHECK: %[[GENERIC:.+]] = linalg.generic - %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor) -> (tensor) + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, pad = array, kernel = array, stride = array} : (tensor) -> (tensor) return %0 : tensor } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -12,28 +12,28 @@ // ----- // CHECK-LABEL: avg_pool2d_f32 func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } // ----- // CHECK-LABEL: avg_pool2d_i8 func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> return %0 : tensor<1x7x7x9xi8> } // ----- // CHECK-LABEL: avg_pool2d_i16 func.func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> return %0 : tensor<1x7x7x9xi16> } // ----- // CHECK-LABEL: avg_pool2d_q8 func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> return %0 : tensor<1x7x7x9x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -659,7 +659,7 @@ // CHECK-LABEL: @test_pool_static func.func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) { // CHECK: -> tensor<3x2x4x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor // CHECK: -> tensor<3x2x4x7xf32> %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor @@ -689,7 +689,7 @@ // CHECK-LABEL: @test_pool_dynamic_input func.func @test_pool_dynamic_input(%arg0: tensor) { // CHECK: -> tensor - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor) -> tensor + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, kernel = array, pad = array, stride = array} : (tensor) -> tensor // CHECK: -> tensor %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor) -> tensor @@ -701,7 +701,7 @@ // CHECK-LABEL: @test_pool_padded func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) { // CHECK: -> tensor<3x5x11x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor // CHECK: -> tensor<3x5x11x7xf32> %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor @@ -731,7 +731,7 @@ // CHECK-LABEL: @test_pool_stride func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) { // CHECK: -> tensor<3x4x4x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = #tosa.accumulator_type, kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor // CHECK: -> tensor<3x4x4x7xf32> %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -9001,6 +9001,14 @@ ["-gen-dialect-defs"], "include/mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc", ), + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/Tosa/IR/TosaOpsEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/Tosa/IR/TosaOpsEnums.cpp.inc", + ), ( ["-gen-attrdef-decls"], "include/mlir/Dialect/Tosa/IR/TosaAttributes.h.inc",