diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -100,6 +100,14 @@ // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// +static bool haveZeroDimensions(ShapedType shapedType) { + if (!shapedType.hasStaticShape()) + return false; + + return llvm::count_if(shapedType.getShape(), + [](int64_t dim) { return dim == 0; }); +} + template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); @@ -115,6 +123,9 @@ return failure(); } + if (haveZeroDimensions(inputType)) + return op.emitOpError() << "cannot take tensor with zero dimensions"; + auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); @@ -142,7 +153,11 @@ } LogicalResult tosa::AvgPool2dOp::verify() { - auto inputETy = llvm::cast(getInput().getType()).getElementType(); + auto inputType = llvm::cast(getInput().getType()); + if (haveZeroDimensions(inputType)) + return emitOpError() << "cannot take tensor with zero dimensions"; + + auto inputETy = inputType.getElementType(); auto resultETy = llvm::cast(getType()).getElementType(); if (auto quantType = @@ -765,6 +780,9 @@ return emitOpError() << "Cannot reshape " << inputElementsNum << " elements into " << outputElementsNum; } + + if (haveZeroDimensions(inputType) || haveZeroDimensions(outputType)) + return emitOpError() << "cannot take tensor with zero dimensions"; } return mlir::success(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -150,3 +150,29 @@ %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32> return %0 : tensor<100x100xf32> } + +// ----- + +func.func @test_reshape_zero_dim_input(%arg0 : tensor<13x0x3xf32>) -> () { + // expected-error@+1 {{'tosa.reshape' op cannot take tensor with zero dimensions}} + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x0x3xf32>) -> tensor<13x0x3xf32> + return +} + +// ----- + +func.func @test_conv2d_zero_dim_input(%arg0: tensor<1x29x0x4xf32>, %arg1: tensor<16x3x3x4xf32>, %arg2: tensor<16xf32>) -> tensor<1x27x27x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op cannot take tensor with zero dimensions}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} + : (tensor<1x29x0x4xf32>, tensor<16x3x3x4xf32>, tensor<16xf32>) -> tensor<1x27x27x16xf32> + return %0 : tensor<1x27x27x16xf32> +} + +// ----- + +func.func @test_avg_pool2d_zero_dim_input(%arg0: tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op cannot take tensor with zero dimensions}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} + : (tensor<1x0x7x9xf32>) -> tensor<1x7x7x9xf32> + return %0 : tensor<1x7x7x9xf32> +} \ No newline at end of file