diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -82,6 +82,8 @@ ); let builders = [Tosa_AvgPool2dOpQuantInfoBuilder]; + + let verifier = [{ return verifyAveragePoolOp(*this); }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2796,7 +2796,7 @@ Type inElementTy = inputTy.getElementType(); ShapedType resultTy = op.getType().template cast(); - Type resultETy = inputTy.getElementType(); + Type resultETy = op.getType().cast().getElementType(); Type accETy = inElementTy.isa() ? rewriter.getI32Type() : inElementTy; @@ -2810,9 +2810,10 @@ pad.resize(2, 0); getValuesFromIntArrayAttribute(op.pad(), pad); pad.resize(pad.size() + 2, 0); - Attribute initialAttr = rewriter.getZeroAttr(accETy); - Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); + Attribute padAttr = rewriter.getZeroAttr(inElementTy); + Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter); + Attribute initialAttr = rewriter.getZeroAttr(accETy); Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; @@ -2909,8 +2910,7 @@ // to be applied. Value poolVal = args[0]; if (accETy.isa()) { - auto countF = - rewriter.create(loc, inElementTy, countI); + auto countF = rewriter.create(loc, accETy, countI); poolVal = rewriter.create(loc, poolVal, countF)->getResult(0); } else { @@ -2974,8 +2974,11 @@ auto clamp = clampHelper( loc, scaled, min, max, CmpIPredicate::slt, rewriter); + poolVal = clamp; // Convert type. - poolVal = rewriter.create(loc, resultETy, clamp); + if (resultETy != clamp.getType()) { + poolVal = rewriter.create(loc, resultETy, poolVal); + } } // Cast to output type. diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -342,6 +342,26 @@ return success(); } +static LogicalResult verifyAveragePoolOp(tosa::AvgPool2dOp op) { + auto inputETy = op.input().getType().cast().getElementType(); + auto resultETy = op.getType().cast().getElementType(); + + if (auto quantType = inputETy.dyn_cast()) + inputETy = quantType.getStorageType(); + + if (auto quantType = resultETy.dyn_cast()) + resultETy = quantType.getStorageType(); + + if (inputETy.isF32()) + return resultETy.isF32() ? success() : failure(); + if (inputETy.isInteger(8)) + return resultETy.isInteger(32) ? success() : failure(); + if (inputETy.isInteger(16)) + return resultETy.isInteger(32) ? success() : failure(); + + return failure(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Quantization Builders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1465,15 +1465,14 @@ // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false} // CHECK: %[[OUTZP:.+]] = constant -128 // CHECK: %[[OUT:.+]] = addi %[[SCALE]], %[[OUTZP]] - // CHECK: %[[MIN:.+]] = constant -128 - // CHECK: %[[MAX:.+]] = constant 127 + // CHECK: %[[MIN:.+]] = constant -2147483648 + // CHECK: %[[MAX:.+]] = constant 2147483647 // CHECK: %[[CMP_MIN:.+]] = cmpi slt, %[[OUT]], %[[MIN]] // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]] // CHECK: %[[CMP_MAX:.+]] = cmpi slt, %[[MAX]], %[[OUT]] // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] - // CHECK: %[[TRUNC:.+]] = trunci %[[CLMP_MAX]] - // CHECK: linalg.yield %[[TRUNC]] - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8> + // CHECK: linalg.yield %[[CLMP_MAX]] + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi32> return } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -10,12 +10,33 @@ } // ----- -// CHECK-LABEL: avg_pool2d -func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { +// CHECK-LABEL: avg_pool2d_f32 +func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> { %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> return %0 : tensor<1x7x7x9xf32> } +// ----- +// CHECK-LABEL: avg_pool2d_i8 +func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32> + return %0 : tensor<1x7x7x9xi32> +} + +// ----- +// CHECK-LABEL: avg_pool2d_i16 +func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32> + return %0 : tensor<1x7x7x9xi32> +} + +// ----- +// CHECK-LABEL: avg_pool2d_q8 +func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> { + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9x!quant.uniform>) -> tensor<1x7x7x9x!quant.uniform> + return %0 : tensor<1x7x7x9x!quant.uniform> +} + // ----- // CHECK-LABEL: conv2d func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {