diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -115,12 +115,39 @@ } if (isa(op) && elementTy.isa()) { - auto mul = - rewriter.create(loc, resultTypes, args[0], args[1]); - auto constant = - rewriter.create(loc, elementTy, op->getAttr("shift")); - return rewriter.create(loc, resultTypes, mul, - constant); + Value a = args[0]; + Value b = args[1]; + auto shift = + op->getAttr("shift").cast().getValue().getSExtValue(); + if (shift > 0) { + auto shiftConst = + rewriter.create(loc, shift, /*bitwidth=*/8); + if (!a.getType().isInteger(32)) + a = rewriter.create(loc, rewriter.getI32Type(), a); + + if (!b.getType().isInteger(32)) + b = rewriter.create(loc, rewriter.getI32Type(), b); + + auto result = rewriter.create( + loc, rewriter.getI32Type(), a, b, shiftConst, + rewriter.getBoolAttr(false)); + + if (elementTy.isInteger(32)) + return result; + + return rewriter.create(loc, elementTy, result); + } + + int aWidth = a.getType().getIntOrFloatBitWidth(); + int bWidth = b.getType().getIntOrFloatBitWidth(); + int cWidth = resultTypes[0].getIntOrFloatBitWidth(); + + if (aWidth < cWidth) + a = rewriter.create(loc, resultTypes[0], a); + if (bWidth < cWidth) + b = rewriter.create(loc, resultTypes[0], b); + + return rewriter.create(loc, resultTypes, a, b); } // tosa::NegateOp diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -185,6 +185,19 @@ // ----- +// CHECK-LABEL: @test_simple_i16 +func @test_simple_i16(%arg0: tensor<1xi16>) -> () { + // CHECK: linalg.generic + // CHECK: sext + // CHECK: sext + // CHECK: muli + %0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32> + + return +} + +// ----- + // CHECK-LABEL: @test_simple_i32 func @test_simple_i32(%arg0: tensor<1xi32>) -> () { // CHECK: linalg.generic @@ -199,61 +212,66 @@ // CHECK: muli %2 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic + // CHECK: constant 2 + // CHECK: apply_scale + %3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic // CHECK: muli - %3 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + %4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: and - %4 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %5 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: or - %5 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %6 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: xor - %6 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %7 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: shift_left - %7 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %8 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: shift_right_unsigned - %8 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi - %9 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %10 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: cmpi - %10 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %11 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: select - %11 = "tosa.select"(%9, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %12 = "tosa.select"(%10, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %12 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %13 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %13 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %14 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %14 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + %15 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %15 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + %16 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> return }