diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h
@@ -20,7 +20,10 @@
 
 struct InitTensorEliminationStep : public bufferization::PostAnalysisStep {
   /// A function that matches anchor OpOperands for InitTensorOp elimination.
-  using AnchorMatchFn = std::function<bool(OpOperand &)>;
+  /// If an OpOperand is matched, the function should populate the SmallVector
+  /// with all values that are needed during `RewriteFn` to produce the
+  /// replacement value.
+  using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
 
   /// A function that rewrites matched anchors.
   using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Operation.h"
 
 using namespace mlir;
@@ -444,6 +445,79 @@
 
 } // namespace
 
+/// Return true if all `neededValues` are in scope at the given
+/// `insertionPoint`.
+static bool
+neededValuesDominateInsertionPoint(const DominanceInfo &domInfo,
+                                   Operation *insertionPoint,
+                                   const SmallVector<Value> &neededValues) {
+  for (Value val : neededValues) {
+    if (auto bbArg = val.dyn_cast<BlockArgument>()) {
+      Block *owner = bbArg.getOwner();
+      if (!owner->findAncestorOpInBlock(*insertionPoint))
+        return false;
+    } else {
+      auto opResult = val.cast<OpResult>();
+      if (!domInfo.dominates(opResult.getOwner(), insertionPoint))
+        return false;
+    }
+  }
+  return true;
+}
+
+/// Return true if the given `insertionPoint` dominates all uses of
+/// `initTensorOp`.
+static bool insertionPointDominatesUses(const DominanceInfo &domInfo,
+                                        Operation *insertionPoint,
+                                        Operation *initTensorOp) {
+  for (Operation *user : initTensorOp->getUsers())
+    if (!domInfo.dominates(insertionPoint, user))
+      return false;
+  return true;
+}
+
+/// Find a valid insertion point for a replacement of `initTensorOp`, assuming
+/// that the replacement may use any value from `neededValues`.
+static Operation *
+findValidInsertionPoint(Operation *initTensorOp,
+                        const SmallVector<Value> &neededValues) {
+  DominanceInfo domInfo;
+
+  // Gather all possible insertion points: the location of `initTensorOp` and
+  // right after the definition of each value in `neededValues`.
+  SmallVector<Operation *> insertionPointCandidates;
+  insertionPointCandidates.push_back(initTensorOp);
+  for (Value val : neededValues) {
+    // Note: The anchor op is using all of `neededValues`, so:
+    // * in case of a block argument: There must be at least one op in the block
+    //                                (the anchor op or one of its parents).
+    // * in case of an OpResult: There must be at least one op right after the
+    //                           defining op (the anchor op or one of its
+    //                           parents).
+    if (auto bbArg = val.dyn_cast<BlockArgument>()) {
+      insertionPointCandidates.push_back(
+          &bbArg.getOwner()->getOperations().front());
+    } else {
+      insertionPointCandidates.push_back(val.getDefiningOp()->getNextNode());
+    }
+  }
+
+  // Select first matching insertion point.
+  for (Operation *insertionPoint : insertionPointCandidates) {
+    // Check if all needed values are in scope.
+    if (!neededValuesDominateInsertionPoint(domInfo, insertionPoint,
+                                            neededValues))
+      continue;
+    // Check if the insertion point is before all uses.
+    if (!insertionPointDominatesUses(domInfo, insertionPoint, initTensorOp))
+      continue;
+    return insertionPoint;
+  }
+
+  // No suitable insertion point was found.
+  return nullptr;
+}
+
 /// Try to eliminate InitTensorOps inside `op`. An InitTensorOp is replaced
 /// with the the result of `rewriteFunc` if it is anchored on a matching
 /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def
@@ -462,8 +536,10 @@
       // Skip operands that do not bufferize inplace.
       if (!aliasInfo.isInPlace(operand))
         continue;
+      // All values that are needed to create the replacement op.
+      SmallVector<Value> neededValues;
       // Is this a matching OpOperand?
