diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -108,7 +108,8 @@
 
 def AffineForOp : Affine_Op<"for",
     [AutomaticAllocationScope, ImplicitAffineTerminator, RecursiveSideEffects,
-     DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
+     DeclareOpInterfaceMethods<LoopLikeOpInterface,
+     ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep"]>]> {
   let summary = "for operation";
   let description = [{
     Syntax:
diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
--- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h
@@ -27,7 +27,7 @@
 } // namespace vector
 
 namespace memref {
-
+class AllocOp;
 //===----------------------------------------------------------------------===//
 // Patterns
 //===----------------------------------------------------------------------===//
@@ -51,6 +51,33 @@
 /// terms of shapes of its input operands.
 void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns);
 
+/// Transformation to do multi-buffering/array expansion to remove dependencies
+/// on the temporary allocation between consecutive loop iterations.
+/// It return success if the allocation was multi-buffered and returns failure()
+/// otherwise.
+/// Example:
+/// ```
+/// %0 = memref.alloc() : memref<4x128xf32>
+/// scf.for %iv = %c1 to %c1024 step %c3 {
+///   memref.copy %1, %0 : memref<4x128xf32> to memref<4x128xf32>
+///   "some_use"(%0) : (memref<4x128xf32>) -> ()
+/// }
+/// ```
+/// into:
+/// ```
+/// %0 = memref.alloc() : memref<5x4x128xf32>
+/// scf.for %iv = %c1 to %c1024 step %c3 {
+///   %s = arith.subi %iv, %c1 : index
+///   %d = arith.divsi %s, %c3 : index
+///   %i = arith.remsi %d, %c5 : index
+///   %sv = memref.subview %0[%i, 0, 0] [1, 4, 128] [1, 1, 1] :
+///     memref<5x4x128xf32> to memref<4x128xf32, #map0>
+///   memref.copy %1, %sv : memref<4x128xf32> to memref<4x128xf32, #map0>
+///   "some_use"(%sv) : (memref<4x128xf32, $map0>) -> ()
+/// }
+/// ```
+LogicalResult multiBuffer(memref::AllocOp allocOp, unsigned multiplier);
+
 //===----------------------------------------------------------------------===//
 // Passes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td
--- a/mlir/include/mlir/Dialect/SCF/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td
@@ -110,7 +110,8 @@
 }
 
 def ForOp : SCF_Op<"for",
-      [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface>,
+      [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
+       ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep"]>,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
        SingleBlockImplicitTerminator<"scf::YieldOp">,
        RecursiveSideEffects]> {
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -44,6 +44,42 @@
       }],
       "::mlir::LogicalResult", "moveOutOfLoop", (ins "::mlir::ArrayRef<::mlir::Operation *>":$ops)
     >,
+    InterfaceMethod<[{
+        If there is a single induction variable return it, otherwise return
+        llvm::None.
+      }],
+      /*retTy=*/"::mlir::Optional<::mlir::Value>",
+      /*methodName=*/"getSingleInductionVar",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return llvm::None;
+      }]
+    >,
+    InterfaceMethod<[{
+        Return the single lower bound value or attribute if it exist, otherwise
+        return llvm::None.
+      }],
+      /*retTy=*/"::mlir::Optional<::mlir::OpFoldResult>",
+      /*methodName=*/"getSingleLowerBound",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return llvm::None;
+      }]
+    >,
+    InterfaceMethod<[{
+        Return the single step value or attribute if it exist, otherwise
+        return llvm::None.
+      }],
+      /*retTy=*/"::mlir::Optional<::mlir::OpFoldResult>",
+      /*methodName=*/"getSingleStep",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return llvm::None;
+      }]
+    >,
   ];
 }
 
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1820,6 +1820,22 @@
   return !region().isAncestor(value.getParentRegion());
 }
 
