diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -126,6 +126,10 @@ return rewriter.create(loc, resultTypes, args); } + // tosa::DivOp + if (isa(op) && elementTy.isa()) + return rewriter.create(loc, resultTypes, args); + // tosa::ReciprocalOp if (isa(op) && elementTy.isa()) { auto one = @@ -2335,6 +2339,7 @@ PointwiseConverter, PointwiseConverter, PointwiseConverter, + PointwiseConverter, PointwiseConverter, PointwiseConverter, PointwiseConverter, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -294,34 +294,38 @@ // CHECK: apply_scale %3 = "tosa.mul"(%arg0, %arg0) {shift = 2 : i32} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic + // CHECK: divi + %4 = "tosa.div"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: linalg.generic // CHECK: [[ZERO:%.+]] = constant 0 // CHECK: subi [[ZERO]], %arg1 - %4 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> + %5 = "tosa.negate"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: and - %5 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %6 = "tosa.bitwise_and"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: or - %6 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %7 = "tosa.bitwise_or"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: xor - %7 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %8 = "tosa.bitwise_xor"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: shift_left - %8 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %9 = "tosa.logical_left_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: shift_right_unsigned - %9 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %10 = "tosa.logical_right_shift"(%arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: shift_right_signed - %10 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 0 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %11 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 0 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: constant 1 @@ -335,39 +339,39 @@ // CHECK: and // CHECK: zexti // CHECK: addi - %11 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %12 = "tosa.arithmetic_right_shift"(%arg0, %arg0) {round = 1 : i1} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi - %12 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %13 = "tosa.greater"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: cmpi - %13 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> + %14 = "tosa.greater_equal"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: select - %14 = "tosa.select"(%12, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %15 = "tosa.select"(%13, %0, %1) : (tensor<1xi1>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %15 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %16 = "tosa.maximum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %16 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + %17 = "tosa.minimum"(%0, %1) : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %17 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + %18 = "tosa.clamp"(%0) {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: cmpi // CHECK: select - %18 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + %19 = "tosa.reluN"(%0) {max_int = 5 : i64, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: constant -32768 @@ -377,24 +381,24 @@ // CHECK: cmpi slt // CHECK: select // CHECK: trunci - %19 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16> + %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi16> // CHECK: linalg.generic // CHECK: yield - %20 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32> + %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK: sexti - %21 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> + %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi64> // CHECK: linalg.generic // CHECK: constant 0 // CHECK: cmpi - %22 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> + %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xi1> // CHECK: linalg.generic // CHECK: sitofp - %23 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> + %24 = "tosa.cast"(%0) : (tensor<1xi32>) -> tensor<1xf32> return }