diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -208,13 +208,13 @@ }]; let arguments = (ins - Tosa_Tensor2Dto3D:$a, - Tosa_Tensor2Dto3D:$b, + Tosa_Tensor3D:$a, + Tosa_Tensor3D:$b, OptionalAttr:$quantization_info ); let results = (outs - Tosa_Tensor2Dto3D:$c + Tosa_Tensor3D:$c ); let builders = [Tosa_MatMulOpQuantInfoBuilder]; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -124,8 +124,6 @@ def Tosa_Tensor1Dto4D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>; def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>; -def Tosa_Tensor2Dto3D : TensorRankOf<[Tosa_AnyNumber], [2,3]>; - def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>; def Tosa_Int32TensorUpto4D : TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>; diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -39,9 +39,9 @@ // ----- // CHECK-LABEL: test_matmul -func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> { - %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32> - return %0 : tensor<14x28xf32> +func @test_matmul(%arg0: tensor<1x14x19xf32>, %arg1: tensor<1x19x28xf32>) -> tensor<1x14x28xf32> { + %0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> + return %0 : tensor<1x14x28xf32> } // -----