diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp
@@ -13,7 +13,6 @@
 
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
-#include "mlir/Pass/Pass.h"
 
 using namespace mlir;
 using namespace mlir::tosa;
@@ -32,30 +31,34 @@
     ShapedType weightType = weight.getType().cast<ShapedType>();
     ShapedType resultType = op.getType().cast<ShapedType>();
 
-    if (!inputType.hasStaticShape() || !weightType.hasRank()) {
-      return failure();
-    }
+    auto numDynamic = llvm::count_if(inputType.getShape(), [](int64_t d) {
+      return ShapedType::isDynamic(d);
+    });
+    if (numDynamic > 1)
+      return rewriter.notifyMatchFailure(
+          op, "at most one dim in input may be dynamic");
+    if (!weightType.hasRank())
+      return rewriter.notifyMatchFailure(op, "unranked weight input");
 
     // Stride must be 1 for this optimization.
-    for (Attribute stride : op.stride().getValue()) {
-      if (!stride.cast<IntegerAttr>().getValue().isOne()) {
+    for (APInt stride : op.stride().getAsValueRange<IntegerAttr>()) {
+      if (!stride.isOne())
         return failure();
-      }
     }
 
     // Only works for a 1x1 kernel.
     ArrayRef<int64_t> weightShape = weightType.getShape();
-    if (weightShape[1] != 1 || weightShape[2] != 1) {
+    if (weightShape[1] != 1 || weightShape[2] != 1)
       return failure();
-    }
 
     // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC].
     ArrayRef<int64_t> inputShape = inputType.getShape();
-    llvm::SmallVector<int64_t, 2> revisedInputShape{
-        inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]};
-    auto revisedInputShapeType = RankedTensorType::get(
-        revisedInputShape,
-        input.getType().dyn_cast<RankedTensorType>().getElementType());
+    int64_t combined = inputShape[0] * inputShape[1] * inputShape[2];
+    if (combined < 0)
+      combined = ShapedType::kDynamicSize;
+    llvm::SmallVector<int64_t, 2> revisedInputShape{combined, inputShape[3]};
+    auto revisedInputShapeType =
+        RankedTensorType::get(revisedInputShape, inputType.getElementType());
     auto reshapedInput = rewriter
                              .create<tosa::ReshapeOp>(
                                  op.getLoc(), revisedInputShapeType, input,
@@ -75,11 +78,9 @@
                               .getResult();
 
     // Perform a fully connected network over the reshaped input and weight.
-    llvm::SmallVector<int64_t, 2> fullyConnectedShape{
-        inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]};
-    auto fullyConnectedShapeType = RankedTensorType::get(
-        fullyConnectedShape,
-        resultType.dyn_cast<ShapedType>().getElementType());
+    llvm::SmallVector<int64_t, 2> fullyConnectedShape{combined, weightShape[0]};
+    auto fullyConnectedShapeType =
+        RankedTensorType::get(fullyConnectedShape, resultType.getElementType());
 
     Value fullyConnectedValue;
     if (op.quantization_info()) {
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
--- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir
@@ -38,3 +38,19 @@
 }
 
 // -----
+
+// CHECK-LABEL:   func.func @conv_with_dynamic_dim(
+// CHECK-SAME:                                     %[[VAL_0:.*]]: tensor<?x14x14x64xi8>,
+// CHECK-SAME:                                     %[[VAL_1:.*]]: tensor<384x1x1x64xi8>,
+// CHECK-SAME:                                     %[[VAL_2:.*]]: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
+func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor<?x14x14x384xi32> {
+// CHECK:           %[[VAL_3:.*]] = "tosa.reshape"(%[[VAL_0]]) {new_shape = [-1, 64]} : (tensor<?x14x14x64xi8>) -> tensor<?x64xi8>
+// CHECK:           %[[VAL_4:.*]] = "tosa.reshape"(%[[VAL_1]]) {new_shape = [384, 64]} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8>
+// CHECK:           %[[VAL_5:.*]] = "tosa.fully_connected"(%[[VAL_3]], %[[VAL_4]], %[[VAL_2]]) {quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32>
+// CHECK:           %[[VAL_6:.*]] = "tosa.reshape"(%[[VAL_5]]) {new_shape = [-1, 14, 14, 384]} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32>
+// CHECK:           return %[[VAL_6]] : tensor<?x14x14x384xi32>
+// CHECK:         }
+  %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>, stride = [1, 1]} : (tensor<?x14x14x64xi8>, tensor<384x1x1x64xi8>, tensor<384xi32>) -> tensor<?x14x14x384xi32>
+  return %0 : tensor<?x14x14x384xi32>
+}
+