diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp --- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp +++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp @@ -52,6 +52,23 @@ } }; +Type matchContainerType(Type element, Type container) { + if (auto shapedTy = container.dyn_cast()) + return shapedTy.clone(element); + + return element; +} + +Attribute getConstantAttr(Type type, int64_t value, PatternRewriter &rewriter) { + if (auto shapedTy = type.dyn_cast()) { + Type eTy = shapedTy.getElementType(); + APInt valueInt(eTy.getIntOrFloatBitWidth(), value); + return DenseIntElementsAttr::get(shapedTy, valueInt); + } + + return rewriter.getIntegerAttr(type, value); +} + // This converts the TOSA ApplyScale operator to a set of StandardOps ops, // using 64-bit operations to perform the necessary multiply, bias, and shift. // Multiple types are used to use minimal bit width operations. @@ -65,13 +82,19 @@ Value value32 = op.value(); Value multiplier32 = op.multiplier(); Value shift8 = op.shift(); + bool doubleRound = op.double_round(); Type inType = op.value().getType(); + Type resultTy = op.getType(); + + Type i8Ty = matchContainerType(rewriter.getIntegerType(8), resultTy); + Type i32Ty = matchContainerType(rewriter.getI32Type(), resultTy); + Type i64Ty = matchContainerType(rewriter.getI64Type(), resultTy); Value one8 = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIntegerType(8), 1)); + loc, getConstantAttr(i8Ty, 1, rewriter)); Value one64 = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); + loc, getConstantAttr(i64Ty, 1, rewriter)); Value shiftSubOne8 = rewriter.create(loc, shift8, one8); @@ -85,23 +108,20 @@ // Note that minimal bitwidth operators are used throughout the block. Value round64 = rewriter.create( - loc, one64, - rewriter.create(loc, rewriter.getI64Type(), - shiftSubOne8)); + loc, one64, rewriter.create(loc, i64Ty, shiftSubOne8)); // Double rounding is performing a round operation before the shift if (doubleRound) { Value one32 = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 1)); - Value shift32 = - rewriter.create(loc, rewriter.getI32Type(), shift8); + loc, getConstantAttr(i32Ty, 1, rewriter)); + Value shift32 = rewriter.create(loc, i32Ty, shift8); Value thirty32 = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 30)); + loc, getConstantAttr(i32Ty, 30, rewriter)); Value shiftThirty32 = rewriter.create(loc, one32, thirty32); - Value shiftThirty64 = rewriter.create( - loc, rewriter.getI64Type(), shiftThirty32); + Value shiftThirty64 = + rewriter.create(loc, i64Ty, shiftThirty32); // Round value needs to with be added or subtracted depending on the sign // of the input value. @@ -120,7 +140,7 @@ // We only perform double rounding if the shift value is greater than 32. Value thirtyTwo32 = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI32Type(), 32)); + loc, getConstantAttr(i32Ty, 32, rewriter)); Value shiftGreaterThanThirtyTwo = rewriter.create( loc, arith::CmpIPredicate::sge, shift32, thirtyTwo32); round64 = rewriter.create(loc, shiftGreaterThanThirtyTwo, @@ -133,20 +153,17 @@ // // Note that multiply and shift need to be perform in i64 to preserve bits. - Value value64 = - rewriter.create(loc, rewriter.getI64Type(), value32); - Value multiplier64 = rewriter.create( - loc, rewriter.getI64Type(), multiplier32); - Value shift64 = - rewriter.create(loc, rewriter.getI64Type(), shift8); + Value value64 = rewriter.create(loc, i64Ty, value32); + Value multiplier64 = + rewriter.create(loc, i64Ty, multiplier32); + Value shift64 = rewriter.create(loc, i64Ty, shift8); // Multiply as a pair of i64 values to guarantee the end value fits. Value result64 = rewriter.create(loc, value64, multiplier64); result64 = rewriter.create(loc, result64, round64); result64 = rewriter.create(loc, result64, shift64); - Value result32 = - rewriter.create(loc, rewriter.getI32Type(), result64); + Value result32 = rewriter.create(loc, resultTy, result64); rewriter.replaceOp(op, result32); return success(); diff --git a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir --- a/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir +++ b/mlir/test/Conversion/TosaToStandard/tosa-to-standard.mlir @@ -56,6 +56,43 @@ // ----- +// CHECK-LABEL: @apply_scale_test_vector +func @apply_scale_test_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>, %arg2 : vector<4xi8>) -> (vector<4xi32>) { + // CHECK-DAG: [[C1_8:%.+]] = arith.constant dense<1> : vector<4xi8> + // CHECK-DAG: [[C1_32:%.+]] = arith.constant dense<1> : vector<4xi32> + // CHECK-DAG: [[C1_64:%.+]] = arith.constant dense<1> : vector<4xi64> + // CHECK-DAG: [[SHIFT_MINUS_ONE_8:%.+]] = arith.subi %arg2, [[C1_8]] + + // CHECK-DAG: [[SHIFT_32:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi32> + // CHECK-DAG: [[SHIFT_MINUS_ONE_64:%.+]] = arith.extsi [[SHIFT_MINUS_ONE_8]] : vector<4xi8> to vector<4xi64> + // CHECK-DAG: [[SHIFTED_64:%.+]] = arith.shli [[C1_64]], [[SHIFT_MINUS_ONE_64]] + + // CHECK-DAG: [[C0_32:%.+]] = arith.constant dense<0> : vector<4xi32> + // CHECK-DAG: [[C30_32:%.+]] = arith.constant dense<30> : vector<4xi32> + // CHECK-DAG: [[SECOND_BIAS:%.+]] = arith.shli [[C1_32]], [[C30_32]] + // CHECK-DAG: [[SECOND_BIAS_64:%.+]] = arith.extsi [[SECOND_BIAS]] : vector<4xi32> to vector<4xi64> + // CHECK-DAG: [[POSITIVE_ROUND:%.+]] = arith.addi [[SHIFTED_64]], [[SECOND_BIAS_64]] + // CHECK-DAG: [[NEGATIVE_ROUND:%.+]] = arith.subi [[SHIFTED_64]], [[SECOND_BIAS_64]] + // CHECK-DAG: [[VALUE_NEGATIVE:%.+]] = arith.cmpi sge, %arg0, [[C0_32]] : vector<4xi32> + // CHECK-DAG: [[DOUBLE_ROUNDED:%.+]] = select [[VALUE_NEGATIVE]], [[POSITIVE_ROUND]], [[NEGATIVE_ROUND]] : vector<4xi1>, vector<4xi64> + // CHECK-DAG: [[C32_32:%.+]] = arith.constant dense<32> : vector<4xi32> + // CHECK-DAG: [[IS_32BIT_SHIFT:%.+]] = arith.cmpi sge, [[SHIFT_32]], [[C32_32]] + // CHECK-DAG: [[ROUND:%.+]] = select [[IS_32BIT_SHIFT]], [[DOUBLE_ROUNDED]], [[SHIFTED_64]] + + // CHECK-DAG: [[VAL_64:%.+]] = arith.extsi %arg0 : vector<4xi32> to vector<4xi64> + // CHECK-DAG: [[MULTIPLY_64:%.+]] = arith.extsi %arg1 : vector<4xi32> to vector<4xi64> + // CHECK-DAG: [[SHIFT_64:%.+]] = arith.extsi %arg2 : vector<4xi8> to vector<4xi64> + // CHECK-DAG: [[SCALED:%.+]] = arith.muli [[VAL_64]], [[MULTIPLY_64]] + // CHECK-DAG: [[BIASED:%.+]] = arith.addi [[SCALED]], [[ROUND]] + // CHECK-DAG: [[DOWNSHIFTED:%.+]] = arith.shrsi [[BIASED]], [[SHIFT_64]] + // CHECK: [[TRUNCATED:%.+]] = arith.trunci [[DOWNSHIFTED]] + + %0 = "tosa.apply_scale"(%arg0, %arg1, %arg2) {double_round = true} : (vector<4xi32>, vector<4xi32>, vector<4xi8>) -> vector<4xi32> + return %0 : vector<4xi32> +} + +// ----- + // CHECK-LABEL: @apply_scale_test_i48 func @apply_scale_test_i48(%arg0 : i48, %arg1 : i32, %arg2 : i8) -> (i32) { // CHECK-DAG: [[C1_8:%.+]] = arith.constant 1 : i8