diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -100,6 +100,19 @@ // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// +static bool hasZeroDimension(ShapedType shapedType) { + auto rank = shapedType.getRank(); + + for (int i = 0; i < rank; i++) { + if (shapedType.isDynamicDim(i)) + continue; + if (shapedType.getDimSize(i) == 0) + return true; + } + + return false; +} + template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); @@ -115,6 +128,9 @@ return failure(); } + if (hasZeroDimension(inputType)) + return op.emitOpError() << "cannot take tensor with zero dimensions"; + auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); @@ -142,7 +158,11 @@ } LogicalResult tosa::AvgPool2dOp::verify() { - auto inputETy = llvm::cast(getInput().getType()).getElementType(); + auto inputType = llvm::cast(getInput().getType()); + if (hasZeroDimension(inputType)) + return emitOpError() << "cannot take tensor with zero dimensions"; + + auto inputETy = inputType.getElementType(); auto resultETy = llvm::cast(getType()).getElementType(); if (auto quantType = @@ -758,6 +778,9 @@ ShapedType inputType = llvm::cast(getInput1().getType()); ShapedType outputType = llvm::cast(getType()); + if (hasZeroDimension(inputType) || hasZeroDimension(outputType)) + return emitOpError() << "cannot take tensor with zero dimensions"; + if (inputType.hasStaticShape() && outputType.hasStaticShape()) { int64_t inputElementsNum = inputType.getNumElements(); int64_t outputElementsNum = outputType.getNumElements(); diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -150,3 +150,56 @@ %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32> return %0 : tensor<100x100xf32> } + +// ----- + +func.func @test_reshape_static_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () { + // expected-error@+1 {{'tosa.reshape' op cannot take tensor with zero dimensions}} + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32> + return +} + +// ----- + +func.func @test_reshape_zero_dim_input(%arg0 : tensor) -> () { + // expected-error@+1 {{'tosa.reshape' op cannot take tensor with zero dimensions}} + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor) -> tensor<13x0x3xf32> + return +} + +// ----- + +func.func @test_conv2d_static_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op cannot take tensor with zero dimensions}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} + : (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32> + return %0 : tensor<1x27x27x16xf32> +} + +// ----- + +func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x?x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op cannot take tensor with zero dimensions}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} + : (tensor<1x?x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32> + return %0 : tensor<1x27x27x16xf32> +} + + +// ----- + +func.func @test_avg_pool2d_static_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op cannot take tensor with zero dimensions}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} + : (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> + return %0 : tensor<1x7x7x9xf32> +} + +// ----- + +func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op cannot take tensor with zero dimensions}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} + : (tensor<1x0x?x9xf32>) -> tensor<1x7x7x9xf32> + return %0 : tensor<1x7x7x9xf32> +}