diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5432,6 +5432,119 @@
   return success();
 }
 
+namespace {
+
+/// Given a single-result operation that is used by a vector.yield, returns the
+/// operand number of such result in the vector.yield.
+Optional<unsigned> getYieldOpUseOperandNum(Operation *op) {
+  auto isYieldOpUse = [](OpOperand &use) -> bool {
+    return isa<vector::YieldOp>(use.getOwner());
+  };
+
+  assert(llvm::count_if(op->getUses(), isYieldOpUse) <= 1 &&
+         "Yielding the same value multiple times is not supported yet");
+
+  for (OpOperand &use : op->getUses()) {
+    if (isYieldOpUse(use))
+        return use.getOperandNumber();
+  }
+
+  return std::nullopt;
+}
+
+/// Given a vector.mask operation with multiple nested operations (other than
+/// the vector.yield), hoists all the operations that do not need masking out of
+/// the vector.mask operation and create individual vector.mask operations for
+/// each nested operation that needs masking, using the mask of the input
+/// vector.mask operation.
+struct FlattenMultiOpMaskOp : public OpRewritePattern<MaskOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MaskOp maskOp,
+                                PatternRewriter &rewriter) const override {
+    Block &block = maskOp.getMaskRegion().getBlocks().front();
+    if (block.getOperations().size() <= 2)
+      return success();
+
+    PatternRewriter::InsertionGuard guard(rewriter);
+    rewriter.setInsertionPoint(maskOp);
+    Value activeMask = maskOp.getMask();
+
+    // Hoist every operation either to an individual vector.mask or outside of
+    // the original vector.mask operation.
+    for (Operation &op : llvm::make_early_inc_range(block)) {
+      Operation *nestedOp = &op;
+      if (isa<vector::YieldOp>(nestedOp))
+        continue;
+
+      assert(nestedOp->getNumResults() <= 1 &&
+             "Multi-result ops are not supported");
+
+      // The U-D chaine of operations that are returned by the original
+      // vector.mask need to be rewired properly.
+      auto maybeResultIdxToReplace = getYieldOpUseOperandNum(nestedOp);
+      unsigned resultIdxToReplace;
+      if (maybeResultIdxToReplace) {
+        resultIdxToReplace = *maybeResultIdxToReplace;
+        assert(resultIdxToReplace == 0 && "Multi-result ops are not supported");
+        nestedOp->dropAllUses();
+      }
+
+      if (auto maskableOp = dyn_cast<MaskableOpInterface>(nestedOp)) {
+        // Create a new vector.mask operation for this maskable op using the
+        // original mask.
+        auto createRegionMask = [nestedOp](OpBuilder &builder, Location loc) {
+          Block *insBlock = builder.getInsertionBlock();
+          // Create a block, put an op in that block. Look for a utility.
+          // Maybe in conversion pattern rewriter. Way to avoid splice.
+          // Set insertion point.
+          insBlock->getOperations().splice(
+              insBlock->begin(), nestedOp->getBlock()->getOperations(),
+              nestedOp);
+          builder.create<vector::YieldOp>(loc, nestedOp->getResults());
+        };
+
+        auto newMaskOp = maskableOp->getResults().empty()
+            ? rewriter.create<vector::MaskOp>(maskOp.getLoc(), activeMask,
+                                              createRegionMask)
+            : rewriter.create<vector::MaskOp>(
+                  maskOp.getLoc(), maskableOp->getResultTypes().front(),
+                  activeMask, createRegionMask);
+
+        Operation *newMaskOpTerminator = &newMaskOp.getMaskRegion().front().back();
+
+        // Replace the original uses of the maskable op with result value of the
+        // new vector.mask containing the maskable op.
+        for (auto [resIdx, resVal] : llvm::enumerate(maskableOp->getResults()))
+          rewriter.replaceAllUsesExcept(resVal, newMaskOp.getResult(resIdx),
+                                        newMaskOpTerminator);
+
+        // If the maskable op was returned by the original vector.mask, replace
+        // the original uses with the result value of the new vector.mask.
+        if (maybeResultIdxToReplace)
+          rewriter.replaceAllUsesWith(maskOp.getResult(0),
+                                      newMaskOp.getResult(0));
+      } else {
+        // This operation doesn't need mask. We just move it outside the vector.mask.
+        maskOp->getBlock()->getOperations().splice(
+            Block::iterator(maskOp), nestedOp->getBlock()->getOperations(),
+            nestedOp);
+
+        // If the operation was returned by the original vector.mask, replace
+        // the original uses with the result value of the new vector.mask.
+        if (maybeResultIdxToReplace)
+          rewriter.replaceAllUsesWith(maskOp.getResult(0),
+                                      nestedOp->getResult(0));
+      }
+    }
+
+    rewriter.eraseOp(maskOp);
+    return success();
+  }
+};
+
+} // namespace
+
 // MaskingOpInterface definitions.
 
 /// Returns the operation masked by this 'vector.mask'.