diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -208,13 +208,13 @@ }]; let arguments = (ins - Tosa_Tensor2D:$a, - Tosa_Tensor2D:$b, + Tosa_Tensor2Dto3D:$a, + Tosa_Tensor2Dto3D:$b, OptionalAttr:$quantization_info ); let results = (outs - Tosa_Tensor2D:$c + Tosa_Tensor2Dto3D:$c ); let builders = [Tosa_MatMulOpQuantInfoBuilder]; diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -124,6 +124,8 @@ def Tosa_Tensor1Dto4D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>; def Tosa_Tensor1Dto6D : TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>; +def Tosa_Tensor2Dto3D : TensorRankOf<[Tosa_AnyNumber], [2,3]>; + def Tosa_TensorUpto4D : TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>; def Tosa_Int32TensorUpto4D : TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>;