-      if (!anchorMatchFunc(operand))
+      if (!anchorMatchFunc(operand, neededValues))
         continue;
       SetVector<Value> maybeInitTensor =
           state.findValueInReverseUseDefChain(operand.get(), [&](Value val) {
@@ -492,8 +568,14 @@
         return WalkResult::skip();
       Value initTensor = maybeInitTensor.front();
 
+      // Find a suitable insertion point.
+      Operation *insertionPoint =
+          findValidInsertionPoint(initTensor.getDefiningOp(), neededValues);
+      if (!insertionPoint)
+        continue;
+
       // Create a replacement for the InitTensorOp.
-      b.setInsertionPoint(initTensor.getDefiningOp());
+      b.setInsertionPoint(insertionPoint);
       Value replacement = rewriteFunc(b, initTensor.getLoc(), operand);
       if (!replacement)
         continue;
@@ -552,7 +634,7 @@
   return eliminateInitTensors(
       op, state, aliasInfo,
       /*anchorMatchFunc=*/
-      [&](OpOperand &operand) {
+      [&](OpOperand &operand, SmallVector<Value> &neededValues) {
         auto insertSliceOp =
             dyn_cast<tensor::InsertSliceOp>(operand.getOwner());
         if (!insertSliceOp)
@@ -560,7 +642,19 @@
         // Only inplace bufferized InsertSliceOps are eligible.
         if (!aliasInfo.isInPlace(insertSliceOp->getOpOperand(1) /*dest*/))
           return false;
-        return &operand == &insertSliceOp->getOpOperand(0) /*source*/;
+        if (&operand != &insertSliceOp->getOpOperand(0) /*source*/)
+          return false;
+
+        // Collect all values that are needed to construct the replacement op.
+        neededValues.append(insertSliceOp.offsets().begin(),
+                            insertSliceOp.offsets().end());
+        neededValues.append(insertSliceOp.sizes().begin(),
+                            insertSliceOp.sizes().end());
+        neededValues.append(insertSliceOp.strides().begin(),
+                            insertSliceOp.strides().end());
+        neededValues.push_back(insertSliceOp.dest());
+
+        return true;
       },
       /*rewriteFunc=*/
       [](OpBuilder &b, Location loc, OpOperand &operand) {
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-init-tensor-elimination.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref init-tensor-elimination" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref init-tensor-elimination" -canonicalize -split-input-file | FileCheck %s
 
 // -----
 
@@ -62,3 +62,62 @@
 
   return %r1: tensor<?xf32>
 }
+
+// -----
+
+//      CHECK: func @insertion_point_inside_loop(
+// CHECK-SAME:     %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index)
+func @insertion_point_inside_loop(%t : tensor<?xf32>, %sz : index) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c5 = arith.constant 5 : index
+
+  // CHECK-NOT: memref.alloc
+  %blank = linalg.init_tensor [5] : tensor<5xf32>
+
+  // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} {
+  %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor<?xf32>) {
+    // CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[iv]]] [5] [1]
+    %iv_i32 = arith.index_cast %iv : index to i32
+    %f = arith.sitofp %iv_i32 : i32 to f32
+
+    // CHECK: linalg.fill(%{{.*}}, %[[subview]])
+    %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32>
+
+    // CHECK-NOT: memref.copy
+    %inserted = tensor.insert_slice %filled into %bb[%iv][5][1] : tensor<5xf32> into tensor<?xf32>
+    scf.yield %inserted : tensor<?xf32>
+  }
+
+  return %r : tensor<?xf32>
+}
+
+// -----
+
+//      CHECK: func @insertion_point_outside_loop(
+// CHECK-SAME:     %[[t:.*]]: memref<?xf32, #{{.*}}>, %[[sz:.*]]: index, %[[idx:.*]]: index)
+func @insertion_point_outside_loop(%t : tensor<?xf32>, %sz : index,
+                                   %idx : index) -> (tensor<?xf32>) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c5 = arith.constant 5 : index
+
+  // CHECK-NOT: memref.alloc
+  // CHECK: %[[subview:.*]] = memref.subview %[[t]][%[[idx]]] [5] [1]
+  %blank = linalg.init_tensor [5] : tensor<5xf32>
+
+  // CHECK: scf.for %[[iv:.*]] = %{{.*}} to %[[sz]] step %{{.*}} {
+  %r = scf.for %iv = %c0 to %sz step %c5 iter_args(%bb = %t) -> (tensor<?xf32>) {
+    %iv_i32 = arith.index_cast %iv : index to i32
+    %f = arith.sitofp %iv_i32 : i32 to f32
+
+    // CHECK: linalg.fill(%{{.*}}, %[[subview]])
+    %filled = linalg.fill(%f, %blank) : f32, tensor<5xf32> -> tensor<5xf32>
+
+    // CHECK-NOT: memref.copy
+    %inserted = tensor.insert_slice %filled into %bb[%idx][5][1] : tensor<5xf32> into tensor<?xf32>
+    scf.yield %inserted : tensor<?xf32>
+  }
+
+  return %r : tensor<?xf32>
+}