+Optional<Value> AffineForOp::getSingleInductionVar() {
+  return getInductionVar();
+}
+
+Optional<OpFoldResult> AffineForOp::getSingleLowerBound() {
+  if (!hasConstantLowerBound())
+    return llvm::None;
+  OpBuilder b(getContext());
+  return OpFoldResult(b.getI64IntegerAttr(getConstantLowerBound()));
+}
+
+Optional<OpFoldResult> AffineForOp::getSingleStep() {
+  OpBuilder b(getContext());
+  return OpFoldResult(b.getI64IntegerAttr(getStep()));
+}
+
 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
   for (auto *op : ops)
     op->moveBefore(*this);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
--- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt
@@ -2,6 +2,7 @@
   ComposeSubView.cpp
   ExpandOps.cpp
   FoldSubViewOps.cpp
+  MultiBuffer.cpp
   NormalizeMemRefs.cpp
   ResolveShapedTypeResultDims.cpp
 
@@ -16,6 +17,7 @@
   MLIRAffineUtils
   MLIRArithmetic
   MLIRInferTypeOpInterface
+  MLIRLoopLikeInterface
   MLIRMemRef
   MLIRPass
   MLIRStandard
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -0,0 +1,146 @@
+//===----------- MultiBuffering.cpp ---------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements multi buffering transformation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
+
+using namespace mlir;
+
+/// Return true if the op fully overwrite the given `buffer` value.
+static bool overrideBuffer(Operation *op, Value buffer) {
+  auto copyOp = dyn_cast<memref::CopyOp>(op);
+  if (!copyOp)
+    return false;
+  return copyOp.target() == buffer;
+}
+
+/// Replace the uses of `oldOp` with the given `val` and for subview uses
+/// propagate the type change. Changing the memref type may require propagating
+/// it through subview ops so we cannot just do a replaceAllUse but need to
+/// propagate the type change and erase old subview ops.
+static void replaceUsesAndPropagateType(Operation *oldOp, Value val,
+                                        OpBuilder &builder) {
+  SmallVector<Operation *> opToDelete;
+  SmallVector<OpOperand *> operandsToReplace;
+  for (OpOperand &use : oldOp->getUses()) {
+    auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner());
+    if (!subviewUse) {
+      // Save the operand to and replace outside the loop to not invalidate the
+      // iterator.
+      operandsToReplace.push_back(&use);
+      continue;
+    }
+    builder.setInsertionPoint(subviewUse);
+    Type newType = memref::SubViewOp::inferRankReducedResultType(
+        subviewUse.getType().getRank(), val.getType().cast<MemRefType>(),
+        extractFromI64ArrayAttr(subviewUse.static_offsets()),
+        extractFromI64ArrayAttr(subviewUse.static_sizes()),
+        extractFromI64ArrayAttr(subviewUse.static_strides()));
+    Value newSubview = builder.create<memref::SubViewOp>(
+        subviewUse->getLoc(), newType.cast<MemRefType>(), val,
+        subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(),
+        subviewUse.getMixedStrides());
+    replaceUsesAndPropagateType(subviewUse, newSubview, builder);
+    opToDelete.push_back(use.getOwner());
+  }
+  for (OpOperand *operand : operandsToReplace)
+    operand->set(val);
+  // Clean up old subview ops.
+  for (Operation *op : opToDelete)
+    op->erase();
+}
+
+/// Helper to convert get a value from an OpFoldResult or create it at the
+/// builder insert point.
+static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
+                              Location loc) {
+  Value value = res.dyn_cast<Value>();
+  if (value)
+    return value;
+  return builder.create<arith::ConstantIndexOp>(
+      loc, res.dyn_cast<Attribute>().cast<IntegerAttr>().getInt());
+}
+
+// Transformation to do multi-buffering/array expansion to remove dependencies
+// on the temporary allocation between consecutive loop iterations.
+// Returns success if the transformation happened and failure otherwise.
+// This is not a pattern as it requires propagating the new memref type to its
+// uses and requires updating subview ops.
+LogicalResult mlir::memref::multiBuffer(memref::AllocOp allocOp,
+                                        unsigned multiplier) {
+  DominanceInfo dom(allocOp->getParentOp());
+  LoopLikeOpInterface candidateLoop;
+  for (Operation *user : allocOp->getUsers()) {
+    auto parentLoop = user->getParentOfType<LoopLikeOpInterface>();
+    if (!parentLoop)
+      return failure();
+    /// Make sure there is no loop carried dependency on the allocation.
+    if (!overrideBuffer(user, allocOp.getResult()))
+      continue;
+    // If this user doesn't dominate all the other users keep looking.
+    if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
+          return !dom.dominates(user, otherUser);
+        }))
+      continue;
+    candidateLoop = parentLoop;
+    break;
+  }
+  if (!candidateLoop)
+    return failure();
+  llvm::Optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
+  llvm::Optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
+  llvm::Optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
+  if (!inductionVar || !lowerBound || !singleStep)
+    return failure();
+  OpBuilder builder(candidateLoop);
+  Value stepValue =
+      getOrCreateValue(*singleStep, builder, candidateLoop->getLoc());
+  Value lowerBoundValue =
+      getOrCreateValue(*lowerBound, builder, candidateLoop->getLoc());
+  SmallVector<int64_t, 4> newShape(1, multiplier);
+  ArrayRef<int64_t> oldShape = allocOp.getType().getShape();
+  newShape.append(oldShape.begin(), oldShape.end());
+  auto newMemref = MemRefType::get(newShape, allocOp.getType().getElementType(),
+                                   MemRefLayoutAttrInterface(),
+                                   allocOp.getType().getMemorySpace());
+  builder.setInsertionPoint(allocOp);
+  Location loc = allocOp->getLoc();
+  auto newAlloc = builder.create<memref::AllocOp>(loc, newMemref);
+  builder.setInsertionPoint(&candidateLoop.getLoopBody().front(),
+                            candidateLoop.getLoopBody().front().begin());
+  AffineExpr induc = getAffineDimExpr(0, allocOp.getContext());
+  AffineExpr init = getAffineDimExpr(1, allocOp.getContext());
+  AffineExpr step = getAffineDimExpr(2, allocOp.getContext());
+  AffineExpr expr = ((induc - init).floorDiv(step)) % multiplier;
+  auto map = AffineMap::get(3, 0, expr);
+  std::array<Value, 3> operands = {*inductionVar, lowerBoundValue, stepValue};
+  Value bufferIndex = builder.create<AffineApplyOp>(loc, map, operands);
+  SmallVector<OpFoldResult> offsets, sizes, strides;
+  offsets.push_back(bufferIndex);
+  offsets.append(oldShape.size(), builder.getIndexAttr(0));
+  strides.assign(oldShape.size() + 1, builder.getIndexAttr(1));
+  sizes.push_back(builder.getIndexAttr(1));
+  for (int64_t size : oldShape)
+    sizes.push_back(builder.getIndexAttr(size));
+  auto dstMemref =
+      memref::SubViewOp::inferRankReducedResultType(
+          allocOp.getType().getRank(), newMemref, offsets, sizes, strides)
+          .cast<MemRefType>();
+  Value subview = builder.create<memref::SubViewOp>(loc, dstMemref, newAlloc,
+                                                    offsets, sizes, strides);
+  replaceUsesAndPropagateType(allocOp, subview, builder);
+  allocOp.erase();
+  return success();
+}
diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp
--- a/mlir/lib/Dialect/SCF/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/SCF.cpp
@@ -320,6 +320,16 @@
   return RegionBranchOpInterface::verifyTypes(*this);
 }
 
