diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -44,12 +44,12 @@ }]; let arguments = (ins - Tosa_Tensor1Dto4D: $input, + Tosa_Tensor: $input, I64Attr: $axis ); let results = (outs - Tosa_TensorUpto4D: $output + Tosa_Tensor: $output ); } @@ -1222,12 +1222,12 @@ }]; let arguments = (ins - Tosa_Tensor1Dto4D:$input, + Tosa_Tensor:$input, I64Attr:$axis ); let results = (outs - Tosa_Tensor1Dto4D:$output + Tosa_Tensor:$output ); let hasFolder = 1; @@ -1251,12 +1251,12 @@ }]; let arguments = (ins - Tosa_Tensor1Dto4D:$input, + Tosa_Tensor:$input, I64Attr:$axis ); let results = (outs - Tosa_Tensor1Dto4D:$output + Tosa_Tensor:$output ); let hasFolder = 1; @@ -1280,12 +1280,12 @@ }]; let arguments = (ins - Tosa_Tensor1Dto4D:$input, + Tosa_Tensor:$input, I64Attr:$axis ); let results = (outs - Tosa_Tensor1Dto4D:$output + Tosa_Tensor:$output ); let hasFolder = 1; @@ -1309,12 +1309,12 @@ }]; let arguments = (ins - Tosa_Tensor1Dto4D:$input, + Tosa_Tensor:$input, I64Attr:$axis ); let results = (outs - Tosa_Tensor1Dto4D:$output + Tosa_Tensor:$output ); let hasFolder = 1; @@ -1338,12 +1338,12 @@ }]; let arguments = (ins - Tosa_Tensor1Dto4D:$input, + Tosa_Tensor:$input, I64Attr:$axis ); let results = (outs - Tosa_Tensor1Dto4D:$output + Tosa_Tensor:$output ); let hasFolder = 1; @@ -1367,12 +1367,12 @@ }]; let arguments = (ins - Tosa_Tensor1Dto4D:$input, + Tosa_Tensor:$input, I64Attr:$axis ); let results = (outs - Tosa_Tensor1Dto4D:$output + Tosa_Tensor:$output ); let hasFolder = 1; @@ -1515,12 +1515,12 @@ }]; let arguments = (ins - Tosa_Tensor1Dto4D:$input, + Tosa_Tensor:$input, I64Attr:$axis ); let results = (outs - Tosa_Tensor1Dto4D:$output + Tosa_Tensor:$output ); let hasFolder = 1; @@ -1541,13 +1541,13 @@ }]; let arguments = (ins - Tosa_Tensor1Dto6D:$input, + Tosa_Tensor:$input, DenseI64ArrayAttr:$start, DenseI64ArrayAttr:$size ); let results = (outs - Tosa_Tensor1Dto6D:$output + Tosa_Tensor:$output ); let hasCanonicalizer = 1; @@ -1568,11 +1568,11 @@ }]; let arguments = (ins - Tosa_Tensor1Dto4D:$input1, + Tosa_Tensor:$input1, DenseI64ArrayAttr:$multiples); let results = (outs - Tosa_Tensor1Dto4D:$output + Tosa_Tensor:$output ); let hasFolder = 1; @@ -1592,12 +1592,12 @@ }]; let arguments = (ins - Tosa_Tensor1Dto6D:$input1, + Tosa_Tensor:$input1, Tosa_Int32Or64Tensor:$perms ); let results = ( - outs Tosa_Tensor1Dto6D:$output + outs Tosa_Tensor:$output ); let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td @@ -74,6 +74,14 @@ let cppNamespace = "mlir::tosa"; } +def TosaLevelType : I32EnumAttr<"TosaLevelEnum", "Tosa level", + [ + I32EnumAttrCase<"None", 0, "none">, + I32EnumAttrCase<"EightK", 1, "8k">, + ]>{ + let cppNamespace = "mlir::tosa"; +} + def TosaValidation : Pass<"tosa-validate", "func::FuncOp"> { let summary = "Validates TOSA dialect"; let description = [{ @@ -89,6 +97,9 @@ Option<"StrictOperationSpecAlignment", "strict-op-spec-alignment", "bool", /*default=*/"false", "Verify if the properties of certain operations align the spec requirement">, + Option<"levelName", "level", "std::string", + /*default=*/"\"8k\"", + "Validate if operator parameters are within specfication for the given level">, ]; } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -72,6 +72,23 @@ return success(); } +struct tosa_level_t { + int32_t MAX_RANK = 0; + int32_t MAX_KERNEL = 0; + int32_t MAX_STRIDE = 0; + int32_t MAX_SCALE = 0; + + // @todo: MAX_LOG2_SIZE value and checks + + bool operator==(const tosa_level_t &rhs) { + return MAX_RANK == rhs.MAX_RANK && MAX_KERNEL == rhs.MAX_KERNEL && + MAX_STRIDE == rhs.MAX_STRIDE && MAX_SCALE == rhs.MAX_SCALE; + } +}; + +static constexpr tosa_level_t TOSA_LEVEL_EIGHTK = {6, 8192, 8192, 64}; +static constexpr tosa_level_t TOSA_LEVEL_NONE = {0, 0, 0, 0}; + //===----------------------------------------------------------------------===// // TOSA Validation Pass. //===----------------------------------------------------------------------===// @@ -89,6 +106,8 @@ return success(); } + LogicalResult applyLevelCheck(Operation *op); + private: void populateConstantOperandChecks() { const_checkers.emplace_back(checkConstantOperandPad); @@ -96,13 +115,320 @@ const_checkers.emplace_back(checkConstantOperandFullyConnected); } + bool levelCheckKernel(Operation *op, int32_t v, + const std::string &check_desc) { + if (v > tosa_level.MAX_KERNEL) { + op->emitOpError() << "failed level check: " << check_desc; + return false; + } + return true; + } + + bool levelCheckStride(Operation *op, int32_t v, + const std::string &check_desc) { + if (v > tosa_level.MAX_STRIDE) { + op->emitOpError() << "failed level check: " << check_desc; + return false; + } + return true; + } + + bool levelCheckScale(Operation *op, int32_t v, + const std::string &check_desc) { + if (v > tosa_level.MAX_SCALE) { + op->emitOpError() << "failed level check: " << check_desc; + return false; + } + return true; + } + + bool levelCheckRank(Operation *op, const Value &v, + const std::string &check_desc) { + if (ShapedType type = dyn_cast(v.getType())) { + if (type.getRank() > tosa_level.MAX_RANK) { + op->emitOpError() << "failed level check: " << check_desc; + return false; + } + } + return true; + } + + template + bool levelCheckRanksFor(Operation *op) { + if (dyn_cast(op)) { + // level check ranks of all operands and results + for (auto v : op->getOperands()) { + if (!levelCheckRank(op, v, "operand rank(shape) <= MAX_RANK")) + return false; + } + for (auto v : op->getResults()) { + if (!levelCheckRank(op, v, "result rank(shape) <= MAX_RANK")) + return false; + } + } + return true; + } + + bool levelCheckRanks(Operation *op) { +#define CHECK_RANKS_FOR(tosa_op) \ + if (!levelCheckRanksFor(op)) \ + return false; + + // tensor operators: + CHECK_RANKS_FOR(ArgMax); + // all activation functions: + CHECK_RANKS_FOR(Clamp); + CHECK_RANKS_FOR(Sigmoid); + CHECK_RANKS_FOR(Tanh); + // all elementwise binary operators: + CHECK_RANKS_FOR(Add); + CHECK_RANKS_FOR(ArithmeticRightShift); + CHECK_RANKS_FOR(BitwiseAnd); + CHECK_RANKS_FOR(BitwiseOr); + CHECK_RANKS_FOR(BitwiseXor); + CHECK_RANKS_FOR(Div); + CHECK_RANKS_FOR(LogicalAnd); + CHECK_RANKS_FOR(LogicalLeftShift); + CHECK_RANKS_FOR(LogicalRightShift); + CHECK_RANKS_FOR(LogicalOr); + CHECK_RANKS_FOR(LogicalXor); + CHECK_RANKS_FOR(Maximum); + CHECK_RANKS_FOR(Minimum); + CHECK_RANKS_FOR(Mul); + CHECK_RANKS_FOR(Pow); + CHECK_RANKS_FOR(Sub); + CHECK_RANKS_FOR(Table); + // all elementwise unary operators: + CHECK_RANKS_FOR(Abs); + CHECK_RANKS_FOR(BitwiseNot); + CHECK_RANKS_FOR(Ceil); + CHECK_RANKS_FOR(Clz); + CHECK_RANKS_FOR(Exp); + CHECK_RANKS_FOR(Floor); + CHECK_RANKS_FOR(Log); + CHECK_RANKS_FOR(LogicalNot); + CHECK_RANKS_FOR(Negate); + CHECK_RANKS_FOR(Reciprocal); + CHECK_RANKS_FOR(Rsqrt); + // all elementwise ternary operators: + CHECK_RANKS_FOR(Select); + // all comparison operators: + CHECK_RANKS_FOR(Equal); + CHECK_RANKS_FOR(Greater); + CHECK_RANKS_FOR(GreaterEqual); + // all reduction operators: + CHECK_RANKS_FOR(ReduceAll); + CHECK_RANKS_FOR(ReduceAny); + CHECK_RANKS_FOR(ReduceMax); + CHECK_RANKS_FOR(ReduceMin); + CHECK_RANKS_FOR(ReduceProd); + CHECK_RANKS_FOR(ReduceSum); + // all data layout operators: + CHECK_RANKS_FOR(Concat); + CHECK_RANKS_FOR(Pad); + CHECK_RANKS_FOR(Reshape); + CHECK_RANKS_FOR(Reverse); + CHECK_RANKS_FOR(Slice); + CHECK_RANKS_FOR(Tile); + CHECK_RANKS_FOR(Transpose); + // all type conversion operators: + CHECK_RANKS_FOR(Cast); + CHECK_RANKS_FOR(Rescale); + // all data nodes operators: + CHECK_RANKS_FOR(Const); + CHECK_RANKS_FOR(Identity); + +#undef CHECK_RANKS_FOR + return true; + } + + // Pool Op: level check kernel/stride/pad values + template + bool levelCheckPool(Operation *op) { + if (auto pool_op = dyn_cast(op)) { + for (auto k : pool_op.getKernel()) { + if (!levelCheckKernel(op, k, "kernel <= MAX_KERNEL")) { + return false; + } + } + for (auto s : pool_op.getStride()) { + if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { + return false; + } + } + for (auto p : pool_op.getPad()) { + if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { + return false; + } + } + } + return true; + } + + // Conv Op: level check dilation/stride/pad values + template + bool levelCheckConv(Operation *op) { + if (auto conv_op = dyn_cast(op)) { + + for (auto k : conv_op.getDilation()) { + if (!levelCheckKernel(op, k, "dilation <= MAX_KERNEL")) { + return false; + } + } + for (auto p : conv_op.getPad()) { + if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { + return false; + } + } + for (auto s : conv_op.getStride()) { + if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { + return false; + } + } + auto dilation = conv_op.getDilation(); + if (ShapedType weight_type = + dyn_cast(op->getOperand(1).getType())) { + auto shape = weight_type.getShape(); + if (isa(op)) { + assert(shape.size() == 4); + assert(dilation.size() == 2); + if (!levelCheckKernel(op, dilation[0] * shape[1], + "dilation_y * KH <= MAX_KERNEL)") || + !levelCheckKernel(op, dilation[1] * shape[2], + "dilation_x * KW <= MAX_KERNEL)")) + return false; + } else if (isa(op)) { + assert(shape.size() == 5); + assert(dilation.size() == 3); + if (!levelCheckKernel(op, dilation[0] * shape[1], + "dilation_d * KD <= MAX_KERNEL)") || + !levelCheckKernel(op, dilation[1] * shape[2], + "dilation_y * KH <= MAX_KERNEL)") || + !levelCheckKernel(op, dilation[2] * shape[3], + "dilation_x * KW <= MAX_KERNEL)")) + return false; + } else if (isa(op)) { + assert(shape.size() == 4); + assert(dilation.size() == 2); + if (!levelCheckKernel(op, dilation[0] * shape[0], + "dilation_y * KH <= MAX_KERNEL)") || + !levelCheckKernel(op, dilation[1] * shape[1], + "dilation_x * KW <= MAX_KERNEL)")) + return false; + } + } + } + return true; + } + + // FFT op: level check H, W in input shape [N,H,W] + template + bool levelCheckFFT(Operation *op) { + if (isa(op)) { + for (auto v : op->getOperands()) { + if (ShapedType type = dyn_cast(v.getType())) { + auto shape = type.getShape(); + assert(shape.size() == 3); + if (!levelCheckKernel(op, shape[1], "H <= MAX_KERNEL") || + !levelCheckKernel(op, shape[2], "W <= MAX_KERNEL")) { + return false; + } + } + } + } + return true; + } + + // TransposeConv2d op: level check kH/kW, outpad, and stride + bool levelCheckTransposeConv2d(Operation *op) { + if (auto transpose = dyn_cast(op)) { + if (ShapedType filter_type = + transpose.getFilter().getType().dyn_cast()) { + auto shape = filter_type.getShape(); + assert(shape.size() == 4); + // level check kernel sizes for kH and KW + if (!levelCheckKernel(op, shape[1], "KH <= MAX_KERNEL") || + !levelCheckKernel(op, shape[2], "KW <= MAX_KERNEL")) { + return false; + } + } + for (auto p : transpose.getOutPad()) { + if (!levelCheckKernel(op, p, "pad <= MAX_KERNEL")) { + return false; + } + } + for (auto s : transpose.getStride()) { + if (!levelCheckStride(op, s, "stride <= MAX_STRIDE")) { + return false; + } + } + } + return true; + } + + // Resize op: level check max scales + bool levelCheckResize(Operation *op) { + if (auto resize = dyn_cast(op)) { + auto scale = resize.getScale(); + int16_t scale_y_n = scale[0]; + int16_t scale_y_d = scale[1]; + int16_t scale_x_n = scale[2]; + int16_t scale_x_d = scale[3]; + if (!levelCheckScale(op, scale_y_n / scale_y_d, + "scale_y_n/scale_y_d <= MAX_SCALE") || + !levelCheckScale(op, scale_x_n / scale_x_d, + "scale_x_n/scale_x_d <= MAX_SCALE")) { + return false; + } + } + return true; + } + + // configure profile and level values from pass options profileName and + // levelName + void configLevelAndProfile() { + profileType = symbolizeEnum(profileName); + + auto levelType = symbolizeEnum(levelName); + + tosa_level = TOSA_LEVEL_NONE; + if (levelType == TosaLevelEnum::EightK) { + tosa_level = TOSA_LEVEL_EIGHTK; + } + } + SmallVector> const_checkers; std::optional profileType; + tosa_level_t tosa_level; }; -void TosaValidation::runOnOperation() { - profileType = symbolizeEnum(profileName); +LogicalResult TosaValidation::applyLevelCheck(Operation *op) { + if (tosa_level == TOSA_LEVEL_NONE) { + // no need to do level checks + return success(); + } + + if (!levelCheckRanks(op)) { + return failure(); + } + + // additional level checks from spec 0.70 + if (!levelCheckPool(op) || + !levelCheckConv(op) || + !levelCheckConv(op) || + !levelCheckConv(op) || + !levelCheckFFT(op) || + !levelCheckPool(op) || + !levelCheckFFT(op) || !levelCheckTransposeConv2d(op) || + !levelCheckResize(op)) { + return failure(); + } + return success(); +} + +void TosaValidation::runOnOperation() { + configLevelAndProfile(); getOperation().walk([&](Operation *op) { for (Value operand : op->getOperands()) { if ((profileType == TosaProfileEnum::BaseInference) && @@ -117,6 +443,10 @@ // Some uses of TOSA rely on the constant operands of particular operations. if (StrictOperationSpecAlignment && failed(applyConstantOperandCheck(op))) signalPassFailure(); + + // do level checks + if (failed(applyLevelCheck(op))) + signalPassFailure(); }); } } // namespace diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -0,0 +1,698 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics --tosa-validate + + +func.func @test_argmax(%arg0: tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> { + // expected-error@+1 {{'tosa.argmax' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.argmax"(%arg0) {axis = 4 : i64} : (tensor<1x1x1x1x29x29x4xf32>) -> tensor<1x1x1x1x29x4xf32> + return %0 : tensor<1x1x1x1x29x4xf32> +} + +// ----- + +func.func @test_reduce_all(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1> { + // expected-error@+1 {{'tosa.reduce_all' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.reduce_all"(%arg0) {axis = 4 : i64} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x1x21x3xi1> + return %0 : tensor<1x1x1x1x1x21x3xi1> +} + +// ----- + +func.func @test_reduce_any(%arg0: tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> { + // expected-error@+1 {{'tosa.reduce_any' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.reduce_any"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xi1>) -> tensor<1x1x1x1x13x21x3xi1> + return %0 : tensor<1x1x1x1x13x21x3xi1> +} + +// ----- + +func.func @test_reduce_max(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { + // expected-error@+1 {{'tosa.reduce_max' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.reduce_max"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> + return %0 : tensor<1x1x1x1x13x21x3xf32> +} + +// ----- + +func.func @test_reduce_min(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { + // expected-error@+1 {{'tosa.reduce_min' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.reduce_min"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> + return %0 : tensor<1x1x1x1x13x21x3xf32> +} + +// ----- + +func.func @test_reduce_prod(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { + // expected-error@+1 {{'tosa.reduce_prod' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.reduce_prod"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> + return %0 : tensor<1x1x1x1x13x21x3xf32> +} + +// ----- + +func.func @test_reduce_sum(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { + // expected-error@+1 {{'tosa.reduce_sum' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.reduce_sum"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> + return %0 : tensor<1x1x1x1x13x21x3xf32> +} + +// ----- + +func.func @test_concat(%arg0: tensor<1x1x1x13x21x3x8xf32>, %arg1: tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32> { + // expected-error@+1 {{'tosa.concat' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.concat"(%arg0, %arg1) {axis = 3 : i64} : (tensor<1x1x1x13x21x3x8xf32>, tensor<1x1x1x13x21x3x8xf32>) -> tensor<1x1x1x26x21x3x8xf32> + return %0 : tensor<1x1x1x26x21x3x8xf32> +} + +// ----- + +func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> { + // expected-error@+1 {{'tosa.reshape' op failed level check: result rank(shape) <= MAX_RANK}} + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<13x21x3xf32>) -> tensor<1x1x1x1x1x1x819xf32> + return %0 : tensor<1x1x1x1x1x1x819xf32> +} + +// ----- + +func.func @test_reverse(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> { + // expected-error@+1 {{'tosa.reverse' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.reverse"(%arg0) {axis = 0 : i64} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x13x21x3xf32> + return %0 : tensor<1x1x1x1x13x21x3xf32> +} + +// ----- +// CHECK-LABEL: slice +func.func @test_slice(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32> { + // expected-error@+1 {{'tosa.slice' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.slice"(%arg0) {start = array, size = array} : + (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x4x11x1xf32> + return %0 : tensor<1x1x1x1x4x11x1xf32> +} + +// ----- +// CHECK-LABEL: tile +func.func @test_tile(%arg0: tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> { + // expected-error@+1 {{'tosa.tile' op failed level check: operand rank(shape) <= MAX_RANK}} + %0 = "tosa.tile"(%arg0) {multiples = array} : (tensor<1x1x1x1x13x21x3xf32>) -> tensor<1x1x1x1x39x21x6xf32> + return %0 : tensor<1x1x1x1x39x21x6xf32> +} + +// ----- + +func.func @test_transpose(%arg0: tensor<13x21x3x1x1x1x1xf32>) -> tensor<3x13x21x1x1x1x1xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1, 3, 4, 5, 6]> : tensor<7xi32>} : () -> tensor<7xi32> + // expected-error@+1 {{'tosa.transpose' op failed level check: operand rank(shape) <= MAX_RANK}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3x1x1x1x1xf32>, tensor<7xi32>) -> tensor<3x13x21x1x1x1x1xf32> + return %1 : tensor<3x13x21x1x1x1x1xf32> +} + +// ----- + +func.func @test_const(%arg0 : tensor<1x1xi32>) -> tensor<1x1x1x1x1x1x1xi32> { + // expected-error@+1 {{'tosa.const' op failed level check: result rank(shape) <= MAX_RA}} + %0 = "tosa.const"() {value = dense<0> : tensor<1x1x1x1x1x1x1xi32>} : () -> tensor<1x1x1x1x1x1x1xi32> + return %0: tensor<1x1x1x1x1x1x1xi32> +} + +// ----- + +func.func @test_avgpool2d_kernel_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}} + %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_avgpool2d_kernel_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: kernel <= MAX_KERNEL}} + %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_avgpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_avgpool2d_stride_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + + +// ----- + +func.func @test_avgpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_avgpool2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_avgpool2d_pad_left(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_avgpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.avg_pool2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.avg_pool2d"(%arg0) {kernel = array, pad = array, stride = array, acc_type = f32} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x2x2x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_dilation_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_d * KD <= MAX_KERNEL}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_dilation_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_y * KH <= MAX_KERNEL}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_dilation_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: dilation_x * KW <= MAX_KERNEL}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_pad_d0(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_pad_d1(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_pad_top(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_pad_bottom(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_pad_left(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_pad_right(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_stride_d(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_stride_y(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_conv3d_stride_x(%arg0: tensor<1x1x32x32x8xf32>, %arg1: tensor<16x2x2x2x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x1x32x32x16xf32> { + // expected-error@+1 {{'tosa.conv3d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x1x32x32x8xf32>, tensor<16x2x2x2x8xf32>, tensor<16xf32>) -> tensor<1x1x32x32x16xf32> + return %0 : tensor<1x1x32x32x16xf32> +} + +// ----- + +func.func @test_depthwise_conv2d_dilation_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: dilation_y * KH <= MAX_KERNEL}} + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> + return %0 : tensor<1x32x32x64xf32> +} + +// ----- + +func.func @test_depthwise_conv2d_dilation_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: dilation_x * KW <= MAX_KERNEL}} + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> + return %0 : tensor<1x32x32x64xf32> +} + +// ----- + +func.func @test_depthwise_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> + return %0 : tensor<1x32x32x64xf32> +} + +// ----- + +func.func @test_depthwise_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> + return %0 : tensor<1x32x32x64xf32> +} + +// ----- + +func.func @test_depthwise_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> + return %0 : tensor<1x32x32x64xf32> +} + +// ----- + +func.func @test_depthwise_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> + return %0 : tensor<1x32x32x64xf32> +} + +// ----- + +func.func @test_depthwise_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> + return %0 : tensor<1x32x32x64xf32> +} + +// ----- + +func.func @test_depthwise_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x8xf32>, %arg2: tensor<64xf32>) -> tensor<1x32x32x64xf32> { + // expected-error@+1 {{'tosa.depthwise_conv2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<2x2x8x8xf32>, tensor<64xf32>) -> tensor<1x32x32x64xf32> + return %0 : tensor<1x32x32x64xf32> +} + +// ----- + +func.func @test_fft2d_real_h(%arg0: tensor<32x8193x32xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) { + // expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}} + %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } : + (tensor<32x8193x32xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) + return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32> +} + +// ----- + +func.func @test_fft2d_real_w(%arg0: tensor<32x32x8193xf32>, %arg1: tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) { + // expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}} + %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } : + (tensor<32x32x8193xf32>, tensor<32x32x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) + return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32> +} + +// ----- + +func.func @test_fft2d_imag_h(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) { + // expected-error@+1 {{'tosa.fft2d' op failed level check: H <= MAX_KERNEL}} + %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } : + (tensor<32x32x32xf32>, tensor<32x8193x32xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) + return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32> +} + +// ----- + +func.func @test_fft2d_imag_w(%arg0: tensor<32x32x32xf32>, %arg1: tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) { + // expected-error@+1 {{'tosa.fft2d' op failed level check: W <= MAX_KERNEL}} + %0, %1 = "tosa.fft2d"(%arg0, %arg1) { inverse = false } : + (tensor<32x32x32xf32>, tensor<32x32x8193xf32>) -> (tensor<32x32x32xf32>, tensor<32x32x32xf32>) + return %0, %1 : tensor<32x32x32xf32>, tensor<32x32x32xf32> +} + +// ----- + +func.func @test_maxpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.max_pool2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_maxpool2d_kernel_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.max_pool2d' op failed level check: kernel <= MAX_KERNEL}} + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_maxpool2d_stride_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.max_pool2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_maxpool2d_stride_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.max_pool2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + + +// ----- + +func.func @test_maxpool2d_pad_top(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.max_pool2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_maxpool2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.max_pool2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_maxpool2d_pad_left(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.max_pool2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_maxpool2d_pad_right(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { + // expected-error@+1 {{'tosa.max_pool2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : + (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> + return %0 : tensor<1x32x32x8xf32> +} + +// ----- + +func.func @test_rfft2d_input_h(%arg0: tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) { + // expected-error@+1 {{'tosa.rfft2d' op failed level check: H <= MAX_KERNEL}} + %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8193x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) + return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32> +} + +// ----- + +func.func @test_rfft2d_input_w(%arg0: tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) { + // expected-error@+1 {{'tosa.rfft2d' op failed level check: W <= MAX_KERNEL}} + %0, %1 = "tosa.rfft2d"(%arg0) {} : (tensor<13x8x8193xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) + return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32> +} + +// ----- + +func.func @test_transpose_conv2d_weight_h(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x8193x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: KH <= MAX_KERNEL}} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x8193x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_transpose_conv2d_weight_w(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x8193x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: KW <= MAX_KERNEL}} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x1x8193x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_transpose_conv2d_pad_top(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_transpose_conv2d_pad_bottom(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_transpose_conv2d_pad_left(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_transpose_conv2d_pad_right(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: pad <= MAX_KERNEL}} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_transpose_conv2d_stride_y(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_transpose_conv2d_stride_x(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> { + // expected-error@+1 {{'tosa.transpose_conv2d' op failed level check: stride <= MAX_STRIDE}} + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {out_pad = array, out_shape = array, stride = array} : + (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + return %0 : tensor<1x32x32x16xf32> +} + +// ----- + +func.func @test_resize_scale_y(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { + // expected-error@+1 {{'tosa.resize' op failed level check: scale_y_n/scale_y_d <= MAX_SCALE}} + %1 = "tosa.resize"(%arg0) { scale = array, offset = array, border = array, mode = "BILINEAR"} : + (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> + return %1 : tensor<1x64x64x8xf32> +} + +// ----- + +func.func @test_resize_scale_x(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> { + // expected-error@+1 {{'tosa.resize' op failed level check: scale_x_n/scale_x_d <= MAX_SCALE}} + %1 = "tosa.resize"(%arg0) { scale = array, offset = array, border = array, mode = "BILINEAR"} : + (tensor<1x32x32x8xf32>) -> tensor<1x64x64x8xf32> + return %1 : tensor<1x64x64x8xf32> +} + +// ----- + +// CHECK-LABEL: @test_cond_if +func.func @test_cond_if(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor<1x1x1x1x1x1x1xf32>, %arg2: tensor) -> tensor<1x1x1x1x1x1x1xf32> { + %0 = "tosa.cond_if"(%arg2, %arg0, %arg1) ({ + ^bb0(%arg3: tensor<1x1x1x1x1x1x1xf32>, %arg4: tensor<1x1x1x1x1x1x1xf32>): + "tosa.yield"(%arg3) : (tensor<1x1x1x1x1x1x1xf32>) -> () + }, { + ^bb0(%arg3: tensor<1x1x1x1x1x1x1xf32>, %arg4: tensor<1x1x1x1x1x1x1xf32>): + "tosa.yield"(%arg4) : (tensor<1x1x1x1x1x1x1xf32>) -> () + }) : (tensor, tensor<1x1x1x1x1x1x1xf32>, tensor<1x1x1x1x1x1x1xf32>) -> tensor<1x1x1x1x1x1x1xf32> + return %0 : tensor<1x1x1x1x1x1x1xf32> +} + +// ----- + +// CHECK-LABEL: @test_while_loop +func.func @test_while_loop(%arg0: tensor<1x1x1x1x1x1x1xf32>, %arg1: tensor) { + %0 = "tosa.const"() {value = dense<0> : tensor} : () -> tensor + %1:2 = "tosa.while_loop"(%0, %arg0) ({ + ^bb0(%arg3: tensor, %arg4: tensor<1x1x1x1x1x1x1xf32>): + %2 = "tosa.greater_equal"(%arg3, %arg1) : (tensor, tensor) -> tensor + %3 = "tosa.logical_not"(%2) : (tensor) -> tensor + "tosa.yield"(%3) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor<1x1x1x1x1x1x1xf32>): + %2 = "tosa.const"() {value = dense<1> : tensor} : () -> tensor + %3 = "tosa.add"(%arg3, %2) : (tensor, tensor) -> tensor + "tosa.yield"(%3, %arg4) : (tensor, tensor<1x1x1x1x1x1x1xf32>) -> () + }) : (tensor, tensor<1x1x1x1x1x1x1xf32>) -> (tensor, tensor<1x1x1x1x1x1x1xf32>) + return +} + +// ----- + +// CHECK-LABEL: @test_custom +func.func @test_custom(%arg0: tensor<1x1x1x1x1x1x10xi32>) -> tensor<1x1x1x1x1x1x10xi32> { + %0 = "tosa.custom"(%arg0) {identifier="custom_test", config="tosa_mlir_test", implementation_attrs=""} : + (tensor<1x1x1x1x1x1x10xi32>) -> (tensor<1x1x1x1x1x1x10xi32>) + return %0 : tensor<1x1x1x1x1x1x10xi32> +} + +