diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -49,103 +50,198 @@ return rewriter.getIntegerAttr(type, value); } +Value getConstantValue(Location loc, Type type, int64_t value, + PatternRewriter &rewriter) { + return rewriter.create( + loc, getConstantAttr(type, value, rewriter)); +} + // This converts the TOSA ApplyScale operator to a set of arithmetic ops, // using 64-bit operations to perform the necessary multiply, bias, and shift. -// Multiple types are used to use minimal bit width operations. -class ApplyScaleOpConverter : public OpRewritePattern { +class ApplyScale48OpConverter : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); - Value value32 = op.value(); + Value value = op.value(); Value multiplier32 = op.multiplier(); - Value shift8 = op.shift(); - bool doubleRound = op.double_round(); - Type inType = op.value().getType(); Type resultTy = op.getType(); - - Type i8Ty = matchContainerType(rewriter.getIntegerType(8), resultTy); + Type valueTy = value.getType(); + Type valueETy = getElementTypeOrSelf(valueTy); Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy); - Value one8 = rewriter.create( - loc, getConstantAttr(i8Ty, 1, rewriter)); - Value one64 = rewriter.create( - loc, getConstantAttr(i64Ty, 1, rewriter)); - - Value shiftSubOne8 = rewriter.create(loc, shift8, one8); - - // The rounding value semantics below equate to the following code: - // int64_t round = 1 << (shift - 1); - // if (double_round) { - // if (shift > 31 && value >= 0) round += 1<<30; - // if (shift > 31 && value < 0) round -= 1<<30; - // } - // - // Note that minimal bitwidth operators are used throughout the block. - - Value round64 = rewriter.create( - loc, one64, rewriter.create(loc, i64Ty, shiftSubOne8)); - - // Double rounding is performing a round operation before the shift - if (doubleRound) { - Value one32 = rewriter.create( - loc, getConstantAttr(i32Ty, 1, rewriter)); - Value shift32 = rewriter.create(loc, i32Ty, shift8); - Value thirty32 = rewriter.create( - loc, getConstantAttr(i32Ty, 30, rewriter)); - - Value shiftThirty32 = - rewriter.create(loc, one32, thirty32); - Value shiftThirty64 = - rewriter.create(loc, i64Ty, shiftThirty32); - - // Round value needs to with be added or subtracted depending on the sign - // of the input value. - Value roundAdd64 = - rewriter.create(loc, round64, shiftThirty64); - Value roundSub64 = - rewriter.create(loc, round64, shiftThirty64); - - Value zero32 = - rewriter.create(loc, rewriter.getZeroAttr(inType)); - Value valueGreaterThanZero = rewriter.create( - loc, arith::CmpIPredicate::sge, value32, zero32); + if (valueETy.getIntOrFloatBitWidth() <= 32) + return failure(); + + Value shift32 = rewriter.create(loc, i32Ty, op.shift()); + + Value zero = getConstantValue(loc, valueTy, 0, rewriter); - Value doubleRound64 = rewriter.create( - loc, valueGreaterThanZero, roundAdd64, roundSub64); + Value one64 = getConstantValue(loc, i64Ty, 1, rewriter); + Value negOne64 = getConstantValue(loc, i64Ty, -1, rewriter); + Value thirty64 = getConstantValue(loc, i64Ty, 30, rewriter); + Value thirtyOne32 = getConstantValue(loc, i32Ty, 31, rewriter); - // We only perform double rounding if the shift value is greater than 32. - Value thirtyTwo32 = rewriter.create( - loc, getConstantAttr(i32Ty, 32, rewriter)); - Value shiftGreaterThanThirtyTwo = rewriter.create( - loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); - round64 = rewriter.create(loc, shiftGreaterThanThirtyTwo, - doubleRound64, round64); + // Compute the multiplication in 64-bits then select the high / low parts. + Value value64 = rewriter.create(loc, i64Ty, value); + Value multiplier64 = + rewriter.create(loc, i64Ty, multiplier32); + Value multiply64 = + rewriter.create(loc, value64, multiplier64); + + // Apply normal rounding. + Value shift64 = rewriter.create(loc, i64Ty, shift32); + Value round = rewriter.create(loc, one64, shift64); + round = rewriter.create(loc, round, one64); + multiply64 = rewriter.create(loc, multiply64, round); + + // Apply double rounding if necessary. + if (op.double_round()) { + Value positive = rewriter.create( + loc, arith::CmpIPredicate::sge, value, zero); + Value dir = + rewriter.create(loc, positive, one64, negOne64); + Value update = rewriter.create(loc, dir, thirty64); + Value val = rewriter.create(loc, update, multiply64); + Value valid = rewriter.create( + loc, arith::CmpIPredicate::sgt, shift32, thirtyOne32); + multiply64 = + rewriter.create(loc, valid, val, multiply64); } - // The computation below equates to the following pseudocode: - // int64_t result = (int64_t)value * multiplier + round; - // result = result >> shift; - // - // Note that multiply and shift need to be perform in i64 to preserve bits. + Value result64 = rewriter.create(loc, multiply64, shift64); + Value result32 = rewriter.create(loc, i32Ty, result64); + + rewriter.replaceOp(op, result32); + return success(); + } +}; + +class ApplyScale32OpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + + Type resultTy = op.getType(); + Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); + Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy); + Value value = op.value(); + if (getElementTypeOrSelf(value.getType()).getIntOrFloatBitWidth() > 32) { + return failure(); + } + + Value value32 = op.value(); + Value multiplier32 = op.multiplier(); + Value shift32 = rewriter.create(loc, i32Ty, op.shift()); + + // Constants used during the scaling operation. + Value zero32 = getConstantValue(loc, i32Ty, 0, rewriter); + Value one32 = getConstantValue(loc, i32Ty, 1, rewriter); + Value two32 = getConstantValue(loc, i32Ty, 2, rewriter); + Value thirty32 = getConstantValue(loc, i32Ty, 30, rewriter); + Value thirtyTwo32 = getConstantValue(loc, i32Ty, 32, rewriter); + Value thirtyTwo64 = getConstantValue(loc, i64Ty, 32, rewriter); + + // Compute the multiplication in 64-bits then select the high / low parts. Value value64 = rewriter.create(loc, i64Ty, value32); Value multiplier64 = rewriter.create(loc, i64Ty, multiplier32); - Value shift64 = rewriter.create(loc, i64Ty, shift8); + Value multiply64 = + rewriter.create(loc, value64, multiplier64); - // Multiply as a pair of i64 values to guarantee the end value fits. - Value result64 = rewriter.create(loc, value64, multiplier64); - result64 = rewriter.create(loc, result64, round64); - result64 = rewriter.create(loc, result64, shift64); + // Grab out the high/low of the computation + Value high64 = + rewriter.create(loc, multiply64, thirtyTwo64); + Value high32 = rewriter.create(loc, i32Ty, high64); + Value low32 = rewriter.create(loc, value32, multiplier32); - Value result32 = rewriter.create(loc, resultTy, result64); + // Determine the direction and amount to shift the high bits. + Value shiftOver32 = rewriter.create( + loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); + Value roundHighBits = rewriter.create( + loc, arith::CmpIPredicate::sgt, shift32, thirtyTwo32); - rewriter.replaceOp(op, result32); + Value shiftHighL = + rewriter.create(loc, thirtyTwo32, shift32); + Value shiftHighR = + rewriter.create(loc, shift32, thirtyTwo32); + + shiftHighL = + rewriter.create(loc, shiftOver32, zero32, shiftHighL); + shiftHighR = + rewriter.create(loc, shiftOver32, shiftHighR, zero32); + + // Conditionally perform our double round. + if (op.double_round()) { + Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter); + Value valuePositive = rewriter.create( + loc, arith::CmpIPredicate::sge, value32, zero32); + + Value roundDir = + rewriter.create(loc, valuePositive, one32, negOne32); + roundDir = + rewriter.create(loc, shiftOver32, roundDir, zero32); + + Value shiftLow = rewriter.create(loc, low32, thirty32); + Value rounded = rewriter.create(loc, shiftLow, roundDir); + Value carry = rewriter.create(loc, rounded, two32); + + Value shiftRound = + rewriter.create(loc, roundDir, thirty32); + + low32 = rewriter.create(loc, low32, shiftRound); + high32 = rewriter.create(loc, high32, carry); + } + + // Conditionally apply rounding in the low bits. + { + Value shiftSubOne = rewriter.create(loc, shift32, one32); + Value roundBit = rewriter.create(loc, one32, shiftSubOne); + roundBit = rewriter.create(loc, roundHighBits, zero32, + roundBit); + + Value newLow32 = rewriter.create(loc, low32, roundBit); + Value wasRounded = rewriter.create( + loc, arith::CmpIPredicate::ugt, low32, newLow32); + low32 = newLow32; + + Value rounded32 = rewriter.create(loc, i32Ty, wasRounded); + high32 = rewriter.create(loc, high32, rounded32); + } + + // Conditionally apply rounding in the high bits. + { + Value shiftSubOne = + rewriter.create(loc, shiftHighR, one32); + Value roundBit = rewriter.create(loc, one32, shiftSubOne); + roundBit = rewriter.create(loc, roundHighBits, roundBit, + zero32); + high32 = rewriter.create(loc, high32, roundBit); + } + + // Combine the correct high/low bits into the final rescale result. + high32 = rewriter.create(loc, high32, shiftHighL); + high32 = rewriter.create(loc, high32, shiftHighR); + low32 = rewriter.create(loc, low32, shift32); + low32 = rewriter.create(loc, shiftOver32, zero32, low32); + + // Apply the rounding behavior and shift to the final alignment. + Value result = rewriter.create(loc, low32, high32); + + // Truncate if necessary. + if (!getElementTypeOrSelf(resultTy).isInteger(32)) { + result = rewriter.create(loc, resultTy, result); + } + + rewriter.replaceOp(op, result); return success(); } }; @@ -159,5 +255,6 @@ void mlir::tosa::populateTosaRescaleToArithConversionPatterns( RewritePatternSet *patterns) { - patterns->add(patterns->getContext()); + patterns->add(patterns->getContext()); + patterns->add(patterns->getContext()); } diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir --- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir +++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir @@ -4,116 +4,117 @@ // CHECK-LABEL: func @const_test func.func @const_test() -> (tensor) { // CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor - %0 = "tosa.const"() {value = dense<3> : tensor} : () -> tensor + %result = "tosa.const"() {value = dense<3> : tensor} : () -> tensor // CHECK: return [[C3]] - return %0 : tensor + return %result : tensor } // ----- // CHECK-LABEL: @apply_scale_test_i32 +// SCALE: "tosa.apply_scale" func.func @apply_scale_test_i32(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) { - // CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8 - // CHECK-DAG: [[C1_32:%.+]] = arith.constant 1 : i32 - // CHECK-DAG: [[C1_64:%.+]] = arith.constant 1 : i64 - // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]] - - // CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : i8 to i32 - // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : i8 to i64 - // CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]] - - // CHECK-DAG: [[C0_32:%.+]] = arith.constant 0 : i32 - // CHECK-DAG: [[C30_32:%.+]] = arith.constant 30 : i32 - // CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]] - // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : i32 to i64 - // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]] - // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]] - // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : i32 - // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = arith.select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 - // CHECK-DAG: [[C32_32:%.+]] = arith.constant 32 : i32 - // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]] - // CHECK-DAG: [[ROUND:%.+]] = arith.select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] - - // CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : i32 to i64 - // CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : i32 to i64 - // CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : i8 to i64 - // CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]] - // CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]] - // CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]] - // CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]] - - // SCALE: "tosa.apply_scale" - %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32 - return %0 : i32 + // CHECK: %[[S32:.+]] = arith.extui %arg2 : i8 to i32 + // CHECK: %[[C0:.+]] = arith.constant 0 : i32 + // CHECK: %[[C1:.+]] = arith.constant 1 : i32 + // CHECK: %[[C2:.+]] = arith.constant 2 : i32 + // CHECK: %[[C30:.+]] = arith.constant 30 : i32 + // CHECK: %[[C32:.+]] = arith.constant 32 : i32 + // CHECK: %[[C32L:.+]] = arith.constant 32 : i64 + + // Compute the high-low values of the matmul in 64-bits. + // CHECK: %[[V64:.+]] = arith.extsi %arg0 : i32 to i64 + // CHECK: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64 + // CHECK: %[[MUL64:.+]] = arith.muli %[[V64]], %[[M64]] + // CHECK: %[[HI64:.+]] = arith.shrui %[[MUL64]], %[[C32L]] + // CHECK: %[[HI:.+]] = arith.trunci %[[HI64]] : i64 to i32 + // CHECK: %[[LOW:.+]] = arith.muli %arg0, %arg1 + + // Determine whether the high bits need to shift left or right and by how much. + // CHECK: %[[OVER31:.+]] = arith.cmpi sge, %[[S32]], %[[C32]] + // CHECK: %[[OVER32:.+]] = arith.cmpi sgt, %[[S32]], %[[C32]] + // CHECK: %[[HISHLN:.+]] = arith.subi %[[C32]], %[[S32]] + // CHECK: %[[HISHRN:.+]] = arith.subi %[[S32]], %[[C32]] + // CHECK: %[[HISHL:.+]] = arith.select %[[OVER31]], %[[C0]], %[[HISHLN]] + // CHECK: %[[HISHR:.+]] = arith.select %[[OVER31]], %[[HISHRN]], %[[C0]] + + // Apply double rounding. + // CHECK: %[[CN1:.+]] = arith.constant -1 + // CHECK: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]] + // CHECK: %[[DIR:.+]] = arith.select %[[POS]], %[[C1]], %[[CN1]] + // CHECK: %[[DRND:.+]] = arith.select %[[OVER31]], %[[DIR]], %[[C0]] + // CHECK: %[[DSHFTR:.+]] = arith.shrui %[[LOW]], %[[C30]] + // CHECK: %[[DRNDED:.+]] = arith.addi %[[DSHFTR]], %[[DRND]] + // CHECK: %[[DCARRY:.+]] = arith.shrsi %[[DRNDED]], %[[C2:.+]] + // CHECK: %[[DBIT:.+]] = arith.shli %[[DRND]], %[[C30]] + // CHECK: %[[DLOW:.+]] = arith.addi %[[LOW]], %[[DBIT]] + // CHECK: %[[DHI:.+]] = arith.addi %[[HI]], %[[DCARRY]] + + // Apply low-bit rounding. + // CHECK: %[[SHFTM1:.+]] = arith.subi %[[S32]], %[[C1]] + // CHECK: %[[LBIT:.+]] = arith.shli %[[C1]], %[[SHFTM1]] + // CHECK: %[[HALF:.+]] = arith.select %[[OVER32]], %[[C0]], %[[LBIT]] + // CHECK: %[[LADD:.+]] = arith.addi %[[DLOW]], %[[HALF]] + // CHECK: %[[LLO:.+]] = arith.cmpi ugt, %[[DLOW]], %[[LADD]] + // CHECK: %[[LCARRY:.+]] = arith.extui %[[LLO]] : i1 to i32 + // CHECK: %[[LRNDED:.+]] = arith.addi %[[DHI]], %[[LCARRY]] + + // Apply high-bit rounding. + // CHECK: %[[HISHRM1:.+]] = arith.subi %[[HISHR]], %[[C1]] + // CHECK: %[[LHISHFT:.+]] = arith.shli %[[C1]], %[[HISHRM1]] + // CHECK: %[[LHI:.+]] = arith.select %[[OVER32]], %[[LHISHFT]], %[[C0]] + // CHECK: %[[FHI:.+]] = arith.addi %[[LRNDED]], %[[LHI]] + + // Combine hi-low into the final result. + // CHECK: %[[HIL:.+]] = arith.shli %[[FHI]], %[[HISHL]] + // CHECK: %[[HIALIGN:.+]] = arith.shrsi %[[HIL:.+]], %[[HISHR]] + // CHECK: %[[LOR:.+]] = arith.shrui %[[LADD]], %[[S32]] + // CHECK: %[[LOWALIGN:.+]] = arith.select %[[OVER31]], %[[C0]], %[[LOR]] + // CHECK: %[[RESULT:.+]] = arith.addi %[[LOWALIGN]], %[[HIALIGN]] + // CHECK: return %[[RESULT]] + %res = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32 + return %res : i32 } // ----- // CHECK-LABEL: @apply_scale_test_vector +// SCALE: "tosa.apply_scale" func.func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) { - // CHECK-DAG: [[C1_8:%.+]] = arith.constant dense<1> : vector<4xi8> - // CHECK-DAG: [[C1_32:%.+]] = arith.constant dense<1> : vector<4xi32> - // CHECK-DAG: [[C1_64:%.+]] = arith.constant dense<1> : vector<4xi64> - // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]] - - // CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi32> - // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : vector<4xi8> to vector<4xi64> - // CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]] - - // CHECK-DAG: [[C0_32:%.+]] = arith.constant dense<0> : vector<4xi32> - // CHECK-DAG: [[C30_32:%.+]] = arith.constant dense<30> : vector<4xi32> - // CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]] - // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : vector<4xi32> to vector<4xi64> - // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]] - // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]] - // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : vector<4xi32> - // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = arith.select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : vector<4xi1>, vector<4xi64> - // CHECK-DAG: [[C32_32:%.+]] = arith.constant dense<32> : vector<4xi32> - // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]] - // CHECK-DAG: [[ROUND:%.+]] = arith.select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] - - // CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : vector<4xi32> to vector<4xi64> - // CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : vector<4xi32> to vector<4xi64> - // CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi64> - // CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]] - // CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]] - // CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]] - // CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]] - - %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32> - return %0 : vector<4xi32> + // CHECK-NOT: "tosa.apply_scale" + %res = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32> + return %res : vector<4xi32> } // ----- // CHECK-LABEL: @apply_scale_test_i48 +// SCALE: "tosa.apply_scale" func.func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) { - // CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8 - // CHECK-DAG: [[C1_32:%.+]] = arith.constant 1 : i32 - // CHECK-DAG: [[C1_64:%.+]] = arith.constant 1 : i64 - // CHECK-DAG: [[C30_32:%.+]] = arith.constant 30 : i32 - // CHECK-DAG: [[C0_32:%.+]] = arith.constant 0 : i48 - // CHECK-DAG: [[C32_32:%.+]] = arith.constant 32 : i32 - // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]] - // CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : i8 to i32 - // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : i8 to i64 - // CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]] - // CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]] - // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : i32 to i64 - // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]] - // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]] - // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : i48 - // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = arith.select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 - // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]] - // CHECK-DAG: [[ROUND:%.+]] = arith.select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] - // CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : i48 to i64 - // CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : i32 to i64 - // CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : i8 to i64 - // CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]] - // CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]] - // CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]] - // CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]] - %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i48, i32, i8) -> i32 - return %0 : i32 + // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : i48 + // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64 + // CHECK-DAG: %[[CN1:.+]] = arith.constant -1 : i64 + // CHECK-DAG: %[[C30:.+]] = arith.constant 30 : i64 + // CHECK-DAG: %[[C31:.+]] = arith.constant 31 : i32 + // CHECK-DAG: %[[V64:.+]] = arith.extsi %arg0 : i48 to i64 + // CHECK-DAG: %[[M64:.+]] = arith.extsi %arg1 : i32 to i64 + // CHECK-DAG: %[[MUL:.+]] = arith.muli %[[V64]], %[[M64]] + + // CHECK-DAG: %[[S32:.+]] = arith.extui %arg2 : i8 to i32 + // CHECK-DAG: %[[S64:.+]] = arith.extui %[[S32]] : i32 to i64 + // CHECK-DAG: %[[ONEL:.+]] = arith.shli %[[C1]], %[[S64]] : i64 + // CHECK-DAG: %[[ONER:.+]] = arith.shrui %[[ONEL]], %[[C1]] + // CHECK-DAG: %[[ROUND:.+]] = arith.addi %[[MUL]], %[[ONER]] + // CHECK-DAG: %[[POS:.+]] = arith.cmpi sge, %arg0, %[[C0]] + // CHECK-DAG: %[[DBIT:.+]] = arith.select %[[POS]], %[[C1]], %[[CN1]] + // CHECK-DAG: %[[DSHFT:.+]] = arith.shli %[[DBIT]], %[[C30]] + // CHECK-DAG: %[[DRND:.+]] = arith.addi %[[DSHFT]], %[[ROUND]] + // CHECK-DAG: %[[USED:.+]] = arith.cmpi sgt, %[[S32]], %[[C31]] : i32 + // CHECK-DAG: %[[RES64:.+]] = arith.select %[[USED]], %[[DRND]], %[[ROUND]] : i64 + // CHECK-DAG: %[[SHR:.+]] = arith.shrsi %[[RES64]], %[[S64]] + // CHECK-DAG: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i64 to i32 + // CHECK: return %[[TRUNC]] + %res = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i48, i32, i8) -> i32 + return %res : i32 }