diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2796,7 +2796,7 @@ Type inElementTy = inputTy.getElementType(); ShapedType resultTy = op.getType().template cast(); - Type resultETy = inputTy.getElementType(); + Type resultETy = op.getType().cast().getElementType(); Type accETy = inElementTy.isa() ? rewriter.getI32Type() : inElementTy; @@ -2810,9 +2810,10 @@ pad.resize(2, 0); getValuesFromIntArrayAttribute(op.pad(), pad); pad.resize(pad.size() + 2, 0); - Attribute initialAttr = rewriter.getZeroAttr(accETy); - Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); + Attribute padAttr = rewriter.getZeroAttr(inElementTy); + Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter); + Attribute initialAttr = rewriter.getZeroAttr(accETy); Value initialValue = rewriter.create(loc, initialAttr); SmallVector kernel, stride; @@ -2909,8 +2910,7 @@ // to be applied. Value poolVal = args[0]; if (accETy.isa()) { - auto countF = - rewriter.create(loc, inElementTy, countI); + auto countF = rewriter.create(loc, accETy, countI); poolVal = rewriter.create(loc, poolVal, countF)->getResult(0); } else { @@ -2974,8 +2974,11 @@ auto clamp = clampHelper( loc, scaled, min, max, CmpIPredicate::slt, rewriter); + poolVal = clamp; // Convert type. - poolVal = rewriter.create(loc, resultETy, clamp); + if (resultETy != clamp.getType()) { + poolVal = rewriter.create(loc, resultETy, poolVal); + } } // Cast to output type. diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1465,15 +1465,14 @@ // CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false} // CHECK: %[[OUTZP:.+]] = constant -128 // CHECK: %[[OUT:.+]] = addi %[[SCALE]], %[[OUTZP]] - // CHECK: %[[MIN:.+]] = constant -128 - // CHECK: %[[MAX:.+]] = constant 127 + // CHECK: %[[MIN:.+]] = constant -2147483648 + // CHECK: %[[MAX:.+]] = constant 2147483647 // CHECK: %[[CMP_MIN:.+]] = cmpi slt, %[[OUT]], %[[MIN]] // CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]] // CHECK: %[[CMP_MAX:.+]] = cmpi slt, %[[MAX]], %[[OUT]] // CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]] - // CHECK: %[[TRUNC:.+]] = trunci %[[CLMP_MAX]] - // CHECK: linalg.yield %[[TRUNC]] - %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8> + // CHECK: linalg.yield %[[CLMP_MAX]] + %0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi32> return }