+Optional<Value> ForOp::getSingleInductionVar() { return getInductionVar(); }
+
+Optional<OpFoldResult> ForOp::getSingleLowerBound() {
+  return OpFoldResult(getLowerBound());
+}
+
+Optional<OpFoldResult> ForOp::getSingleStep() {
+  return OpFoldResult(getStep());
+}
+
 /// Prints the initialization list in the form of
 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
 /// where 'inner' values are assumed to be region arguments and 'outer' values
diff --git a/mlir/test/Dialect/MemRef/multibuffer.mlir b/mlir/test/Dialect/MemRef/multibuffer.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Dialect/MemRef/multibuffer.mlir
@@ -0,0 +1,110 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -test-multi-buffering=multiplier=5 -cse -split-input-file | FileCheck %s
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (((d0 - d1) floordiv d2) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer
+func @multi_buffer(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[A:.*]] = memref.alloc() : memref<5x4x128xf32>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+  %0 = memref.alloc() : memref<4x128xf32>
+  %c1024 = arith.constant 1024 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %[[C1]]
+  scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]], %[[C1]], %[[C3]])
+// CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, #[[$MAP0]]>
+   %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+    memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, #[[$MAP0]]>
+   memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> ()
+    "some_use"(%0) : (memref<4x128xf32>) -> ()
+// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> ()
+   "some_use"(%0) : (memref<4x128xf32>) -> ()
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_buffer_affine
+func @multi_buffer_affine(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[A:.*]] = memref.alloc() : memref<5x4x128xf32>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+  %0 = memref.alloc() : memref<4x128xf32>
+  %c1024 = arith.constant 1024 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+// CHECK: affine.for %[[IV:.*]] = 1
+  affine.for %arg2 = 1 to 1024 step 3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]], %[[C1]], %[[C3]])
+// CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, #[[$MAP0]]>
+   %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+    memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, #[[$MAP0]]>
+   memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> ()
+    "some_use"(%0) : (memref<4x128xf32>) -> ()
+// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> ()
+   "some_use"(%0) : (memref<4x128xf32>) -> ()
+  }
+  return
+}
+
+// -----
+
+// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (((d0 - d1) floordiv d2) mod 5)>
+
+// CHECK-LABEL: func @multi_buffer_subview_use
+func @multi_buffer_subview_use(%a: memref<1024x1024xf32>) {
+// CHECK-DAG: %[[A:.*]] = memref.alloc() : memref<5x4x128xf32>
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+  %0 = memref.alloc() : memref<4x128xf32>
+  %c1024 = arith.constant 1024 : index
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+// CHECK: scf.for %[[IV:.*]] = %[[C1]]
+  scf.for %arg2 = %c1 to %c1024 step %c3 {
+// CHECK: %[[I:.*]] = affine.apply #[[$MAP1]](%[[IV]], %[[C1]], %[[C3]])
+// CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0, 0] [1, 4, 128] [1, 1, 1] : memref<5x4x128xf32> to memref<4x128xf32, #[[$MAP0]]>
+   %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+    memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+// CHECK: memref.copy %{{.*}}, %[[SV]] : memref<4x128xf32, #{{.*}}> to memref<4x128xf32, #[[$MAP0]]>
+   memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+// CHECK: %[[SV1:.*]] = memref.subview %[[SV]][0, 1] [4, 127] [1, 1] : memref<4x128xf32, #[[$MAP0]]> to memref<4x127xf32, #[[$MAP0]]>
+   %s = memref.subview %0[0, 1] [4, 127] [1, 1] :
+      memref<4x128xf32> to memref<4x127xf32, affine_map<(d0, d1) -> (d0 * 128 + d1 + 1)>>
+// CHECK: "some_use"(%[[SV1]]) : (memref<4x127xf32, #[[$MAP0]]>) -> ()
+   "some_use"(%s) : (memref<4x127xf32, affine_map<(d0, d1) -> (d0 * 128 + d1 + 1)>>) -> ()
+// CHECK: "some_use"(%[[SV]]) : (memref<4x128xf32, #[[$MAP0]]>) -> ()
+   "some_use"(%0) : (memref<4x128xf32>) -> ()
+  }
+  return
+}
+
+// -----
+
+// CHECK-LABEL: func @multi_buffer_negative
+func @multi_buffer_negative(%a: memref<1024x1024xf32>) {
+// CHECK-NOT: %{{.*}} = memref.alloc() : memref<5x4x128xf32>
+//     CHECK: %{{.*}} = memref.alloc() : memref<4x128xf32>
+  %0 = memref.alloc() : memref<4x128xf32>
+  %c1024 = arith.constant 1024 : index
+  %c0 = arith.constant 0 : index
+  %c3 = arith.constant 3 : index
+  scf.for %arg2 = %c0 to %c1024 step %c3 {
+   "blocking_use"(%0) : (memref<4x128xf32>) -> ()
+   %1 = memref.subview %a[%arg2, 0] [4, 128] [1, 1] :
+    memref<1024x1024xf32> to memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>>
+   memref.copy %1, %0 : memref<4x128xf32, affine_map<(d0, d1)[s0] -> (d0 * 1024 + s0 + d1)>> to memref<4x128xf32>
+   "some_use"(%0) : (memref<4x128xf32>) -> ()
+  }
+  return
+}
+
diff --git a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
--- a/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/MemRef/CMakeLists.txt
@@ -1,11 +1,13 @@
 # Exclude tests from libMLIR.so
 add_mlir_library(MLIRMemRefTestPasses
   TestComposeSubView.cpp
+  TestMultiBuffer.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
   LINK_LIBS PUBLIC
   MLIRPass
+  MLIRMemRef
   MLIRMemRefTransforms
   MLIRTestDialect
   )
diff --git a/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp b/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/test/lib/Dialect/MemRef/TestMultiBuffer.cpp
@@ -0,0 +1,53 @@
+//===- TestComposeSubView.cpp - Test composed subviews --------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+struct TestMultiBufferingPass
+    : public PassWrapper<TestMultiBufferingPass, OperationPass<FuncOp>> {
+  TestMultiBufferingPass() = default;
+  TestMultiBufferingPass(const TestMultiBufferingPass &pass)
+      : PassWrapper(pass) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect>();
+  }
+  StringRef getArgument() const final { return "test-multi-buffering"; }
+  StringRef getDescription() const final {
+    return "Test multi buffering transformation";
+  }
+  void runOnOperation() override;
+  Option<unsigned> multiplier{
+      *this, "multiplier",
+      llvm::cl::desc(
+          "Decide how many versions of the buffer should be created,"),
+      llvm::cl::init(2)};
+};
+
+void TestMultiBufferingPass::runOnOperation() {
+  SmallVector<memref::AllocOp> allocs;
+  getOperation().walk(
+      [&allocs](memref::AllocOp alloc) { allocs.push_back(alloc); });
+  for (memref::AllocOp alloc : allocs)
+    (void)multiBuffer(alloc, multiplier);
+}
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestMultiBuffering() {
+  PassRegistration<TestMultiBufferingPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -74,6 +74,7 @@
 void registerTestDynamicPipelinePass();
 void registerTestExpandTanhPass();
 void registerTestComposeSubView();
+void registerTestMultiBuffering();
 void registerTestGpuParallelLoopMappingPass();
 void registerTestIRVisitorsPass();
 void registerTestGenericIRVisitorsPass();
@@ -164,6 +165,7 @@
   mlir::test::registerTestDynamicPipelinePass();
   mlir::test::registerTestExpandTanhPass();
   mlir::test::registerTestComposeSubView();
+  mlir::test::registerTestMultiBuffering();
   mlir::test::registerTestGpuParallelLoopMappingPass();
   mlir::test::registerTestIRVisitorsPass();
   mlir::test::registerTestGenericIRVisitorsPass();
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -8037,6 +8037,7 @@
         ":ArithmeticTransforms",
         ":IR",
         ":InferTypeOpInterface",
+        ":LoopLikeInterface",
         ":MemRefDialect",
         ":MemRefPassIncGen",
         ":Pass",
diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
--- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel
@@ -438,6 +438,7 @@
     deps = [
         ":TestDialect",
         "//mlir:Affine",
+        "//mlir:MemRefDialect",
         "//mlir:MemRefTransforms",
         "//mlir:Pass",
         "//mlir:Transforms",