diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1771,6 +1771,14 @@ SmallVector shiftValues; getValuesFromIntArrayAttribute(op.shift(), shiftValues); + // If we shift by more than the bitwidth, this just sets to 0. + for (int i = 0, s = multiplierValues.size(); i < s; i++) { + if (shiftValues[i] > 63) { + shiftValues[i] = 0; + multiplierValues[i] = 0; + } + } + // Double round only occurs if shift is greater than 31, check that this // is ever true. bool doubleRound = diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -43,6 +43,13 @@ "Shifted mantissa exceeds 32-bit signed output type"); multiplier = static_cast(shiftedM); + + // Shifting tops out at 63 bits. Right shift to make 63 bits the max. + if (shift > 63) { + // Shifting the multiplier by more than 32-bits is unnecessary. + multiplier = multiplier >> std::min(32, shift - 63); + shift = 63; + } } /// From a scale value, generates multiplier and shift values where @@ -71,6 +78,13 @@ "Shifted mantissa exceeds 32-bit signed output type"); multiplier = static_cast(shiftedM); + + // Shifting tops out at 63 bits. Right shift to make 63 bits the max. + if (shift > 63) { + // Shifting the multiplier by more than 32-bits is unnecessary. + multiplier = multiplier >> std::min(32, shift - 63); + shift = 63; + } } /// Generates a quantized multiplier/shift from double. diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -931,11 +931,11 @@ // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: @rescale_per_channel -func @rescale_per_channel(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) { - // CHECK: [[MULTIPLIERS:%.+]] = arith.constant dense<[42, 43]> - // CHECK: [[SHIFTS:%.+]] = arith.constant dense<[14, 15]> +func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) { + // CHECK: [[MULTIPLIERS:%.+]] = arith.constant dense<[42, 43, 0]> + // CHECK: [[SHIFTS:%.+]] = arith.constant dense<[14, 15, 90]> // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] - // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[MULTIPLIERS]], [[SHIFTS]] : tensor<2xi8>, tensor<2xi32>, tensor<2xi8>) outs([[INIT]] : tensor<2xi8>) + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]], #[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[MULTIPLIERS]], [[SHIFTS]] : tensor<3xi8>, tensor<3xi32>, tensor<3xi8>) outs([[INIT]] : tensor<3xi8>) // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8): // CHECK: [[C243:%.+]] = arith.constant 243 // CHECK: [[C252:%.+]] = arith.constant 252 @@ -952,10 +952,10 @@ // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]] // CHECK-DAG: linalg.yield [[TRUNC]] - %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32], shift = [14 : i32, 15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> (tensor<2xi8>) + %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [42 : i32, 43 : i32, 44 : i32], shift = [14 : i32, 15 : i32, 64 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<3xi8>) -> (tensor<3xi8>) // CHECK: return [[GENERIC]] - return %0 : tensor<2xi8> + return %0 : tensor<3xi8> } // -----