diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -3175,6 +3175,31 @@
   }
 };
 
+struct CastAwayBrodcastLeadingOneDim
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType newDstType = trimLeadingOneDims(broadcastOp.getVectorType());
+    if (newDstType == broadcastOp.getVectorType())
+      return failure();
+    Location loc = broadcastOp.getLoc();
+    VectorType srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
+    if (srcVecType)
+      srcVecType = trimLeadingOneDims(srcVecType);
+    Value source = broadcastOp.source();
+    if (srcVecType && srcVecType != broadcastOp.getSourceType()) {
+      source = rewriter.create<vector::ShapeCastOp>(loc, srcVecType, source);
+    }
+    Value newBroadcastOp =
+        rewriter.create<vector::BroadcastOp>(loc, newDstType, source);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+        broadcastOp, broadcastOp.getVectorType(), newBroadcastOp);
+    return success();
+  }
+};
+
 // Returns the values in `arrayAttr` as an integer vector.
 static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
   return llvm::to_vector<4>(
@@ -3771,7 +3796,8 @@
   patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
                CastAwayInsertStridedSliceLeadingOneDim,
                CastAwayTransferReadLeadingOneDim,
-               CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
+               CastAwayTransferWriteLeadingOneDim,
+               CastAwayBrodcastLeadingOneDim, ShapeCastOpFolder>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir
--- a/mlir/test/Dialect/Vector/vector-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-transforms.mlir
@@ -672,6 +672,23 @@
   return
 }
 
+// CHECK-LABEL: func @cast_away_broadcast_leading_one_dims
+func @cast_away_broadcast_leading_one_dims(
+  %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) ->
+  (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>) {
+  // CHECK:  vector.broadcast %{{.*}} : vector<8xf32> to vector<8xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32>
+  %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32>
+  // CHECK:  vector.broadcast %{{.*}} : f32 to vector<4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32>
+  %1 = vector.broadcast %arg1 : f32 to vector<1x1x4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32>
+  // CHECK:  vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32>
+  // CHECK:  vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32>
+  %2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32>
+  return %0, %1, %2: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>
+}
+
 // CHECK-LABEL: func @bubble_down_bitcast_in_extract
 //  CHECK-SAME: %[[SRC:.+]]: vector<4xf32>
 func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {