diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
--- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
@@ -3362,6 +3362,10 @@
     data. The result value is a tensor whose shape and element type match the
     memref operand.
 
+    The opposite of this op is tensor_to_memref. Together, these two ops are
+    useful for source/target materializations when doing type conversions
+    involving tensors and memrefs.
+
     Example:
 
     ```mlir
@@ -3393,6 +3397,8 @@
   }];
 
   let assemblyFormat = "$memref attr-dict `:` type($memref)";
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -3427,6 +3433,47 @@
   let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
 }
 
+//===----------------------------------------------------------------------===//
+// TensorToMemrefOp
+//===----------------------------------------------------------------------===//
+
+def TensorToMemrefOp : Std_Op<"tensor_to_memref",
+    [SameOperandsAndResultShape, SameOperandsAndResultElementType,
+     TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'",
+                    "memref", "tensor",
+                    "getTensorTypeFromMemRefType($_self)">]> {
+  let summary = "tensor to memref operation";
+  let description = [{
+    Create a memref from a tensor. This is equivalent to allocating a new
+    memref of the appropriate (possibly dynamic) shape, and then copying the
+    elements (as if by a tensor_store op) into the newly allocated memref.
+
+    The opposite of this op is tensor_load. Together, these two ops are useful
+    for source/target materializations when doing type conversions involving
+    tensors and memrefs.
+
+    Note: This op takes the memref type in its pretty form because the tensor
+    type can always be inferred from the memref type, but the reverse is not
+    true. For example, the memref might have a layout map or memory space which
+    cannot be inferred from the tensor type.
+
+    ```mlir
+    // Result type is tensor<4x?xf32>
+    %12 = tensor_to_memref %10 :  memref<4x?xf32, #map0, 42>
+    ```
+  }];
+
+  let arguments = (ins AnyTensor:$tensor);
+  let results = (outs Res<AnyRankedOrUnrankedMemRef,
+                      "the memref to create", [MemAlloc]>:$memref);
+  // This op is fully verified by traits.
+  let verifier = ?;
+
+  let assemblyFormat = "$tensor attr-dict `:` type($memref)";
+
+  let hasFolder = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -16,6 +16,7 @@
 #define MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES_H_
 
 #include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/Bufferize.h"
 
 namespace mlir {
 
@@ -27,6 +28,13 @@
 void populateExpandTanhPattern(OwningRewritePatternList &patterns,
                                MLIRContext *ctx);
 
+void populateStdBufferizePatterns(MLIRContext *context,
+                                  BufferizeTypeConverter &typeConverter,
+                                  OwningRewritePatternList &patterns);
+
+/// Creates an instance of the StdBufferize pass.
+std::unique_ptr<Pass> createStdBufferizePass();
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td
@@ -16,4 +16,9 @@
   let constructor = "mlir::createExpandAtomicPass()";
 }
 
+def StdBufferize : FunctionPass<"std-bufferize"> {
+  let summary = "Bufferize the std dialect";
+  let constructor = "mlir::createStdBufferizePass()";
+}
+
 #endif // MLIR_DIALECT_STANDARD_TRANSFORMS_PASSES
diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h
--- a/mlir/include/mlir/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Transforms/Bufferize.h
@@ -13,6 +13,16 @@
 // pattern needs to be written. The infrastructure in this file assists in
 // defining these conversion patterns in a composable way.
 //
+// Bufferization conversion patterns should generally use the ordinary
+// conversion pattern classes (e.g. OpConversionPattern). A TypeConverter
+// (accessible with getTypeConverter()) available on such patterns in sufficient
+// for most cases (if needed at all).
+//
+// But some patterns require access to the extra functions on
+// BufferizeTypeConverter that don't exist on the base TypeConverter class. For
+// those cases, BufferizeConversionPattern and its related classes should be
+// used, which provide access to a BufferizeTypeConverter directly.
+//
 //===----------------------------------------------------------------------===//
 
 #ifndef MLIR_TRANSFORMS_BUFFERIZE_H
diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
--- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
+++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp
@@ -3592,7 +3592,7 @@
 }
 
 //===----------------------------------------------------------------------===//
-// Helpers for Tensor[Load|Store]Op
+// Helpers for Tensor[Load|Store]Op and TensorToMemrefOp
 //===----------------------------------------------------------------------===//
 
 static Type getTensorTypeFromMemRefType(Type type) {
@@ -3603,6 +3603,27 @@
   return NoneType::get(type.getContext());
 }
 
+//===----------------------------------------------------------------------===//
+// TensorLoadOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
+  if (auto tensorToMemref = memref().getDefiningOp<TensorToMemrefOp>())
+    return tensorToMemref.tensor();
+  return {};
+}
+
+//===----------------------------------------------------------------------===//
+// TensorToMemrefOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult TensorToMemrefOp::fold(ArrayRef<Attribute>) {
+  if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
+    if (tensorLoad.memref().getType() == getType())
+      return tensorLoad.memref();
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // TransposeOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -0,0 +1,62 @@
+//===- Bufferize.cpp - Bufferization for std ops --------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements bufferization of std ops.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Transforms/Bufferize.h"
+#include "PassDetail.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
+#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+
+namespace {
+class BufferizeTensorCastOp : public OpConversionPattern<TensorCastOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(TensorCastOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto resultType = getTypeConverter()->convertType(op.getType());
+    rewriter.replaceOpWithNewOp<MemRefCastOp>(op, resultType, operands[0]);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::populateStdBufferizePatterns(MLIRContext *context,
+                                        BufferizeTypeConverter &typeConverter,
+                                        OwningRewritePatternList &patterns) {
+  patterns.insert<BufferizeTensorCastOp>(typeConverter, context);
+}
+
+namespace {
+struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
+  void runOnFunction() override {
+    auto *context = &getContext();
+    BufferizeTypeConverter typeConverter;
+    OwningRewritePatternList patterns;
+    ConversionTarget target(*context);
+
+    target.addLegalDialect<StandardOpsDialect>();
+
+    populateStdBufferizePatterns(context, typeConverter, patterns);
+    target.addIllegalOp<TensorCastOp>();
+
+    if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+std::unique_ptr<Pass> mlir::createStdBufferizePass() {
+  return std::make_unique<StdBufferizePass>();
+}
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
--- a/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/StandardOps/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRStandardOpsTransforms
+  Bufferize.cpp
   ExpandAtomic.cpp
   ExpandTanh.cpp
   FuncConversions.cpp
diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -27,6 +27,18 @@
   addConversion([](UnrankedTensorType type) -> Type {
     return UnrankedMemRefType::get(type.getElementType(), 0);
   });
+  addSourceMaterialization([](OpBuilder &builder, RankedTensorType type,
+                              ValueRange inputs, Location loc) -> Value {
+    assert(inputs.size() == 1);
+    assert(inputs[0].getType().isa<BaseMemRefType>());
+    return builder.create<TensorLoadOp>(loc, type, inputs[0]);
+  });
+  addTargetMaterialization([](OpBuilder &builder, MemRefType type,
+                              ValueRange inputs, Location loc) -> Value {
+    assert(inputs.size() == 1);
+    assert(inputs[0].getType().isa<TensorType>());
+    return builder.create<TensorToMemrefOp>(loc, type, inputs[0]);
+  });
 }
 
 /// This method tries to decompose a value of a certain type using provided
diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Dialect/Standard/bufferize.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s -std-bufferize | FileCheck %s
+
+// CHECK-LABEL:   func @tensor_cast(
+// CHECK-SAME:                      %[[TENSOR:.*]]: tensor<?xindex>) -> tensor<2xindex> {
+// CHECK:           %[[MEMREF:.*]] = tensor_to_memref %[[TENSOR]]
+// CHECK:           %[[CASTED:.*]] = memref_cast %[[MEMREF]] : memref<?xindex> to memref<2xindex>
+// CHECK:           %[[RET:.*]] = tensor_load %[[CASTED]]
+// CHECK:           return %[[RET]] : tensor<2xindex>
+func @tensor_cast(%arg0: tensor<?xindex>) -> tensor<2xindex> {
+  %0 = tensor_cast %arg0 : tensor<?xindex> to tensor<2xindex>
+  return %0 : tensor<2xindex>
+}
diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Dialect/Standard/canonicalize.mlir
@@ -0,0 +1,33 @@
+// RUN: mlir-opt %s -canonicalize | FileCheck %s
+
+// Test case: Basic folding of tensor_load(tensor_to_memref(t)) -> t
+// CHECK-LABEL:   func @tensor_load_of_tensor_to_memref(
+// CHECK-SAME:                                          %[[TENSOR:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:           return %[[TENSOR]]
+func @tensor_load_of_tensor_to_memref(%arg0: tensor<?xf32>) -> tensor<?xf32> {
+    %0 = tensor_to_memref %arg0 : memref<?xf32>
+    %1 = tensor_load %0 : memref<?xf32>
+    return %1 : tensor<?xf32>
+}
+
+// Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m
+// CHECK-LABEL:   func @tensor_to_memref_of_tensor_load(
+// CHECK-SAME:                                          %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> {
+// CHECK:           return %[[MEMREF]]
+func @tensor_to_memref_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
+    %0 = tensor_load %arg0 : memref<?xf32>
+    %1 = tensor_to_memref %0 : memref<?xf32>
+    return %1 : memref<?xf32>
+}
+
+// Test case: If the memrefs are not the same type, don't fold them.
+// CHECK-LABEL:   func @no_fold_tensor_to_memref_of_tensor_load(
+// CHECK-SAME:                                                  %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>) -> memref<?xf32, 7> {
+// CHECK:           %[[TENSOR:.*]] = tensor_load %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2>
+// CHECK:           %[[MEMREF_ADDRSPACE7:.*]] = tensor_to_memref %[[TENSOR]] : memref<?xf32, 7>
+// CHECK:           return %[[MEMREF_ADDRSPACE7]]
+func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref<?xf32, 7> {
+    %0 = tensor_load %arg0 : memref<?xf32, 2>
+    %1 = tensor_to_memref %0 : memref<?xf32, 7>
+    return %1 : memref<?xf32, 7>
+}
diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir
--- a/mlir/test/Dialect/Standard/ops.mlir
+++ b/mlir/test/Dialect/Standard/ops.mlir
@@ -19,6 +19,13 @@
   return %0 : tensor<index>
 }
 
+// CHECK-LABEL: test_tensor_to_memref
+func @test_tensor_to_memref(%arg0: tensor<?xi64>, %arg1: tensor<*xi64>) -> (memref<?xi64, affine_map<(d0) -> (d0 + 7)>>, memref<*xi64, 1>) {
+  %0 = tensor_to_memref %arg0 : memref<?xi64, affine_map<(d0) -> (d0 + 7)>>
+  %1 = tensor_to_memref %arg1 : memref<*xi64, 1>
+  return %0, %1 : memref<?xi64, affine_map<(d0) -> (d0 + 7)>>, memref<*xi64, 1>
+}
+
 // CHECK-LABEL: @assert
 func @assert(%arg : i1) {
   assert %arg, "Some message in case this assertion fails."