diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -48,7 +48,10 @@ for (int i = 0, s = inputShape.size(); i < s; i++) { auto lowPad = pad[i * 2]; auto highPad = pad[i * 2 + 1]; - paddedShape.push_back(inputShape[i] + highPad + lowPad); + if (ShapedType::isDynamic(inputShape[i])) + paddedShape.push_back(inputShape[i]); + else + paddedShape.push_back(inputShape[i] + highPad + lowPad); lowIndices.push_back(rewriter.getIndexAttr(lowPad)); highIndices.push_back(rewriter.getIndexAttr(highPad)); } @@ -68,7 +71,6 @@ } // Calculating the output width/height using the formula: -// Out =((initDim+padBefore+padAttr-(dilation*(kernelDim-1)+1))/stride+1 // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1 static mlir::Value @@ -94,6 +96,54 @@ return builder.create(divide, one); } +// Creates a vector of the dynamic output dims for Conv2D and Depthwise_Conv2D +static SmallVector inferDynamicDimsForConv( + Location loc, Value input, Value weight, ShapedType resultTy, + ArrayAttr padAttr, ArrayAttr strideAttr, ArrayAttr dilationAttr, + int64_t weightHDim, int64_t weightWDim, OpBuilder &rewriter) { + ShapedType inputTy = input.getType().cast(); + Type inputETy = inputTy.getElementType(); + int64_t inputRank = inputTy.getRank(); + int64_t heightDim = 1; + int64_t weightDim = 2; + + SmallVector dynDims; + dynDims.resize(resultTy.getRank()); + for (int i = 0; i < inputRank; i++) { + if (inputTy.isDynamicDim(i) && i != heightDim && i != weightDim) + dynDims[i] = rewriter.create(loc, input, i); + } + + // Dynamic input height + if (inputTy.isDynamicDim(heightDim)) { + Value initHDim = + rewriter.create(loc, input, heightDim).getResult(); + Value kernelHDim = + rewriter.create(loc, weight, weightHDim).getResult(); + // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y) + dynDims[heightDim] = getConvOutputDim( + loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1], kernelHDim, + strideAttr.getValue()[0], dilationAttr.getValue()[0], inputETy, + rewriter); + } + + // Dynamic input weight + if (inputTy.isDynamicDim(weightDim)) { + Value initWDim = + rewriter.create(loc, input, weightDim).getResult(); + Value kernelWDim = + rewriter.create(loc, weight, weightWDim).getResult(); + // W = F(IW, pad_left, pad_right, dilation_x, KW, stride_x) + dynDims[weightDim] = getConvOutputDim( + loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3], kernelWDim, + strideAttr.getValue()[1], dilationAttr.getValue()[1], inputETy, + rewriter); + } + + SmallVector filteredDims = condenseValues(dynDims); + return filteredDims; +} + namespace { class ConvConverter : public OpConversionPattern { @@ -111,7 +161,6 @@ ShapedType weightTy = weight.getType().cast(); ShapedType biasTy = bias.getType().cast(); ShapedType resultTy = op->getResult(0).getType().cast(); - int64_t inputRank = inputTy.getRank(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); @@ -129,41 +178,9 @@ return rewriter.notifyMatchFailure( op, "tosa.conv ops does not support unsigned integer input"); - SmallVector dynDims; - dynDims.resize(resultTy.getRank()); - for (int i = 0; i < inputRank; i++) { - if (inputTy.isDynamicDim(i)) { - // Dynamic input height - // H = F(IH, pad_top, pad_bottom, dilation_y, KH, sride_y) - if (i == 1) { - Value initHDim = - rewriter.create(loc, input, 1).getResult(); - Value kernelHDim = - rewriter.create(loc, weight, 1).getResult(); - dynDims[i] = getConvOutputDim( - loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1], - kernelHDim, strideTosaAttr.getValue()[0], - dilationTosaAttr.getValue()[0], inputETy, rewriter); - - // Dynamic input weight - // W = F(IH, pad_left, pad_right, dilation_x, KW, sride_x) - } else if (i == 2) { - Value initWDim = - rewriter.create(loc, input, 2).getResult(); - Value kernelWDim = - rewriter.create(loc, weight, 2).getResult(); - dynDims[i] = getConvOutputDim( - loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3], - kernelWDim, strideTosaAttr.getValue()[1], - dilationTosaAttr.getValue()[1], inputETy, rewriter); - - } else { - dynDims[i] = rewriter.create(loc, input, i); - } - } - } - - SmallVector filteredDims = condenseValues(dynDims); + SmallVector filteredDims = inferDynamicDimsForConv( + loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr, + /*weightHDim=*/1, /*weightWDim=*/2, rewriter); auto weightShape = weightTy.getShape(); @@ -322,6 +339,15 @@ auto strideTosaAttr = op->getAttr("stride").cast(); auto dilationTosaAttr = op->getAttr("dilation").cast(); + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "tosa.depthwise_conv ops require static shapes"); + + // Compute output dynamic dims + SmallVector filteredDims = inferDynamicDimsForConv( + loc, input, weight, resultTy, padAttr, strideTosaAttr, dilationTosaAttr, + 0, 1, rewriter); + bool isQuantized = op->hasAttr("quantization_info"); IntegerAttr iZp; IntegerAttr kZp; @@ -334,16 +360,6 @@ quantizationInfo.weight_zp().getValue().getSExtValue()); } - if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "tosa.depthwise_conv ops require static shapes"); - - auto dynamicDimsOr = - checkHasDynamicBatchDims(rewriter, op, {input, op.output()}); - if (!dynamicDimsOr.hasValue()) - return failure(); - SmallVector dynamicDims = dynamicDimsOr.getValue(); - auto weightShape = weightTy.getShape(); auto resultShape = resultTy.getShape(); @@ -401,7 +417,7 @@ Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( - loc, dynamicDims, linalgConvTy.getShape(), resultETy); + loc, filteredDims, linalgConvTy.getShape(), resultETy); Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter .create(loc, ValueRange{zero}, @@ -409,7 +425,7 @@ .result(); Value biasInitTensor = rewriter.create( - loc, dynamicDims, resultTy.getShape(), resultETy); + loc, filteredDims, resultTy.getShape(), resultETy); if (!isQuantized) { Value conv = rewriter .create( diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -479,7 +479,7 @@ // CHECK: [[DEPTH:%.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>) outs([[FILL]] : tensor<1x5x5x3x11xf32>) // CHECK: [[COLLAPSED:%.+]] = "tosa.reshape"([[DEPTH]]) {new_shape = [1, 5, 5, 33]} // CHECK: [[BIAS:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, [[COLLAPSED]] : tensor<33xf32>, tensor<1x5x5x33xf32>) outs([[OUT]] : tensor<1x5x5x33xf32>) { - // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // CHECK: [[ADD:%.+]] = arith.addf %arg3, %arg4 : f32 // CHECK: linalg.yield [[ADD]] : f32 // CHECK: } -> tensor<1x5x5x33xf32> @@ -503,7 +503,7 @@ // CHECK: %[[DEPTH:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<3x1x3x11xf32>) outs(%[[FILL]] : tensor) // CHECK: %[[COLLAPSED:.+]] = "tosa.reshape"(%[[DEPTH]]) {new_shape = [-1, 5, 5, 33]} // CHECK: %[[BIAS:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[COLLAPSED]] : tensor<33xf32>, tensor) outs(%[[OUT]] : tensor) { - // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + // CHECK: ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // CHECK: %[[ADD:.+]] = arith.addf %arg3, %arg4 : f32 // CHECK: linalg.yield %[[ADD]] : f32 // CHECK: } -> tensor @@ -584,3 +584,19 @@ %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, weight_zp = 42 : i32}, stride = [1, 1], dilation = [2, 2] } : (tensor<1x14x14x4xi8>, tensor<3x3x4x128xi8>, tensor<512xi32>) -> tensor<1x10x10x512xi32> return } + +// CHECK-LABEL: @depthwise_conv2d_dyn_w_h +func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x6x3x5xf32>, %arg2: tensor<15xf32>) { + // CHECK: arith.addi + // CHECK: arith.subi + // CHECK: arith.muli + // CHECK: arith.divui + // CHECK: %[[PADDED:.+]] = tensor.pad %arg0 low[0, 1, 3, 0] high[0, 2, 4, 0] { + // CHECK: ^bb0(%arg3: index, %arg4: index, %arg5: index, %arg6: index): + // CHECK: tensor.yield %cst : f32 + // CHECK: } : tensor<2x?x?x3xf32> to tensor<2x?x?x3xf32> + // CHECK: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<[1, 2]> : tensor<2xi64>} ins(%[[PADDED]], %arg1 : tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>) outs(%22 : tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x3x5xf32> + // CHECK: %[[RESHAPED:.+]] = "tosa.reshape"(%[[CONV]]) {new_shape = [2, -1, -1, 15]} : (tensor<2x?x?x3x5xf32>) -> tensor<2x?x?x15xf32> + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [1, 2, 3, 4], dilation = [2, 1], stride = [1, 2]} : (tensor<2x?x?x3xf32>, tensor<3x6x3x5xf32>, tensor<15xf32>) -> tensor<2x?x?x15xf32> + return +}