diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td @@ -37,14 +37,14 @@ }]; let arguments = (ins - Tosa_Int32Like:$value, - Tosa_Int32Like:$multiplier, + Tosa_Int:$value, + Tosa_Int:$multiplier, Tosa_Int8Like:$shift, BoolAttr:$double_round ); let results = (outs - Tosa_Int32:$output + Tosa_Int:$output ); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1337,15 +1337,20 @@ getNParallelLoopsAttrs(rank), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { + Value value = blockArgs[0]; + // For now we do all of our math in 64-bit. This is not optimal but // should be correct for now, consider computing correct bit depth // later. + int32_t inBitwidth = + value.getType().getIntOrFloatBitWidth() > 32 ? 48 : 32; + auto inputZp = createConstFromIntAttribute( - op, "input_zp", nestedBuilder.getI32Type(), nestedBuilder); + op, "input_zp", nestedBuilder.getIntegerType(inBitwidth), + nestedBuilder); auto outputZp = createConstFromIntAttribute( op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder); - Value value = blockArgs[0]; Value multiplier = multiplierConstant ? multiplierConstant : blockArgs[multiplierArg]; Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -64,11 +64,10 @@ Value multiplier32 = op.multiplier(); Value shift8 = op.shift(); bool doubleRound = op.double_round(); + Type inType = op.value().getType(); Value one8 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1)); - Value one32 = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); Value one64 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); @@ -83,9 +82,6 @@ // // Note that minimal bitwidth operators are used throughout the block. - Value shift32 = rewriter.create( - loc, rewriter.getI32Type(), shift8); - Value round64 = rewriter.create( loc, one64, rewriter.create(loc, rewriter.getI64Type(), @@ -93,8 +89,10 @@ // Double rounding is performing a round operation before the shift if (doubleRound) { - Value zero32 = rewriter.create( - loc, rewriter.getZeroAttr(rewriter.getI32Type())); + Value one32 = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); + Value shift32 = rewriter.create( + loc, rewriter.getI32Type(), shift8); Value thirty32 = rewriter.create( loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30)); @@ -109,6 +107,8 @@ Value roundSub64 = rewriter.create(loc, round64, shiftThirty64); + Value zero32 = + rewriter.create(loc, rewriter.getZeroAttr(inType)); Value valueGreaterThanZero = rewriter.create( loc, CmpIPredicate::sge, value32, zero32); diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir --- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -19,36 +19,70 @@ // ----- -func @apply_scale_test(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) { - // CHECK: [[C1_8:%.+]] = constant 1 : i8 - // CHECK: [[C1_32:%.+]] = constant 1 : i32 - // CHECK: [[C1_64:%.+]] = constant 1 : i64 - // CHECK: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]] - - // CHECK: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32 - // CHECK: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64 - // CHECK: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]] - - // CHECK: [[C0_32:%.+]] = constant 0 : i32 - // CHECK: [[C30_32:%.+]] = constant 30 : i32 - // CHECK: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]] - // CHECK: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64 - // CHECK: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]] - // CHECK: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]] - // CHECK: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i32 - // CHECK: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 - // CHECK: [[C32_32:%.+]] = constant 32 : i32 - // CHECK: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]] - // CHECK: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] - - // CHECK: [[VAL_64:%.+]] = sexti %arg0 : i32 to i64 - // CHECK: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64 - // CHECK: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64 - // CHECK: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]] - // CHECK: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]] - // CHECK: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]] +// CHECK-LABEL: @apply_scale_test_i32 +func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) { + // CHECK-DAG: [[C1_8:%.+]] = constant 1 : i8 + // CHECK-DAG: [[C1_32:%.+]] = constant 1 : i32 + // CHECK-DAG: [[C1_64:%.+]] = constant 1 : i64 + // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]] + + // CHECK-DAG: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32 + // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64 + // CHECK-DAG: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]] + + // CHECK-DAG: [[C0_32:%.+]] = constant 0 : i32 + // CHECK-DAG: [[C30_32:%.+]] = constant 30 : i32 + // CHECK-DAG: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]] + // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64 + // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]] + // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]] + // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i32 + // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 + // CHECK-DAG: [[C32_32:%.+]] = constant 32 : i32 + // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]] + // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] + + // CHECK-DAG: [[VAL_64:%.+]] = sexti %arg0 : i32 to i64 + // CHECK-DAG: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64 + // CHECK-DAG: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64 + // CHECK-DAG: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]] + // CHECK-DAG: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]] + // CHECK-DAG: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]] // CHECK: [[TRUNCATED:%.+]] = trunci [[DOWNSHIFTED]] %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32 return %0 : i32 } + +// ----- + +// CHECK-LABEL: @apply_scale_test_i48 +func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) { + // CHECK-DAG: [[C1_8:%.+]] = constant 1 : i8 + // CHECK-DAG: [[C1_32:%.+]] = constant 1 : i32 + // CHECK-DAG: [[C1_64:%.+]] = constant 1 : i64 + // CHECK-DAG: [[C30_32:%.+]] = constant 30 : i32 + // CHECK-DAG: [[C0_32:%.+]] = constant 0 : i48 + // CHECK-DAG: [[C32_32:%.+]] = constant 32 : i32 + // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]] + // CHECK-DAG: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32 + // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64 + // CHECK-DAG: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]] + // CHECK-DAG: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]] + // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64 + // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]] + // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]] + // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i48 + // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 + // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]] + // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] + // CHECK-DAG: [[VAL_64:%.+]] = sexti %arg0 : i48 to i64 + // CHECK-DAG: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64 + // CHECK-DAG: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64 + // CHECK-DAG: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]] + // CHECK-DAG: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]] + // CHECK-DAG: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]] + // CHECK: [[TRUNCATED:%.+]] = trunci [[DOWNSHIFTED]] + %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i48, i32, i8) -> i32 + return %0 : i32 +}