diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" -#include "mlir/Pass/Pass.h" using namespace mlir; using namespace mlir::tosa; @@ -32,30 +31,34 @@ ShapedType weightType = weight.getType().cast(); ShapedType resultType = op.getType().cast(); - if (!inputType.hasStaticShape() || !weightType.hasRank()) { - return failure(); - } + auto numDynamic = llvm::count_if(inputType.getShape(), [](int64_t d) { + return ShapedType::isDynamic(d); + }); + if (numDynamic > 1) + return rewriter.notifyMatchFailure( + op, "at most one dim in input may be dynamic"); + if (!weightType.hasRank()) + return rewriter.notifyMatchFailure(op, "unranked weight input"); // Stride must be 1 for this optimization. - for (Attribute stride : op.stride().getValue()) { - if (!stride.cast().getValue().isOne()) { + for (APInt stride : op.stride().getAsValueRange()) { + if (!stride.isOne()) return failure(); - } } // Only works for a 1x1 kernel. ArrayRef weightShape = weightType.getShape(); - if (weightShape[1] != 1 || weightShape[2] != 1) { + if (weightShape[1] != 1 || weightShape[2] != 1) return failure(); - } // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. ArrayRef inputShape = inputType.getShape(); - llvm::SmallVector revisedInputShape{ - inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; - auto revisedInputShapeType = RankedTensorType::get( - revisedInputShape, - input.getType().dyn_cast().getElementType()); + int64_t combined = inputShape[0] * inputShape[1] * inputShape[2]; + if (combined < 0) + combined = ShapedType::kDynamicSize; + llvm::SmallVector revisedInputShape{combined, inputShape[3]}; + auto revisedInputShapeType = + RankedTensorType::get(revisedInputShape, inputType.getElementType()); auto reshapedInput = rewriter .create( op.getLoc(), revisedInputShapeType, input, @@ -75,11 +78,9 @@ .getResult(); // Perform a fully connected network over the reshaped input and weight. - llvm::SmallVector fullyConnectedShape{ - inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; - auto fullyConnectedShapeType = RankedTensorType::get( - fullyConnectedShape, - resultType.dyn_cast().getElementType()); + llvm::SmallVector fullyConnectedShape{combined, weightShape[0]}; + auto fullyConnectedShapeType = + RankedTensorType::get(fullyConnectedShape, resultType.getElementType()); Value fullyConnectedValue; if (op.quantization_info()) { diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -38,3 +38,19 @@ } // ----- + +// CHECK-LABEL: func.func @conv_with_dynamic_dim( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<384x1x1x64xi8>, +// CHECK-SAME: %[[VAL_2:.*]]: tensor<384xi32>) -> tensor { +func.func @conv_with_dynamic_dim(%arg0: tensor, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor { +// CHECK: %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_0]]) {new_shape = [-1, 64]} : (tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [384, 64]} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8> +// CHECK: %[[VAL_5:.*]] = "tosa.fully_connected"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]]) {quantization_info = #tosa.conv_quant} : (tensor, tensor<384x64xi8>, tensor<384xi32>) -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [-1, 14, 14, 384]} : (tensor) -> tensor +// CHECK: return %[[VAL_6]] : tensor +// CHECK: } + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant, stride = [1, 1]} : (tensor, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor + return %0 : tensor +} +