diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -84,7 +84,10 @@ //===----------------------------------------------------------------------===// // Operator: conv2d //===----------------------------------------------------------------------===// -def Tosa_Conv2DOp : Tosa_Op<"conv2d", [NoSideEffect]> { +def Tosa_Conv2DOp : Tosa_Op<"conv2d", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "2D Convolution Operator"; let description = [{ @@ -115,7 +118,10 @@ //===----------------------------------------------------------------------===// // Operator: conv3d //===----------------------------------------------------------------------===// -def Tosa_Conv3DOp : Tosa_Op<"conv3d", [NoSideEffect]> { +def Tosa_Conv3DOp : Tosa_Op<"conv3d", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "3D Convolution operator"; let description = [{ @@ -145,7 +151,10 @@ //===----------------------------------------------------------------------===// // Operator: depthwise_conv2d //===----------------------------------------------------------------------===// -def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [NoSideEffect]> { +def Tosa_DepthwiseConv2DOp : Tosa_Op<"depthwise_conv2d", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Depthwise 2D Convolution operator"; let description = [{ @@ -259,7 +268,10 @@ //===----------------------------------------------------------------------===// // Operator: transpose_conv2d //===----------------------------------------------------------------------===// -def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [NoSideEffect]> { +def Tosa_TransposeConv2DOp : Tosa_Op<"transpose_conv2d", [ + DeclareOpInterfaceMethods, + NoSideEffect]> { let summary = "Transpose 2D Convolution operator."; let description = [{ diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -845,6 +845,280 @@ NARY_SHAPE_INFER(tosa::SigmoidOp) #undef PRED_SHAPE_INFER +LogicalResult Conv2DOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); + Conv2DOp::Adaptor adaptor(operands); + + int32_t inputWidth = ShapedType::kDynamicSize; + int32_t inputHeight = ShapedType::kDynamicSize; + int32_t weightWidth = ShapedType::kDynamicSize; + int32_t weightHeight = ShapedType::kDynamicSize; + + // Input shape describes input width/height and batch. + if (auto inputTy = adaptor.input().getType().dyn_cast()) { + outputShape[0] = inputTy.getDimSize(0); + inputHeight = inputTy.getDimSize(1); + inputWidth = inputTy.getDimSize(2); + } + + // Weight shapes describes the filter width/height and the output channels. + if (auto weightTy = adaptor.weight().getType().dyn_cast()) { + outputShape[3] = weightTy.getDimSize(0); + weightHeight = weightTy.getDimSize(1); + weightWidth = weightTy.getDimSize(2); + } + + // Bias shape can describe the output channels. + if (auto biasTy = adaptor.bias().getType().dyn_cast()) { + outputShape[3] = ShapedType::isDynamic(outputShape[3]) + ? biasTy.getDimSize(0) + : outputShape[3]; + } + + llvm::SmallVector dilation; + llvm::SmallVector padding; + llvm::SmallVector stride; + + getI64Values(attributes.get("dilation").cast(), dilation); + getI64Values(attributes.get("pad").cast(), padding); + getI64Values(attributes.get("stride").cast(), stride); + + if (!ShapedType::isDynamic(inputHeight) && + !ShapedType::isDynamic(weightHeight)) { + int32_t inputSize = inputHeight + padding[0] + padding[1]; + int32_t filterSize = (weightHeight - 1) * dilation[0] + 1; + int32_t unstridedResult = inputSize - filterSize + 1; + outputShape[1] = (unstridedResult - 1) / stride[0] + 1; + } + + if (!ShapedType::isDynamic(inputWidth) && + !ShapedType::isDynamic(weightWidth)) { + int32_t inputSize = inputWidth + padding[2] + padding[3]; + int32_t filterSize = (weightWidth - 1) * dilation[1] + 1; + int32_t unstridedResult = inputSize - filterSize + 1; + outputShape[2] = (unstridedResult - 1) / stride[1] + 1; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +LogicalResult Conv3DOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outputShape(5, ShapedType::kDynamicSize); + Conv2DOp::Adaptor adaptor(operands); + + int32_t inputWidth = ShapedType::kDynamicSize; + int32_t inputHeight = ShapedType::kDynamicSize; + int32_t inputDepth = ShapedType::kDynamicSize; + + int32_t weightWidth = ShapedType::kDynamicSize; + int32_t weightHeight = ShapedType::kDynamicSize; + int32_t weightDepth = ShapedType::kDynamicSize; + + // Input shape describes input width/height and batch. + if (auto inputTy = adaptor.input().getType().dyn_cast()) { + outputShape[0] = inputTy.getDimSize(0); + inputHeight = inputTy.getDimSize(1); + inputWidth = inputTy.getDimSize(2); + inputDepth = inputTy.getDimSize(3); + } + + // Weight shapes describes the filter width/height and the output channels. + if (auto weightTy = adaptor.weight().getType().dyn_cast()) { + outputShape[4] = weightTy.getDimSize(0); + weightHeight = weightTy.getDimSize(1); + weightWidth = weightTy.getDimSize(2); + weightDepth = weightTy.getDimSize(3); + } + + // Bias shape can describe the output channels. + if (auto biasTy = adaptor.bias().getType().dyn_cast()) { + outputShape[4] = + (outputShape[4] == -1) ? biasTy.getDimSize(0) : outputShape[4]; + } + + llvm::SmallVector dilation; + llvm::SmallVector padding; + llvm::SmallVector stride; + + getI64Values(attributes.get("dilation").cast(), dilation); + getI64Values(attributes.get("pad").cast(), padding); + getI64Values(attributes.get("stride").cast(), stride); + + if (!ShapedType::isDynamic(inputHeight) && + !ShapedType::isDynamic(weightHeight)) { + int32_t inputSize = inputHeight + padding[0] + padding[1]; + int32_t filterSize = (weightHeight - 1) * dilation[0] + 1; + int32_t unstridedResult = inputSize - filterSize + 1; + outputShape[1] = (unstridedResult - 1) / stride[0] + 1; + } + + if (!ShapedType::isDynamic(inputWidth) && + !ShapedType::isDynamic(weightWidth)) { + int32_t inputSize = inputWidth + padding[2] + padding[3]; + int32_t filterSize = (weightWidth - 1) * dilation[1] + 1; + int32_t unstridedResult = inputSize - filterSize + 1; + outputShape[2] = (unstridedResult - 1) / stride[1] + 1; + } + + if (!ShapedType::isDynamic(inputDepth) && + !ShapedType::isDynamic(weightDepth)) { + int32_t inputSize = inputDepth + padding[4] + padding[5]; + int32_t filterSize = (weightDepth - 1) * dilation[2] + 1; + int32_t unstridedResult = inputSize - filterSize + 1; + outputShape[3] = (unstridedResult - 1) / stride[2] + 1; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + llvm::SmallVector outputShape(4, ShapedType::kDynamicSize); + DepthwiseConv2DOp::Adaptor adaptor(operands); + + int32_t inputWidth = ShapedType::kDynamicSize; + int32_t inputHeight = ShapedType::kDynamicSize; + int32_t inputChannels = ShapedType::kDynamicSize; + + int32_t weightWidth = ShapedType::kDynamicSize; + int32_t weightHeight = ShapedType::kDynamicSize; + int32_t depthChannels = ShapedType::kDynamicSize; + + // Input shape describes input width/height and batch. + if (auto inputTy = adaptor.input().getType().dyn_cast()) { + outputShape[0] = inputTy.getDimSize(0); + inputHeight = inputTy.getDimSize(1); + inputWidth = inputTy.getDimSize(2); + inputChannels = inputTy.getDimSize(3); + } + + // Weight shapes describes the filter width/height and the output channels. + if (auto weightTy = adaptor.weight().getType().dyn_cast()) { + weightHeight = weightTy.getDimSize(0); + weightWidth = weightTy.getDimSize(1); + inputChannels = ShapedType::isDynamic(inputChannels) + ? weightTy.getDimSize(2) + : inputChannels; + depthChannels = weightTy.getDimSize(3); + } + + // If both inputChannels and depthChannels are available we can determine + // the output channels. + if (!ShapedType::isDynamic(inputChannels) && + !ShapedType::isDynamic(depthChannels)) { + outputShape[3] = inputChannels * depthChannels; + } + + // Bias shape can describe the output channels. + if (auto biasTy = adaptor.bias().getType().dyn_cast()) { + outputShape[3] = ShapedType::isDynamic(outputShape[3]) + ? biasTy.getDimSize(0) + : outputShape[3]; + } + + llvm::SmallVector dilation; + llvm::SmallVector padding; + llvm::SmallVector stride; + + getI64Values(attributes.get("dilation").cast(), dilation); + getI64Values(attributes.get("pad").cast(), padding); + getI64Values(attributes.get("stride").cast(), stride); + + if (!ShapedType::isDynamic(inputHeight) && + !ShapedType::isDynamic(weightHeight)) { + int32_t inputSize = inputHeight + padding[0] + padding[1]; + int32_t filterSize = (weightHeight - 1) * dilation[0] + 1; + int32_t unstridedResult = inputSize - filterSize + 1; + outputShape[1] = (unstridedResult - 1) / stride[0] + 1; + } + + if (!ShapedType::isDynamic(inputWidth) && + !ShapedType::isDynamic(weightWidth)) { + int32_t inputSize = inputWidth + padding[2] + padding[3]; + int32_t filterSize = (weightWidth - 1) * dilation[1] + 1; + int32_t unstridedResult = inputSize - filterSize + 1; + outputShape[2] = (unstridedResult - 1) / stride[1] + 1; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + +LogicalResult TransposeConv2DOp::inferReturnTypeComponents( + MLIRContext *context, ::llvm::Optional location, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + TransposeConv2DOp::Adaptor adaptor(operands); + llvm::SmallVector outputShape; + getI64Values(attributes.get("out_shape").cast(), outputShape); + + int32_t inputWidth = ShapedType::kDynamicSize; + int32_t inputHeight = ShapedType::kDynamicSize; + int32_t weightWidth = ShapedType::kDynamicSize; + int32_t weightHeight = ShapedType::kDynamicSize; + + // Input shape describes input width/height and batch. + if (auto inputTy = adaptor.input().getType().dyn_cast()) { + outputShape[0] = ShapedType::isDynamic(outputShape[0]) + ? inputTy.getDimSize(0) + : outputShape[0]; + inputHeight = inputTy.getDimSize(1); + inputWidth = inputTy.getDimSize(2); + } + + // Weight shapes describes the filter width/height and the output channels. + if (auto weightTy = adaptor.filter().getType().dyn_cast()) { + outputShape[3] = ShapedType::isDynamic(outputShape[3]) + ? weightTy.getDimSize(0) + : outputShape[3]; + weightHeight = weightTy.getDimSize(1); + weightWidth = weightTy.getDimSize(2); + } + + // Bias shape can describe the output channels. + if (auto biasTy = adaptor.bias().getType().dyn_cast()) { + outputShape[3] = ShapedType::isDynamic(outputShape[3]) + ? biasTy.getDimSize(0) + : outputShape[3]; + } + + llvm::SmallVector dilation; + llvm::SmallVector padding; + llvm::SmallVector stride; + + getI64Values(attributes.get("dilation").cast(), dilation); + getI64Values(attributes.get("out_pad").cast(), padding); + getI64Values(attributes.get("stride").cast(), stride); + + if (!ShapedType::isDynamic(inputHeight) && + !ShapedType::isDynamic(weightHeight)) { + int32_t dilated = (weightHeight - 1) * dilation[0] + 1; + int32_t calculateSize = + (inputHeight - 1) * stride[0] - padding[0] + dilated; + outputShape[1] = outputShape[1] == -1 ? calculateSize : outputShape[1]; + } + + if (!ShapedType::isDynamic(inputWidth) && + !ShapedType::isDynamic(weightWidth)) { + int32_t dilated = (weightWidth - 1) * dilation[1] + 1; + int32_t calculateSize = (inputWidth - 1) * stride[1] - padding[1] + dilated; + outputShape[2] = outputShape[2] == -1 ? calculateSize : outputShape[2]; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -660,3 +660,264 @@ %0 = "tosa.scatter"(%arg0, %arg1, %arg2) : (tensor, tensor<3x?xi32>, tensor) -> (tensor) return } + +// ----- + +// CHECK-LABEL: @conv2d_static +func @conv2d_static(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { + // CHECK: -> tensor<2x6x4x5xf32> + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @conv2d_dynamic_input +func @conv2d_dynamic_input(%input: tensor, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { + // CHECK: -> tensor + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor, tensor<5x3x6x3xf32>, tensor<5xf32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @conv2d_dynamic_weight +func @conv2d_dynamic_weight(%input: tensor<2x8x9x3xf32>, %weights: tensor, %bias: tensor<5xf32>) -> () { + // CHECK: -> tensor<2x?x?x5xf32> + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<2x8x9x3xf32>, tensor, tensor<5xf32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @conv2d_dynamic_bias +func @conv2d_dynamic_bias(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor) -> () { + // CHECK: -> tensor<2x6x4x5xf32> + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @conv2d_padded +func @conv2d_padded(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { + // CHECK: -> tensor<2x9x11x5xf32> + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [1, 2, 3, 4], stride = [1, 1], dilation = [1, 1]} : (tensor<2x8x9x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @conv2d_dilated +func @conv2d_dilated(%input: tensor<2x12x14x3xf32>, %weights: tensor<5x3x6x3xf32>, %bias: tensor<5xf32>) -> () { + // CHECK: -> tensor<2x6x4x5xf32> + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [3, 2]} : (tensor<2x12x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @conv2d_strided +func @conv2d_strided(%input: tensor<1x13x14x1xf32>, %weights: tensor<1x1x1x1xf32>, %bias: tensor<1xf32>) -> () { + // CHECK: -> tensor<1x5x7x1xf32> + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [3, 2], dilation = [1, 1]} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @conv3d_static +func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x4x3xf32>, %bias: tensor<5xf32>) -> () { + // CHECK: -> tensor<2x6x4x7x5xf32> + %0 = "tosa.conv3d"(%input, %weights, %bias) {dilation = [1, 1, 1], pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1]} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> (tensor) + return +} + +// ----- + +// CHECK-LABEL: @conv3d_dynamic_input +func @conv3d_dynamic_input(%arg0: tensor, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>) { + // CHECK: -> tensor + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = [1, 1, 1], pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1]} : (tensor, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @conv3d_dynamic_weight +func @conv3d_dynamic_weight(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor, %arg2: tensor<5xf32>) { + // CHECK: -> tensor<2x?x?x?x5xf32> + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = [1, 1, 1], pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1]} : (tensor<2x8x9x10x3xf32>, tensor, tensor<5xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @conv3d_dynamic_bias +func @conv3d_dynamic_bias(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor) { + // CHECK: -> tensor<2x6x4x7x5xf32> + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = [1, 1, 1], pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1]} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @conv3d_padded +func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>) { + // CHECK: -> tensor<2x9x11x18x5xf32> + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = [1, 1, 1], pad = [1, 2, 3, 4, 5, 6], stride = [1, 1, 1]} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @conv3d_dilated +func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2x3xf32>, %arg2: tensor<5xf32>) { + // CHECK: -> tensor<2x6x4x12x5xf32> + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = [3, 2, 4], pad = [0, 0, 0, 0, 0, 0], stride = [1, 1, 1]} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<5xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @conv3d_strided +func @conv3d_strided(%arg0: tensor<1x13x14x15x1xf32>, %arg1: tensor<1x1x1x1x1xf32>, %arg2: tensor<1xf32>) { + // CHECK: -> tensor<1x5x7x4x1xf32> + %0 = "tosa.conv3d"(%arg0, %arg1, %arg2) {dilation = [1, 1, 1], pad = [0, 0, 0, 0, 0, 0], stride = [3, 2, 4]} : (tensor<1x13x14x15x1xf32>, tensor<1x1x1x1x1xf32>, tensor<1xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_static +func @depthwise_conv2d_static(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { + // CHECK: -> tensor<2x6x4x15xf32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> + return +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_dynamic_input +func @depthwise_conv2d_dynamic_input(%arg0: tensor, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { + // CHECK: -> tensor + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_dynamic_weight +func @depthwise_conv2d_dynamic_weight(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor, %arg2: tensor<15xf32>) { + // CHECK: -> tensor<2x?x?x15xf32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<2x8x9x3xf32>, tensor, tensor<15xf32>) -> tensor<2x?x?x15xf32> + return +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_dynamic_bias +func @depthwise_conv2d_dynamic_bias(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor) { + // CHECK: -> tensor<2x6x4x15xf32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor) -> tensor<2x6x4x15xf32> + return +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_padded +func @depthwise_conv2d_padded(%arg0: tensor<2x8x9x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { + // CHECK: -> tensor<2x9x11x15xf32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [1, 2, 3, 4], stride = [1, 1]} : (tensor<2x8x9x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x9x11x15xf32> + return +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_dilated +func @depthwise_conv2d_dilated(%arg0: tensor<2x12x14x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { + // CHECK: -> tensor<2x6x4x15xf32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [3, 2], pad = [0, 0, 0, 0], stride = [1, 1]} : (tensor<2x12x14x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x6x4x15xf32> + return +} + +// ----- + +// CHECK-LABEL: @depthwise_conv2d_strided +func @depthwise_conv2d_strided(%arg0: tensor<1x13x14x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>) { + // CHECK: -> tensor<1x5x7x1xf32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], stride = [3, 2]} : (tensor<1x13x14x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x5x7x1xf32> + return +} + +// ----- + +// CHECK-LABEL: @transpose_conv2d_out_shape +func @transpose_conv2d_out_shape(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { + // CHECK: -> tensor<2x8x9x5xf32> + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, 8, 9, -1], stride = [1, 1]} : (tensor<2x?x?x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32> + return +} + +// ----- + +// CHECK-LABEL: @transpose_conv2d_static +func @transpose_conv2d_static(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { + // CHECK: -> tensor<2x8x9x5xf32> + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x8x9x5xf32> + return +} + +// ----- + +// CHECK-LABEL: @transpose_conv2d_dynamic_input +func @transpose_conv2d_dynamic_input(%arg0: tensor, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { + // CHECK: -> tensor + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor + return +} + +// ----- + +// CHECK-LABEL: @transpose_conv2d_dynamic_weights +func @transpose_conv2d_dynamic_weights(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor, %arg2: tensor<5xf32>) { + // CHECK: -> tensor<2x?x?x5xf32> + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor, tensor<5xf32>) -> tensor<2x?x?x5xf32> + return +} + +// ----- + +// CHECK-LABEL: @transpose_conv2d_dynamic_bias +func @transpose_conv2d_dynamic_bias(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor) { + // CHECK: -> tensor<2x8x9x5xf32> + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor) -> tensor<2x8x9x5xf32> + return +} + +// ----- + +// CHECK-LABEL: @transpose_conv2d_padded +func @transpose_conv2d_padded(%arg0: tensor<2x9x11x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { + // CHECK: -> tensor<2x10x13x5xf32> + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [1, 3], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x9x11x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x10x13x5xf32> + return +} + +// ----- + +// CHECK-LABEL: @transpose_conv2d_dilated +func @transpose_conv2d_dilated(%arg0: tensor<2x6x4x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) { + // CHECK: -> tensor<2x12x14x5xf32> + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [3, 2], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [1, 1]} : (tensor<2x6x4x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>) -> tensor<2x12x14x5xf32> + return +} + +// ----- + +// CHECK-LABEL: @transpose_conv2d_strided +func @transpose_conv2d_strided(%arg0: tensor<1x5x7x1xf32>, %arg1: tensor<1x1x1x1xf32>, %arg2: tensor<1xf32>) { + // CHECK: -> tensor<1x13x13x1xf32> + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], out_pad = [0, 0], out_shape = [-1, -1, -1, -1], stride = [3, 2]} : (tensor<1x5x7x1xf32>, tensor<1x1x1x1xf32>, tensor<1xf32>) -> tensor<1x13x13x1xf32> + return +}