diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -38,6 +38,8 @@ // Used to express accumulator results or compare results. //===----------------------------------------------------------------------===// +def Tosa_UInt8 : UI<8>; + def Tosa_Int8 : I<8>; def Tosa_Int16 : I<16>; def Tosa_Int32 : I<32>; @@ -54,6 +56,7 @@ // No unsigned unquantized int types. def Tosa_Int : AnyTypeOf<[Tosa_Bool, + Tosa_UInt8, Tosa_SignedInt]>; def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1544,12 +1544,12 @@ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value value = blockArgs[0]; + Type valueTy = value.getType(); // For now we do all of our math in 64-bit. This is not optimal but // should be correct for now, consider computing correct bit depth // later. - int32_t inBitwidth = - value.getType().getIntOrFloatBitWidth() > 32 ? 48 : 32; + int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32; auto inputZp = createConstFromIntAttribute( op, "input_zp", nestedBuilder.getIntegerType(inBitwidth), @@ -1561,9 +1561,21 @@ : blockArgs[multiplierArg]; Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; - if (value.getType().getIntOrFloatBitWidth() < 32) { - value = nestedBuilder.create( - nestedLoc, nestedBuilder.getI32Type(), value); + if (valueTy.getIntOrFloatBitWidth() < 32) { + if (valueTy.isUnsignedInteger()) { + value = nestedBuilder + .create( + nestedLoc, + nestedBuilder.getIntegerType( + valueTy.getIntOrFloatBitWidth()), + value) + .getResult(0); + value = nestedBuilder.create( + nestedLoc, nestedBuilder.getI32Type(), value); + } else { + value = nestedBuilder.create( + nestedLoc, nestedBuilder.getI32Type(), value); + } } value = nestedBuilder.create(nestedLoc, value, inputZp); @@ -1579,21 +1591,38 @@ IntegerType outIntType = blockArgs.back().getType().cast(); unsigned outBitWidth = outIntType.getWidth(); - auto intMin = nestedBuilder.create( - loc, nestedBuilder.getIntegerAttr( - nestedBuilder.getI32Type(), - APInt::getSignedMinValue(outBitWidth).getSExtValue())); - auto intMax = nestedBuilder.create( - loc, nestedBuilder.getIntegerAttr( - nestedBuilder.getI32Type(), - APInt::getSignedMaxValue(outBitWidth).getSExtValue())); - - value = clampHelper(nestedLoc, value, intMin, intMax, - CmpIPredicate::slt, nestedBuilder); + + int32_t intMin = APInt::getSignedMinValue(outBitWidth).getSExtValue(); + int32_t intMax = APInt::getSignedMaxValue(outBitWidth).getSExtValue(); + + // Unsigned integers have a difference output value. + if (outIntType.isUnsignedInteger()) { + intMin = 0; + intMax = APInt::getMaxValue(outBitWidth).getZExtValue(); + } + + auto intMinVal = nestedBuilder.create( + loc, + nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMin)); + auto intMaxVal = nestedBuilder.create( + loc, + nestedBuilder.getIntegerAttr(nestedBuilder.getI32Type(), intMax)); + + value = + clampHelper(nestedLoc, value, intMinVal, intMaxVal, + CmpIPredicate::slt, nestedBuilder); if (outIntType.getWidth() < 32) { - value = - nestedBuilder.create(nestedLoc, outIntType, value); + value = nestedBuilder.create( + nestedLoc, rewriter.getIntegerType(outIntType.getWidth()), + value); + + if (outIntType.isUnsignedInteger()) { + value = nestedBuilder + .create(nestedLoc, + outIntType, value) + .getResult(0); + } } nestedBuilder.create(loc, value); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -727,20 +727,19 @@ // CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> -// CHECK-LABEL: @rescale -func @rescale(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) { +// CHECK-LABEL: @rescale_i8 +func @rescale_i8(%arg0 : tensor<2xi8>) -> () { // CHECK: [[C0:%.+]] = constant 19689 // CHECK: [[C1:%.+]] = constant 15 // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xi8>) outs([[INIT]] : tensor<2xi8>) // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: i8): - // CHECK: [[C243:%.+]] = constant 243 - // CHECK: [[C252:%.+]] = constant 252 - + // CHECK: [[C17:%.+]] = constant 17 + // CHECK: [[C22:%.+]] = constant 22 // CHECK-DAG: [[IN32:%.+]] = sexti [[IN]] - // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C243]] + // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C17]] // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[C0]], [[C1]]) {double_round = false} - // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C252]] + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C22]] // CHECK-DAG: [[CMIN:%.+]] = constant -128 // CHECK-DAG: [[CMAX:%.+]] = constant 127 // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]] @@ -749,10 +748,63 @@ // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]] // CHECK-DAG: linalg.yield [[TRUNC]] - %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> (tensor<2xi8>) + %0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> (tensor<2xi8>) - // CHECK: return [[GENERIC]] - return %0 : tensor<2xi8> + // CHECK: [[C0:%.+]] = constant 19689 + // CHECK: [[C1:%.+]] = constant 15 + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xi8>) outs([[INIT]] : tensor<2xui8>) + // CHECK: ^bb0([[IN:%.+]]: i8, [[UNUSED:%.+]]: ui8): + // CHECK: [[C17:%.+]] = constant 17 + // CHECK: [[C22:%.+]] = constant 22 + // CHECK-DAG: [[IN32:%.+]] = sexti [[IN]] + // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C17]] + // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[C0]], [[C1]]) {double_round = false} + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C22]] + // CHECK-DAG: [[CMIN:%.+]] = constant 0 + // CHECK-DAG: [[CMAX:%.+]] = constant 255 + // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]] + // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] + // CHECK-DAG: [[MAXLT:%.+]] = cmpi slt, [[CMAX]], [[SCALED_ZEROED]] + // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] + // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]] + // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8 + // CHECK: linalg.yield [[CAST]] + %1 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> (tensor<2xui8>) + + // CHECK: return + return +} + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: @rescale_ui8 +func @rescale_ui8(%arg0 : tensor<2xui8>) -> () { + // CHECK: [[C0:%.+]] = constant 19689 + // CHECK: [[C1:%.+]] = constant 15 + // CHECK: [[INIT:%.+]] = linalg.init_tensor [2] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xui8>) outs([[INIT]] : tensor<2xi8>) + // CHECK: ^bb0([[IN:%.+]]: ui8, [[UNUSED:%.+]]: i8): + // CHECK: [[C17:%.+]] = constant 17 + // CHECK: [[C22:%.+]] = constant 22 + // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[IN]] : ui8 to i8 + // CHECK-DAG: [[IN32:%.+]] = zexti [[CAST]] + // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C17]] + // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[C0]], [[C1]]) {double_round = false} + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C22]] + // CHECK-DAG: [[CMIN:%.+]] = constant -128 + // CHECK-DAG: [[CMAX:%.+]] = constant 127 + // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]] + // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] + // CHECK-DAG: [[MAXLT:%.+]] = cmpi slt, [[CMAX]], [[SCALED_ZEROED]] + // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] + // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]] + // CHECK: linalg.yield [[TRUNC]] + %0 = "tosa.rescale"(%arg0) {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> (tensor<2xi8>) + + return } // -----