diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -81,7 +81,9 @@ Value shift32 = rewriter.create(loc, i32Ty, op.getShift()); // Compute the multiplication in 64-bits then select the high / low parts. - Value value64 = rewriter.create(loc, i64Ty, value); + Value value64 = value; + if (getElementTypeOrSelf(valueTy) != rewriter.getI64Type()) + value64 = rewriter.create(loc, i64Ty, value); Value multiplier64 = rewriter.create(loc, i64Ty, multiplier32); Value multiply64 = diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir --- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir +++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir @@ -118,3 +118,40 @@ %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i48, i32, i8) -> i32 return %res : i32 } + +// ----- + +// CHECK-LABEL: @apply_scale_test_i64 +// SCALE: tosa.apply_scale +func.func @apply_scale_test_i64(%arg0 : i64, %arg1 : i32, %arg2 : i8) -> (i32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i64 + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64 + // CHECK-DAG: %[[C31:.+]] = arith.constant 31 : i32 + + // Multiply in 64 bits. + // CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64 + // CHECK-DAG: %[[MUL:.+]] = arith.muli %arg0, %[[M64]] + + // Round normally. + // CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32 + // CHECK-DAG: %[[S64:.+]] = arith.extui %[[S32]] : i32 to i64 + // CHECK-DAG: %[[ONEL:.+]] = arith.shli %[[C1]], %[[S64]] : i64 + // CHECK-DAG: %[[ONER:.+]] = arith.shrui %[[ONEL]], %[[C1]] + // CHECK-DAG: %[[ROUND:.+]] = arith.addi %[[MUL]], %[[ONER]] + + // Apply double rounding. + // CHECK-DAG: %[[DUP:.+]] = arith.constant 1073741824 : i64 + // CHECK-DAG: %[[DDOWN:.+]] = arith.constant -1073741824 : i64 + // CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]] + // CHECK-DAG: %[[DBIT:.+]] = arith.select %[[POS]], %[[DUP]], %[[DDOWN]] + // CHECK-DAG: %[[DRND:.+]] = arith.addi %[[DBIT]], %[[ROUND]] + // CHECK-DAG: %[[USED:.+]] = arith.cmpi sgt, %[[S32]], %[[C31]] : i32 + // CHECK-DAG: %[[RES64:.+]] = arith.select %[[USED]], %[[DRND]], %[[ROUND]] : i64 + + // Shift and truncate final answer. + // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]] + // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32 + // CHECK: return %[[TRUNC]] + %res = tosa.apply_scale %arg0, %arg1, %arg2 {double_round = true} : (i64, i32, i8) -> i32 + return %res : i32 +}