diff --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h --- a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h +++ b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h @@ -23,6 +23,9 @@ void populateTosaToStandardConversionPatterns( MLIRContext *context, OwningRewritePatternList *patterns); +void populateTosaRescaleToStandardConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns); + /// Populates passes to convert from TOSA to Standard. void addTosaToStandardPasses(OpPassManager &pm); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1494,6 +1494,25 @@ ); } +def Tosa_ApplyScaleOp: Tosa_Op<"apply_scale", [NoSideEffect]> { + let summary = "Rescale scalar operator for Tosa tensor operators"; + + let description = [{ + Applies rescaling for fixed point values + }]; + + let arguments = (ins + Tosa_Int32:$value, + Tosa_Int32:$multiplier, + Tosa_Int8:$shift, + BoolAttr:$double_round + ); + + let results = (outs + Tosa_Int32:$output + ); +} + //===----------------------------------------------------------------------===// // TOSA Spec Section 2.13 // Operator Class: Data Node Ops. diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -31,22 +31,47 @@ template static mlir::ConstantOp createConstFromIntAttribute(Operation *op, std::string attrName, - Type requiredAttrType, PatternRewriter &rewriter) { + Type requiredAttrType, OpBuilder &rewriter) { auto castedN = static_cast( op->getAttr(attrName).cast().getValue().getSExtValue()); return rewriter.create( op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } +template +static void getValuesFromIntArrayAttribute(ArrayAttr attr, + SmallVector &arrayValues) { + for (Attribute val : attr.getValue()) { + arrayValues.push_back(val.cast().getValue().getSExtValue()); + } +} + +static AffineMap createAffineMapForType(ShapedType type, + PatternRewriter &rewriter) { + unsigned rank = type.getRank(); + auto shape = type.getShape(); + SmallVector dimExprs; + dimExprs.reserve(rank); + for (unsigned i = 0; i < rank; ++i) { + // If the dimension is one we can broadcast the input with a constant + // affine expression. + if (shape[i] == 1) + dimExprs.push_back(rewriter.getAffineConstantExpr(0)); + else + dimExprs.push_back(rewriter.getAffineDimExpr(i)); + } + return AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, dimExprs, + rewriter.getContext()); +} + template -static mlir::SelectOp clampHelper(Operation *op, ValueRange args, - mlir::ConstantOp min, mlir::ConstantOp max, - P pred, PatternRewriter &rewriter) { - Location loc = op->getLoc(); - auto smallerThanMin = rewriter.create(loc, pred, args[0], min); +static mlir::SelectOp clampHelper(Location loc, Value arg, mlir::ConstantOp min, + mlir::ConstantOp max, P pred, + OpBuilder &rewriter) { + auto smallerThanMin = rewriter.create(loc, pred, arg, min); auto minOrArg = - rewriter.create(loc, smallerThanMin, min, args[0]); - auto largerThanMax = rewriter.create(loc, pred, max, args[0]); + rewriter.create(loc, smallerThanMin, min, arg); + auto largerThanMax = rewriter.create(loc, pred, max, arg); return rewriter.create(loc, largerThanMax, max, minOrArg); } @@ -210,7 +235,7 @@ op->getAttr("min_fp")); auto max = rewriter.create(loc, elementTy, op->getAttr("max_fp")); - return clampHelper(op, args, min, max, CmpFPredicate::OLT, + return clampHelper(loc, args[0], min, max, CmpFPredicate::OLT, rewriter); } @@ -219,7 +244,7 @@ rewriter); auto max = createConstFromIntAttribute(op, "max_int", elementTy, rewriter); - return clampHelper(op, args, min, max, CmpIPredicate::slt, + return clampHelper(loc, args[0], min, max, CmpIPredicate::slt, rewriter); } @@ -229,7 +254,7 @@ rewriter.create(loc, FloatAttr::get(elementTy, 0)); auto n = rewriter.create(loc, elementTy, op->getAttr("max_fp")); - return clampHelper(op, args, zero, n, CmpFPredicate::OLT, + return clampHelper(loc, args[0], zero, n, CmpFPredicate::OLT, rewriter); } @@ -238,7 +263,7 @@ rewriter.create(loc, IntegerAttr::get(elementTy, 0)); auto n = createConstFromIntAttribute(op, "max_int", elementTy, rewriter); - return clampHelper(op, args, zero, n, CmpIPredicate::slt, + return clampHelper(loc, args[0], zero, n, CmpIPredicate::slt, rewriter); } @@ -289,21 +314,9 @@ indexingMaps.reserve(operation->getNumOperands() + bodyResultTypes.size()); // Input indexing maps may be broadcasted. - for (Type types : operation->getOperandTypes()) { - auto shape = types.cast().getShape(); - SmallVector dimExprs; - dimExprs.reserve(nloops); - for (unsigned i = 0; i < nloops; ++i) { - // If the dimension is one we can broadcast the input with a constant - // affine expression. - if (shape[i] == 1) - dimExprs.push_back(rewriter.getAffineConstantExpr(0)); - else - dimExprs.push_back(rewriter.getAffineDimExpr(i)); - } - indexingMaps.push_back(AffineMap::get(/*dimCount=*/nloops, - /*symbolCount=*/0, dimExprs, - rewriter.getContext())); + for (Type type : operation->getOperandTypes()) { + indexingMaps.push_back( + createAffineMapForType(type.cast(), rewriter)); } indexingMaps.append(operation->getNumResults(), @@ -631,6 +644,146 @@ } }; +class RescaleOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::RescaleOp op, + PatternRewriter &rewriter) const final { + auto loc = op.getLoc(); + auto input = op.input(); + auto inputTy = op.input().getType().cast(); + auto outputTy = op.output().getType().cast(); + unsigned rank = inputTy.getRank(); + + if (!outputTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "tosa to linalg conversion expects statically shaped tensors"); + + // The shift and multiplier values. + SmallVector multiplierValues; + getValuesFromIntArrayAttribute(op.multiplier(), multiplierValues); + + SmallVector shiftValues; + getValuesFromIntArrayAttribute(op.shift(), shiftValues); + + // Double round only occurs if shift is greater than 31, check that this + // is ever true. + bool doubleRound = + op.double_round() && + llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); + + bool onlyDoubleRound = + op.double_round() && + llvm::all_of(shiftValues, [](int32_t v) { return v > 31; }); + + // We need to broadcast along the last dimension, so make all dims 1. + SmallVector multiplierShape; + multiplierShape.resize(rank, 1); + + SmallVector shiftShape; + shiftShape.resize(rank, 1); + + // Set the channel dimension to match the number of shift/broadcast + // channels. + if (!multiplierShape.empty()) + multiplierShape.back() = multiplierValues.size(); + if (!shiftShape.empty()) + shiftShape.back() = shiftValues.size(); + + // Create the tensor types. + auto multiplierType = + RankedTensorType::get(multiplierShape, rewriter.getI32Type()); + auto shiftType = + RankedTensorType::get(shiftShape, rewriter.getIntegerType(8)); + + auto multiplierConst = rewriter.create( + loc, DenseIntElementsAttr::get(multiplierType, multiplierValues)); + + auto shiftConst = rewriter.create( + loc, DenseIntElementsAttr::get(shiftType, shiftValues)); + + // Construct the indexing maps needed for linalg.generic ops. + SmallVector bodyArgTypes = {getElementTypeOrSelf(inputTy), + rewriter.getI32Type(), + rewriter.getI32Type()}; + Value initTensor = rewriter.create( + loc, ArrayRef({}), outputTy.getShape(), + outputTy.getElementType()); + + SmallVector indexingMaps; + + // Indexing map for input values. + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); + + // Shift and multiplier will need to broadcast across their non channel + // values. + indexingMaps.push_back(createAffineMapForType(multiplierType, rewriter)); + indexingMaps.push_back(createAffineMapForType(shiftType, rewriter)); + + // Indexing maps for output values. + indexingMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); + + auto linalgOp = rewriter.create( + loc, outputTy, ValueRange{input, multiplierConst, shiftConst}, + ValueRange{initTensor}, indexingMaps, getNParallelLoopsAttrs(rank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange blockArgs) { + // For now we do all of our math in 64-bit. This is not optimal but + // should be correct for now, consider computing correct bit depth + // later. + auto inputZp = createConstFromIntAttribute( + op, "input_zp", nestedBuilder.getI32Type(), nestedBuilder); + auto outputZp = createConstFromIntAttribute( + op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder); + + Value value = blockArgs[0]; + Value multiplier = blockArgs[1]; + Value shift = blockArgs[2]; + + if (value.getType().getIntOrFloatBitWidth() < 32) { + value = nestedBuilder.create( + nestedLoc, nestedBuilder.getI32Type(), value); + } + + value = nestedBuilder.create(nestedLoc, value, inputZp); + + value = nestedBuilder.create( + loc, nestedBuilder.getI32Type(), value, multiplier, shift, + nestedBuilder.getBoolAttr(doubleRound)); + + // Move to the new zero-point. + value = nestedBuilder.create(nestedLoc, value, outputZp); + + // Saturate to the output size. + IntegerType outIntType = + blockArgs.back().getType().cast(); + unsigned outBitWidth = outIntType.getWidth(); + auto intMin = nestedBuilder.create( + loc, nestedBuilder.getIntegerAttr( + nestedBuilder.getI32Type(), + APInt::getSignedMinValue(outBitWidth).getSExtValue())); + auto intMax = nestedBuilder.create( + loc, nestedBuilder.getIntegerAttr( + nestedBuilder.getI32Type(), + APInt::getSignedMaxValue(outBitWidth).getSExtValue())); + + value = clampHelper(nestedLoc, value, intMin, intMax, + CmpIPredicate::slt, nestedBuilder); + + if (outIntType.getWidth() < 32) { + value = + nestedBuilder.create(nestedLoc, outIntType, value); + } + + nestedBuilder.create(loc, value); + }); + + rewriter.replaceOp(op, linalgOp->getResults()); + return success(); + } +}; + // At the codegen level any identity operations should be removed. Any cases // where identity is load-bearing (e.g. cross device computation) should be // handled before lowering to codegen. @@ -681,5 +834,5 @@ IdentityNConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReduceConverter, ReshapeOpConverter, - TransposeConverter>(context); + TransposeConverter, RescaleOpConverter>(context); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp @@ -40,6 +40,13 @@ ConversionTarget target(getContext()); target.addLegalDialect(); target.addIllegalDialect(); + + // Not every TOSA op can be legalized to linalg. + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); FuncOp func = getFunction(); diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -32,9 +32,116 @@ } }; +// This converts the TOSA ApplyScale operator to a set of StandardOps ops, +// using 64-bit operations to perform the necessary multiply, bias, and shift. +// Multiple types are used to use minimal bit width operations. +class ApplyScaleOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + Value value32 = op.value(); + Value multiplier32 = op.multiplier(); + Value shift8 = op.shift(); + bool doubleRound = op.double_round(); + + Value one8 = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1)); + Value one32 = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); + Value one64 = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); + + Value shiftSubOne8 = rewriter.create(loc, shift8, one8); + + // The rounding value semantics below equate to the following code: + // int64_t round = 1 << (shift - 1); + // if (double_round) { + // if (shift > 31 && value >= 0) round += 1<<30; + // if (shift > 31 && value < 0) round -= 1<<30; + // } + // + // Note that minimal bitwidth operators are used throughout the block. + + Value shift32 = rewriter.create( + loc, rewriter.getI32Type(), shift8); + + Value round64 = rewriter.create( + loc, one64, + rewriter.create(loc, rewriter.getI64Type(), + shiftSubOne8)); + + // Double rounding is performing a round operation before the shift + if (doubleRound) { + Value zero32 = rewriter.create( + loc, rewriter.getZeroAttr(rewriter.getI32Type())); + Value thirty32 = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30)); + + Value shiftThirty32 = + rewriter.create(loc, one32, thirty32); + Value shiftThirty64 = rewriter.create( + loc, rewriter.getI64Type(), shiftThirty32); + + // Round value needs to with be added or sbustracted depending on + Value roundAdd64 = + rewriter.create(loc, round64, shiftThirty64); + Value roundSub64 = + rewriter.create(loc, round64, shiftThirty64); + + Value valueGreaterThanZero = rewriter.create( + loc, CmpIPredicate::sge, value32, zero32); + + Value doubleRound64 = rewriter.create( + loc, valueGreaterThanZero, roundAdd64, roundSub64); + + // We only perform double rounding if the shift value is greater than 32. + Value thirtyTwo32 = rewriter.create( + loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32)); + Value shiftGreaterThanThirtyTwo = rewriter.create( + loc, CmpIPredicate::sge, shift32, thirtyTwo32); + round64 = rewriter.create(loc, shiftGreaterThanThirtyTwo, + doubleRound64, round64); + } + + // The computation below equates to the following pseudocode: + // int64_t result = (int64_t)value * multiplier + round; + // result = result >> shift; + // + // Note that multiply and shift need to be perform in i64 to preserve bits. + + Value value64 = + rewriter.create(loc, rewriter.getI64Type(), value32); + Value multiplier64 = rewriter.create( + loc, rewriter.getI64Type(), multiplier32); + Value shift64 = + rewriter.create(loc, rewriter.getI64Type(), shift8); + + // Multiply as a pair of i64 values to guarantee the end value fits. + Value result64 = rewriter.create(loc, value64, multiplier64); + result64 = rewriter.create(loc, result64, round64); + result64 = + rewriter.create(loc, result64, shift64); + + Value result32 = rewriter.create( + loc, rewriter.getI32Type(), result64); + + rewriter.replaceOp(op, result32); + return success(); + } +}; + } // namespace void mlir::tosa::populateTosaToStandardConversionPatterns( MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert(context); + patterns->insert(context); +} + +void mlir::tosa::populateTosaRescaleToStandardConversionPatterns( + MLIRContext *context, OwningRewritePatternList *patterns) { + patterns->insert(context); } diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp @@ -32,7 +32,9 @@ OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addIllegalOp(); + target.addIllegalOp(); target.addLegalOp(); + target.addLegalDialect(); auto *op = getOperation(); mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(), diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -433,3 +433,54 @@ %4 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<5x4xi32>) -> tensor<4xi32> return } + +// ----- + +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (0)> + +// CHECK-LABEL: @rescale +func @rescale(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) { + // CHECK: [[C0:%.+]] = constant dense<19689> + // CHECK: [[C1:%.+]] = constant dense<15> + // CHECK: [[INIT:%.+]] = linalg.init_tensor [1] + // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%arg0, [[C0]], [[C1]] : tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) outs([[INIT]] : tensor<1xi8>) + // CHECK: ^bb0([[IN:%.+]]: i8, [[MULTIPLIER:%.+]]: i32, [[SHIFT:%.+]]: i8, [[UNUSED:%.+]]: i8): + // CHECK: [[C243:%.+]] = constant 243 + // CHECK: [[C252:%.+]] = constant 252 + + // CHECK-DAG: [[IN32:%.+]] = sexti [[IN]] + // CHECK-DAG: [[IN_ZEROED:%.+]] = subi [[IN32]], [[C243]] + // CHECK-DAG: [[SCALED:%.+]] = "tosa.apply_scale"([[IN_ZEROED]], [[MULTIPLIER]], [[SHIFT]]) {double_round = false} + // CHECK-DAG: [[SCALED_ZEROED:%.+]] = addi [[SCALED]], [[C252]] + // CHECK-DAG: [[CMIN:%.+]] = constant -128 + // CHECK-DAG: [[CMAX:%.+]] = constant 127 + // CHECK-DAG: [[MINLT:%.+]] = cmpi slt, [[SCALED_ZEROED]], [[CMIN]] + // CHECK-DAG: [[MAXLT:%.+]] = cmpi slt, [[CMAX]], [[SCALED_ZEROED]] + // CHECK-DAG: [[LOWER:%.+]] = select [[MINLT]], [[CMIN]], [[SCALED_ZEROED]] + // CHECK-DAG: [[BOUNDED:%.+]] = select [[MAXLT]], [[CMAX]], [[LOWER]] + // CHECK-DAG: [[TRUNC:%.+]] = trunci [[BOUNDED]] + // CHECK-DAG: linalg.yield [[TRUNC]] + %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = false, double_round = false, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) + + // CHECK: return [[GENERIC]] + return %0 : tensor<1xi8> +} + +// CHECK-LABEL: @rescaleDoubleRound +func @rescaleDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) { + // CHECK: linalg.generic + // CHECK: "tosa.apply_scale" + // CHECK-SAME: {double_round = true} + %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [33 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) + return %0 : tensor<1xi8> +} + +// CHECK-LABEL: @rescaleUnnecessaryDoubleRound +func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<1xi8>) -> (tensor<1xi8>) { + // CHECK: linalg.generic + // CHECK: "tosa.apply_scale" + // CHECK-SAME: {double_round = false} + %0 = "tosa.rescale"(%arg0) {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = [19689 : i32], shift = [15 : i32], scale32 = true, double_round = true, per_channel = false} : (tensor<1xi8>) -> (tensor<1xi8>) + return %0 : tensor<1xi8> +} diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir --- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -8,3 +8,40 @@ // CHECK: return [[C3]] return %0 : tensor } + + +// ----- + +func @const_test(%arg0 : i32, %arg1 : i32, %arg2 : i8) -> (i32) { + // CHECK: [[C1_8:%.+]] = constant 1 : i8 + // CHECK: [[C1_32:%.+]] = constant 1 : i32 + // CHECK: [[C1_64:%.+]] = constant 1 : i64 + // CHECK: [[SHIFT_MINUS_ONE_8:%.+]] = subi %arg2, [[C1_8]] + + // CHECK: [[SHIFT_32:%.+]] = sexti %arg2 : i8 to i32 + // CHECK: [[SHIFT_MINUS_ONE_64:%.+]] = sexti [[SHIFT_MINUS_ONE_8]] : i8 to i64 + // CHECK: [[SHIFTED_64:%.+]] = shift_left [[C1_64]], [[SHIFT_MINUS_ONE_64]] + + // CHECK: [[C0_32:%.+]] = constant 0 : i32 + // CHECK: [[C30_32:%.+]] = constant 30 : i32 + // CHECK: [[SECOND_BIAS:%.+]] = shift_left [[C1_32]], [[C30_32]] + // CHECK: [[SECOND_BIAS_64:%.+]] = sexti [[SECOND_BIAS]] : i32 to i64 + // CHECK: [[POSITIVE_ROUND:%.+]] = addi [[SHIFTED_64]], [[SECOND_BIAS_64]] + // CHECK: [[NEGATIVE_ROUND:%.+]] = subi [[SHIFTED_64]], [[SECOND_BIAS_64]] + // CHECK: [[VALUE_NEGATIVE:%.+]] = cmpi sge, %arg0, [[C0_32]] : i32 + // CHECK: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : i64 + // CHECK: [[C32_32:%.+]] = constant 32 : i32 + // CHECK: [[IS_32BIT_SHIFT:%.+]] = cmpi sge, [[SHIFT_32]], [[C32_32]] + // CHECK: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] + + // CHECK: [[VAL_64:%.+]] = sexti %arg0 : i32 to i64 + // CHECK: [[MULTIPLY_64:%.+]] = sexti %arg1 : i32 to i64 + // CHECK: [[SHIFT_64:%.+]] = sexti %arg2 : i8 to i64 + // CHECK: [[SCALED:%.+]] = muli [[VAL_64]], [[MULTIPLY_64]] + // CHECK: [[BIASED:%.+]] = addi [[SCALED]], [[ROUND]] + // CHECK: [[DOWNSHIFTED:%.+]] = shift_right_signed [[BIASED]], [[SHIFT_64]] + // CHECK: [[TRUNCATED:%.+]] = trunci [[DOWNSHIFTED]] + + %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (i32, i32, i8) -> i32 + return %0 : i32 +}