diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -108,7 +108,7 @@ let arguments = (ins Tosa_Tensor4D:$input, - Tosa_Tensor4D:$weight, + 4DTensorOf<[Tosa_Weight]>:$weight, Tosa_Tensor1D:$bias, Tosa_IntArrayAttr4:$pad, @@ -140,7 +140,7 @@ let arguments = (ins Tosa_Tensor5D:$input, - Tosa_Tensor5D:$weight, + TensorRankOf<[Tosa_Weight], [5]>:$weight, Tosa_Tensor1D:$bias, Tosa_IntArrayAttr6:$pad, @@ -173,7 +173,7 @@ let arguments = (ins Tosa_Tensor4D:$input, - Tosa_Tensor4D:$weight, + 4DTensorOf<[Tosa_Weight]>:$weight, Tosa_Tensor1D:$bias, Tosa_IntArrayAttr4:$pad, @@ -235,7 +235,7 @@ let arguments = (ins Tosa_Tensor2D:$input, - Tosa_Tensor2D:$weight, + 2DTensorOf<[Tosa_Weight]>:$weight, Tosa_Tensor1D:$bias, OptionalAttr:$quantization_info ); @@ -351,7 +351,7 @@ let arguments = (ins Tosa_Tensor4D:$input, - Tosa_Tensor4D:$filter, + 4DTensorOf<[Tosa_Weight]>:$filter, Tosa_Tensor1D:$bias, Tosa_IntArrayAttr4:$out_pad, @@ -1817,7 +1817,7 @@ ); let results = (outs - Tosa_Tensor_Plus_F64:$output + TensorOf<[AnyTypeOf<[Tosa_AnyNumber_Plus_F64, Tosa_Int4]>]>:$output ); let hasFolder = 1; } diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -41,6 +41,7 @@ def Tosa_UInt8 : UI<8>; def Tosa_UInt16 : UI<16>; +def Tosa_Int4 : I<4>; def Tosa_Int8 : I<8>; def Tosa_Int16 : I<16>; def Tosa_Int32 : I<32>; @@ -95,10 +96,16 @@ //===----------------------------------------------------------------------===// def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float], "number">; + // Add F64 type support just for tosa::CastOp and tosa::ConstOp def Tosa_AnyNumber_Plus_F64 : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, Tosa_Float, F64], "number_plus_f64">; +// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp, +// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp, tosa::FullyConnectedOp +def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8, + Tosa_QuantizedInt, Tosa_Float]>; + //===----------------------------------------------------------------------===// // Tensor types //===----------------------------------------------------------------------===// @@ -109,6 +116,7 @@ // Either ranked or unranked tensor of TOSA supported element types. def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>; def Tosa_Tensor_Plus_F64 : TensorOf<[Tosa_AnyNumber_Plus_F64]>; + // Must be ranked but no further constraints def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>; diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -20,13 +20,12 @@ // ----- func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<*xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> { - // expected-error@+1 {{expect a ranked tensor for weight, got of type 'tensor<*xi8>' at index: 1}} + // expected-error@+1 {{'tosa.conv2d' op operand #1 must be 4D tensor of 4-bit signless integer or 8-bit signless integer or Quint8 type or Qint4 type or Qint8 type or Qint16 type or Qint32 type or 32-bit float or 16-bit float or bfloat16 type values, but got 'tensor<*xi8>'}} %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : (tensor<1x29x29x4xi8>, tensor<*xi8>, tensor<16xi8>) -> tensor<1x27x27x16xi8> return %0 : tensor<1x27x27x16xi8> } - // ----- func.func @test_conv2d(%arg0: tensor<1x29x29x4xi8>, %arg1: tensor<16x3x3x4xi8>, %arg2: tensor<16xi8>) -> tensor<1x27x27x16xi8> {