diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td @@ -174,10 +174,10 @@ def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder< (ins "::mlir::Type":$outputType, "::mlir::Value":$input, "::mlir::DenseI64ArrayAttr":$kernel, "::mlir::DenseI64ArrayAttr":$stride, - "::mlir::DenseI64ArrayAttr":$pad), + "::mlir::DenseI64ArrayAttr":$pad, "::mlir::TypeAttr":$acc_type), [{ buildAvgPool2dOpWithQuantInfo($_builder, $_state, outputType, - input, kernel, stride, pad); + input, kernel, stride, pad, acc_type); }]>; // This builder is called on single-parameter unary operators that have a scale diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -53,6 +53,12 @@ ); } +//===----------------------------------------------------------------------===// +// Accumulator types. +//===----------------------------------------------------------------------===// + +def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>; + //===----------------------------------------------------------------------===// // Operator: avg_pool2d //===----------------------------------------------------------------------===// @@ -74,6 +80,7 @@ Tosa_IntArrayAttr2:$kernel, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$pad, + TypeAttrOf:$acc_type, OptionalAttr:$quantization_info ); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -751,8 +751,7 @@ ShapedType resultTy = cast(op.getType()); Type resultETy = cast(op.getType()).getElementType(); - Type accETy = - isa(inElementTy) ? rewriter.getI32Type() : inElementTy; + Type accETy = op.getAccType(); ShapedType accTy = resultTy.clone(accETy); auto dynamicDimsOr = diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -153,6 +153,17 @@ llvm::dyn_cast(resultETy)) resultETy = quantType.getStorageType(); + auto accType = getAccType(); + if (inputETy.isa() && !accType.isInteger(32)) + return emitOpError("accumulator type for integer tensor is not i32"); + + if ((inputETy.isBF16() || inputETy.isF16()) && + !(accType.isF16() || accType.isF32())) + return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32"); + + if (inputETy.isF32() && !accType.isF32()) + return emitOpError("accumulator type for f32 tensor is not f32"); + if (inputETy.isF32() && resultETy.isF32()) return success(); if (inputETy.isInteger(8) && resultETy.isInteger(8)) @@ -268,13 +279,16 @@ /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr /// but avg_pool operator has its own builder as it has additional parameters /// not part of the unary ops. -static void buildAvgPool2dOpWithQuantInfo( - OpBuilder &builder, OperationState &result, Type outputType, Value input, - DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad) { +static void +buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, + Type outputType, Value input, + DenseArrayAttr kernel, DenseArrayAttr stride, + DenseArrayAttr pad, TypeAttr acc_type) { result.addOperands(input); result.addAttribute("kernel", kernel); result.addAttribute("stride", stride); result.addAttribute("pad", pad); + result.addAttribute("acc_type", acc_type); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) result.addAttribute("quantization_info", quantAttr); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -286,7 +286,7 @@ // CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]] // CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]] // CHECK: linalg.yield %[[DIV]] - %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>) return %0 : tensor<1x5x33x62xf32> } @@ -329,7 +329,7 @@ // CHECK: %[[CLAMP:.+]] = arith.select %[[CMP]], %[[CMAX]], %[[SEL]] // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]] // CHECK: linalg.yield %[[TRUNC]] - %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) return %0 : tensor<1x5x33x62xi8> } @@ -352,7 +352,7 @@ // CHECK-SAME: outs(%[[FILL]] : tensor) -> tensor // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor // CHECK: %[[GENERIC:.+]] = linalg.generic - %0 = "tosa.avg_pool2d"(%arg0) {pad = array, kernel = array, stride = array} : (tensor) -> (tensor) + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, pad = array, kernel = array, stride = array} : (tensor) -> (tensor) return %0 : tensor } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -12,28 +12,28 @@ // ----- // CHECK-LABEL: avg_pool2d_f32 func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } // ----- // CHECK-LABEL: avg_pool2d_i8 func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> return %0 : tensor<1x7x7x9xi8> } // ----- // CHECK-LABEL: avg_pool2d_i16 func.func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> return %0 : tensor<1x7x7x9xi16> } // ----- // CHECK-LABEL: avg_pool2d_q8 func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> return %0 : tensor<1x7x7x9x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -659,7 +659,7 @@ // CHECK-LABEL: @test_pool_static func.func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) { // CHECK: -> tensor<3x2x4x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor // CHECK: -> tensor<3x2x4x7xf32> %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor @@ -689,7 +689,7 @@ // CHECK-LABEL: @test_pool_dynamic_input func.func @test_pool_dynamic_input(%arg0: tensor) { // CHECK: -> tensor - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor) -> tensor + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor) -> tensor // CHECK: -> tensor %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor) -> tensor @@ -701,7 +701,7 @@ // CHECK-LABEL: @test_pool_padded func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) { // CHECK: -> tensor<3x5x11x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor // CHECK: -> tensor<3x5x11x7xf32> %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor @@ -731,7 +731,7 @@ // CHECK-LABEL: @test_pool_stride func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) { // CHECK: -> tensor<3x4x4x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor // CHECK: -> tensor<3x4x4x7xf32> %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor