diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1363,17 +1363,38 @@ let description = [{ Generate a tensor for which each element in the output is a subtensor of the - values tensor along the given axis, based on the value of indices. + values tensor based on the value of indices. }]; let arguments = (ins - Tosa_Int32Or64Tensor:$indices, - Tosa_Tensor1Dto4D:$values, - I32Attr:$axis + Tosa_Tensor3D:$values, + 2DTensorOf<[Tosa_Int32]>:$indices ); let results = (outs - Tosa_Tensor1Dto4D:$output + Tosa_Tensor3D:$output + ); +} + +//===----------------------------------------------------------------------===// +// Operator: scatter +//===----------------------------------------------------------------------===// +def Tosa_ScatterOp : Tosa_Op<"scatter", [NoSideEffect]> { + let summary = "Scatter operation,"; + + let description = [{ + The values_out tensor is set to the values_in tensor with data modified as follows: + data from the input tensor is inserted at the positions specified by the indices tensor. + }]; + + let arguments = (ins + Tosa_Tensor3D:$values_in, + 2DTensorOf<[Tosa_Int32]>:$indices, + Tosa_Tensor3D:$input + ); + + let results = (outs + Tosa_Tensor3D:$values_out ); } @@ -1402,6 +1423,8 @@ Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$offset, I32Attr:$shift, + Tosa_Fp32ArrayAttr2:$stride_fp, + Tosa_Fp32ArrayAttr2:$offset_fp, Tosa_ResizeTypeAttr:$mode ); @@ -1462,20 +1485,20 @@ let description = [{ Rescale quantized values into a new domain. Supported rescalings are: Mode Input Output - signed 8 to 8 aint8 aint8 - signed 8 to 16 aint8 int16 - signed 8 to 32 aint8 int32 - signed 16 to 8 int16 aint8 + signed 8 to 8 int8 int8 + signed 8 to 16 int8 int16 + signed 8 to 32 int8 int32 + signed 16 to 8 int16 int8 signed 16 to 16 int16 int16 signed 16 to 32 int16 int32 - signed 32 to 8 int32 aint8 + signed 32 to 8 int32 int8 signed 32 to 16 int32 int16 signed 32 to 32 int32 int32 - signed 48 to 8 int48 aint8 + signed 48 to 8 int48 int8 signed 48 to 16 int48 int16 signed 48 to 32 int48 int32 - unsigned 8 to signed 8 uint8 aint8 - signed 8 to unsigned 8 aint8 uint8 + unsigned 8 to signed 8 uint8 int8 + signed 8 to unsigned 8 int8 uint8 }]; let arguments = (ins diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -66,14 +66,12 @@ //===----------------------------------------------------------------------===// // Name Symmetry Grouping Sign //===----------------------------------------------------------------------===// -// aint8 : asymmetric per tensor, signed // uint8 : asymmetric per tensor , unsigned // int4 : symmetric per channel, signed // int8 : symmetric per tensor/per channel, signed // int16 : symmetric per tensor, signed //===----------------------------------------------------------------------===// -def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"aint8", [8], 1>, - Tosa_QuantizedType<"uint8", [8], 0>, +def Tosa_QuantizedInt : AnyTypeOf<[ Tosa_QuantizedType<"uint8", [8], 0>, Tosa_QuantizedType<"int4", [4, 0], 1>, Tosa_QuantizedType<"int8", [8, 0], 1>, Tosa_QuantizedType<"int16", [16, 0], 1>]>; @@ -114,6 +112,7 @@ // Must be listed rank. def Tosa_Tensor1D : 1DTensorOf<[Tosa_AnyNumber]>; def Tosa_Tensor2D : 2DTensorOf<[Tosa_AnyNumber]>; +def Tosa_Tensor3D : 3DTensorOf<[Tosa_AnyNumber]>; def Tosa_Tensor4D : 4DTensorOf<[Tosa_AnyNumber]>; def Tosa_Tensor5D : TensorRankOf<[Tosa_AnyNumber], [5]>; def Tosa_Tensor6D : TensorRankOf<[Tosa_AnyNumber], [6]>; @@ -149,6 +148,12 @@ CPred<"$_self.cast<::mlir::ArrayAttr>().size() <= " # n>, "with at least " # n # " elements">; +def Tosa_Fp32ArrayAttr2 : Confined]>; +def Tosa_Fp32ArrayAttr3 : Confined]>; +def Tosa_Fp32ArrayAttr4 : Confined]>; +def Tosa_Fp32ArrayAttr5 : Confined]>; +def Tosa_Fp32ArrayAttr6 : Confined]>; + def Tosa_IntArrayAttr2 : Confined]>; def Tosa_IntArrayAttr3 : Confined]>; def Tosa_IntArrayAttr4 : Confined]>; diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -406,18 +406,24 @@ // ----- // CHECK-LABEL: gather -func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<26xi32>) -> tensor<26x21x3xi32> { - %0 = "tosa.gather"(%arg0, %arg1) {axis = 0 : i32, batch_dims = 0 : i64} : (tensor<13x21x3xi32>, tensor<26xi32>) -> tensor<26x21x3xi32> - return %0 : tensor<26x21x3xi32> -} - -// Test TBD -// DISABLED-CHECK-LABEL: resize -//func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { -// %0 = "tosa.const"() {value = dense<64> : tensor<2xi32>} : () -> tensor<2xi32> -// %1 = "tosa.resize"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<1x32x32x8xf32>, tensor<2xi32>) -> tensor<1x64x64x8xf32> -// return %1 : tensor<1x64x64x8xf32> -//} +func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x3xf32> { + %0 = "tosa.gather"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x3xf32> + return %0 : tensor<13x26x3xf32> +} + +// ----- +// CHECK-LABEL: scatter +func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: resize +func @test_resize(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { + %1 = "tosa.resize"(%arg0) {output_size = [64, 64], stride = [1024, 1024], offset = [0, 0], shift = 10 : i32, stride_fp = [0.0 : f32, 0.0 : f32], offset_fp = [0.0 : f32, 0.0 : f32], mode = "BILINEAR"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> + return %1 : tensor<1x64x64x8xf32> +} // ----- // CHECK-LABEL: cast