diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -61,6 +61,39 @@ .result(); } +static mlir::Value reifyConstantDim(Attribute attr, + ImplicitLocOpBuilder &builder) { + return builder.createOrFold( + builder.getIndexType(), builder.create(attr)); +} + +// Calculating the output width/height using the formula: +// Out =((initDim+padBefore+padAttr-(dilation*(kernelDim-1)+1))/stride+1 +// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1 +// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1 +static mlir::Value +getConvOutputDim(Location loc, Value initDim, Attribute padBeforeAttr, + Attribute padAfterAttr, Value kernelDim, Attribute strideAttr, + Attribute dilationAttr, Type inputETy, OpBuilder &rewriter) { + ImplicitLocOpBuilder builder(loc, rewriter); + auto one = rewriter.create( + loc, IntegerAttr::get(initDim.getType(), 1)); + Value padBefore = reifyConstantDim(padBeforeAttr, builder); + Value paddedBefore = builder.create(initDim, padBefore); + Value padAfter = reifyConstantDim(padAfterAttr, builder); + Value paddedAfter = builder.create(paddedBefore, padAfter); + + Value subOne = builder.create(kernelDim, one); + Value dilation = reifyConstantDim(dilationAttr, builder); + Value dilated = builder.create(dilation, subOne); + Value addOne = builder.create(dilated, one); + + Value subtract = builder.create(paddedAfter, addOne); + Value stride = reifyConstantDim(strideAttr, builder); + Value divide = builder.create(subtract, stride); + return builder.create(divide, one); +} + namespace { class ConvConverter : public OpConversionPattern { @@ -78,6 +111,7 @@ ShapedType weightTy = weight.getType().cast(); ShapedType biasTy = bias.getType().cast(); ShapedType resultTy = op->getResult(0).getType().cast(); + int64_t inputRank = inputTy.getRank(); Type inputETy = inputTy.getElementType(); Type resultETy = resultTy.getElementType(); @@ -91,16 +125,46 @@ return rewriter.notifyMatchFailure( op, "tosa.conv ops require static shapes for weight and bias"); - auto dynamicDimsOr = - checkHasDynamicBatchDims(rewriter, op, {input, op.output()}); - if (!dynamicDimsOr.hasValue()) - return failure(); - SmallVector dynamicDims = dynamicDimsOr.getValue(); - if (inputETy.isUnsignedInteger()) return rewriter.notifyMatchFailure( op, "tosa.conv ops does not support unsigned integer input"); + SmallVector dynDims; + dynDims.resize(resultTy.getRank()); + for (int i = 0; i < inputRank; i++) { + if (inputTy.isDynamicDim(i)) { + // Dynamic input height + // H = F(IH, pad_top, pad_bottom, dilation_y, KH, sride_y) + if (i == 1) { + Value initHDim = + rewriter.create(loc, input, 1).getResult(); + Value kernelHDim = + rewriter.create(loc, weight, 1).getResult(); + dynDims[i] = getConvOutputDim( + loc, initHDim, padAttr.getValue()[0], padAttr.getValue()[1], + kernelHDim, strideTosaAttr.getValue()[0], + dilationTosaAttr.getValue()[0], inputETy, rewriter); + + // Dynamic input weight + // W = F(IH, pad_left, pad_right, dilation_x, KW, sride_x) + } else if (i == 2) { + Value initWDim = + rewriter.create(loc, input, 2).getResult(); + Value kernelWDim = + rewriter.create(loc, weight, 2).getResult(); + dynDims[i] = getConvOutputDim( + loc, initWDim, padAttr.getValue()[2], padAttr.getValue()[3], + kernelWDim, strideTosaAttr.getValue()[1], + dilationTosaAttr.getValue()[1], inputETy, rewriter); + + } else { + dynDims[i] = rewriter.create(loc, input, i); + } + } + } + + SmallVector filteredDims = condenseValues(dynDims); + auto weightShape = weightTy.getShape(); // Apply padding as necessary. @@ -148,7 +212,7 @@ Attribute resultZeroAttr = rewriter.getZeroAttr(resultETy); Value initTensor = rewriter.create( - loc, dynamicDims, resultTy.getShape(), resultETy); + loc, filteredDims, resultTy.getShape(), resultETy); Value zero = rewriter.create(loc, resultZeroAttr); Value zeroTensor = rewriter.create(loc, zero, initTensor).getResult(0); @@ -173,7 +237,7 @@ indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank())); Value biasInitTensor = rewriter.create( - loc, dynamicDims, resultTy.getShape(), resultETy); + loc, filteredDims, resultTy.getShape(), resultETy); if (isQuantized) { auto quantizationInfo = diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -383,6 +383,66 @@ // ----- +// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: @conv2d_dyn_w_h +func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { + // Computing output height + // CHECK: %[[C1:.+]] = arith.constant 1 + // CHECK: %[[H:.+]] = tensor.dim %arg0, %[[C1]] + // CHECK: %[[C1_0:.+]] = arith.constant 1 + // CHECK: %[[KH:.+]] = tensor.dim %arg1, %[[C1_0]] + // CHECK: %[[ONE:.+]] = arith.constant 1 : index + // CHECK: %[[PAD_0:.+]] = arith.constant 0 : index + // CHECK: %[[ADD_PAD_0:.+]] = arith.addi %[[H]], %[[PAD_0]] : index + // CHECK: %[[PAD_1:.+]] = arith.constant 0 : index + // CHECK: %[[ADD_PAD_1:.+]] = arith.addi %[[ADD_PAD_0]], %[[PAD_1]] : index + // CHECK: %[[SUB_ONE:.+]] = arith.subi %[[KH]], %[[ONE]] : index + // CHECK: %[[DIL_H:.+]] = arith.constant 2 : index + // CHECK: %[[DILATED:.+]] = arith.muli %[[DIL_H]], %[[SUB_ONE]] : index + // CHECK: %[[ADD_ONE:.+]] = arith.addi %[[DILATED]], %[[ONE]] : index + // CHECK: %[[SUBTRACTED:.+]] = arith.subi %[[ADD_PAD_1]], %[[ADD_ONE]] : index + // CHECK: %[[STRIDE_H:.+]] = arith.constant 1 : index + // CHECK: %[[DIVIDED:.+]] = arith.divui %[[SUBTRACTED]], %[[STRIDE_H]] : index + // CHECK: %[[H_OUT:.+]] = arith.subi %[[DIVIDED]], %[[ONE]] : index + + // Computing output width + // CHECK: %[[C2:.+]] = arith.constant 2 + // CHECK: %[[W:.+]] = tensor.dim %arg0, %[[C2]] + // CHECK: %[[C2_0:.+]] = arith.constant 2 + // CHECK: %[[KW:.+]] = tensor.dim %arg1, %[[C2_0]] + // CHECK: %[[ONE_0:.+]] = arith.constant 1 : index + // CHECK: %[[PAD_2:.+]] = arith.constant 0 : index + // CHECK: %[[ADD_PAD_2:.+]] = arith.addi %[[W]], %[[PAD_2]] : index + // CHECK: %[[PAD_3:.+]] = arith.constant 0 : index + // CHECK: %[[ADD_PAD_3:.+]] = arith.addi %[[ADD_PAD_2]], %[[PAD_3]] : index + // CHECK: %[[SUB_ONE_0:.+]] = arith.subi %[[KW]], %[[ONE_0]] : index + // CHECK: %[[DIL_W:.+]] = arith.constant 1 : index + // CHECK: %[[DILATED_0:.+]] = arith.muli %[[DIL_W]], %[[SUB_ONE_0]] : index + // CHECK: %[[ADD_ONE_0:.+]] = arith.addi %[[DILATED_0]], %[[ONE_0]] : index + // CHECK: %[[SUBTRACTED_0:.+]] = arith.subi %[[ADD_PAD_3]], %[[ADD_ONE_0]] : index + // CHECK: %[[STRIDE_W:.+]] = arith.constant 1 : index + // CHECK: %[[DIVIDED_0:.+]] = arith.divui %[[SUBTRACTED_0]], %[[STRIDE_W]] : index + // CHECK: %[[W_OUT:.+]] = arith.subi %[[DIVIDED_0]], %[[ONE_0]] : index + + // Running convolution + // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> + // CHECK: %[[WEIGHT:.+]] = "tosa.transpose"(%arg1, %[[PERM]]) + // CHECK: %[[M_IN:.+]] = linalg.init_tensor [1, %[[H_OUT]], %[[W_OUT]], 28] + // CHECK: %[[CST:.+]] = arith.constant 0 + // CHECK: %[[FILL:.+]] = linalg.fill + // CHECK: %[[B_IN:.+]] = linalg.init_tensor [1, %[[H_OUT]], %[[W_OUT]], 28] + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>) + // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>) + // CHECK: %[[ADD:.+]] = arith.addf + // CHECK: linalg.yield %[[ADD]] : f32 + %0 = "tosa.conv2d"(%input, %weights, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [2, 1]} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> (tensor<1x?x?x28xf32>) + return +} + +// ----- + // CHECK-LABEL: @conv2d_padded_f32 func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28x3x3x28xf32>, %bias: tensor<28xf32>) -> () { // CHECK: %[[C0:.+]] = arith.constant 0