diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -3000,8 +3000,6 @@ } } - // Cast to output type. - rewriter.create(loc, poolVal); }); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -416,9 +416,9 @@ if (inputETy.isF32() && resultETy.isF32()) return success(); - if (inputETy.isInteger(8) && resultETy.isInteger(32)) + if (inputETy.isInteger(8) && resultETy.isInteger(8)) return success(); - if (inputETy.isInteger(16) && resultETy.isInteger(32)) + if (inputETy.isInteger(16) && resultETy.isInteger(16)) return success(); return op.emitOpError("input/output element types are incompatible."); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1474,14 +1474,44 @@ // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false} // CHECK: %[[OUTZP:.+]] = arith.constant -128 // CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]] - // CHECK: %[[MIN:.+]] = arith.constant -2147483648 - // CHECK: %[[MAX:.+]] = arith.constant 2147483647 + // CHECK: %[[MIN:.+]] = arith.constant -128 + // CHECK: %[[MAX:.+]] = arith.constant 127 // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]] // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]] // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]] // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] - // CHECK: linalg.yield %[[CLMP_MAX]] - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi32> + // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]] + // CHECK: linalg.yield %[[TRUNC]] + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8> + return +} + +// ----- + +// CHECK-LABEL: @avg_pool_i16 +func @avg_pool_i16(%arg0 : tensor<1x128x128x2xi16>) -> () { + + // CHECK: linalg.pooling_nhwc_sum + // CHECK: linalg.generic + + // CHECK: %[[INZP:.+]] = arith.constant -128 + // CHECK: %[[INZP_OFF:.+]] = arith.muli %{{.+}}, %[[INZP]] + // CHECK: %[[OFFSETED:.+]] = arith.subi %arg1, %[[INZP_OFF]] + // CHECK: %[[NUMERATOR:.+]] = arith.constant 1073741825 + // CHECK: %[[MULTIPLIER:.+]] = arith.divui %[[NUMERATOR]], %{{.+}} + // CHECK: %[[SHIFT:.+]] = arith.constant 30 + // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false} + // CHECK: %[[OUTZP:.+]] = arith.constant -128 + // CHECK: %[[OUT:.+]] = arith.addi %[[SCALE]], %[[OUTZP]] + // CHECK: %[[MIN:.+]] = arith.constant -32768 + // CHECK: %[[MAX:.+]] = arith.constant 32767 + // CHECK: %[[CMP_MIN:.+]] = arith.cmpi slt, %[[OUT]], %[[MIN]] + // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]] + // CHECK: %[[CMP_MAX:.+]] = arith.cmpi slt, %[[MAX]], %[[OUT]] + // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] + // CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLMP_MAX]] + // CHECK: linalg.yield %[[TRUNC]] + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi16>) -> tensor<1x32x32x2xi16> return } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -18,23 +18,23 @@ // ----- // CHECK-LABEL: avg_pool2d_i8 -func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32> - return %0 : tensor<1x7x7x9xi32> +func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> } // ----- // CHECK-LABEL: avg_pool2d_i16 -func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32> - return %0 : tensor<1x7x7x9xi32> +func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> + return %0 : tensor<1x7x7x9xi16> } // ----- // CHECK-LABEL: avg_pool2d_q8 -func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> - return %0 : tensor<1x7x7x9x!quant.uniform> +func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + return %0 : tensor<1x7x7x9x!quant.uniform> } // -----