diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -71,6 +71,8 @@
                             SmallVector<Operation *> &newOps) = 0;
 };
 
+using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
+
 /// Options for ComprehensiveBufferize.
 struct BufferizationOptions {
   BufferizationOptions();
@@ -107,7 +109,7 @@
   bool testAnalysisOnly = false;
 
   /// Registered post analysis steps.
-  std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
+  PostAnalysisStepList postAnalysisSteps;
 };
 
 /// Specify fine-grain relationship between buffers to enable more analysis.
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -18,13 +18,16 @@
 
 struct BufferizationOptions;
 struct BufferizationState;
+struct PostAnalysisStep;
 
 /// Bufferize the given function. Does not bufferize the function boundary.
+/// Reuses an existing BufferizationState object.
 // TODO: This function is meant to be called from ModuleBufferize and not can
 // not yet be called standalone.
-LogicalResult runComprehensiveBufferize(FuncOp funcOp,
-                                        const BufferizationOptions &options,
-                                        BufferizationState &state);
+LogicalResult runComprehensiveBufferize(
+    FuncOp funcOp, const BufferizationOptions &options,
+    BufferizationState &state,
+    const std::vector<std::unique_ptr<PostAnalysisStep>> &extraSteps);
 
 } // namespace comprehensive_bufferize
 } // namespace linalg
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -726,7 +726,7 @@
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     FuncOp funcOp, const BufferizationOptions &options,
-    BufferizationState &state) {
+    BufferizationState &state, const PostAnalysisStepList &extraSteps) {
 
   DominanceInfo domInfo(funcOp);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
@@ -744,16 +744,23 @@
     return failure();
   equivalenceAnalysis(op, aliasInfo);
 
-  for (const std::unique_ptr<PostAnalysisStep> &step :
-       options.postAnalysisSteps) {
-    SmallVector<Operation *> newOps;
-    if (failed(step->run(funcOp, state, newOps)))
-      return failure();
-    // Analyze ops that were created by the PostAnalysisStep.
-    if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
-      return failure();
-    equivalenceAnalysis(newOps, aliasInfo);
-  }
+  auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
+    for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
+      SmallVector<Operation *> newOps;
+      if (failed(step->run(funcOp, state, newOps)))
+        return failure();
+      // Analyze ops that were created by the PostAnalysisStep.
+      if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
+        return failure();
+      equivalenceAnalysis(newOps, aliasInfo);
+    }
+    return success();
+  };
+
+  if (failed(runPostAnalysisSteps(extraSteps)))
+    return failure();
+  if (failed(runPostAnalysisSteps(options.postAnalysisSteps)))
+    return failure();
 
   // Annotate operations if we only want to report the analysis.
   if (options.testAnalysisOnly) {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -33,8 +33,9 @@
   /// A map for looking up bufferized function types.
   DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
 
-  /// A mapping of return values to equivalent BlockArguments.
-  DenseMap<Value, BlockArgument> equivalentReturnValToBBArg;
+  /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
+  /// indices.
+  DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
 };
 } // namespace
 
@@ -44,6 +45,70 @@
       StandardOpsDialect::getDialectNamespace());
 }
 
+/// Return the unique ReturnOp that terminates `funcOp`.
+/// Return nullptr if there is no such unique ReturnOp.
+static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
+  ReturnOp returnOp;
+  for (Block &b : funcOp.body()) {
+    if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
+      if (returnOp)
+        return nullptr;
+      returnOp = candidateOp;
+    }
+  }
+  return returnOp;
+}
+
+namespace {
+/// Store function BlockArguments that are equivalent to a returned value in
+/// ModuleBufferizationState.
+struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
+  /// Annotate IR with the results of the analysis. For testing purposes only.
+  static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) {
+    const char *kEquivalentArgsAttr = "__equivalent_func_args__";
+    Operation *op = returnVal.getOwner();
+
+    SmallVector<int64_t> equivBbArgs;
+    if (op->hasAttr(kEquivalentArgsAttr)) {
+      auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
+      equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
+        return a.cast<IntegerAttr>().getValue().getSExtValue();
+      }));
+    } else {
+      equivBbArgs.append(op->getNumOperands(), -1);
+    }
+    equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
+
+    OpBuilder b(op->getContext());
+    op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
+  }
+
+  LogicalResult run(FuncOp funcOp, BufferizationState &state,
+                    SmallVector<Operation *> &newOps) override {
+    ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+
+    // Support only single return-terminated block in the function.
+    ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+    assert(returnOp && "expected func with single return op");
+
+    for (OpOperand &returnVal : returnOp->getOpOperands())
+      if (returnVal.get().getType().isa<RankedTensorType>())
+        for (BlockArgument bbArg : funcOp.getArguments())
+          if (bbArg.getType().isa<RankedTensorType>())
+            if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
+                                                              bbArg)) {
+              moduleState
+                  .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
+                  bbArg.getArgNumber();
+              if (state.options.testAnalysisOnly)
+                annotateReturnOp(returnVal, bbArg);
+            }
+
+    return success();
+  }
+};
+} // namespace
+
 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
 
 /// If `value` is a memref::CastOp, return its source. Otherwise, return
@@ -73,20 +138,6 @@
       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
 }
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
-  ReturnOp returnOp;
-  for (Block &b : funcOp.body()) {
-    if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
-    }
-  }
-  return returnOp;
-}
-
 /// Return the FunctionType with `argumentTypes` and `resultTypes` where each
 /// tensor is replaced by the corresponding buffer type.
 /// In order for all the callers to agree, this *must* bufferize to the most
@@ -128,22 +179,30 @@
   return it2.first->second;
 }
 
-/// Store function BlockArguments that are equivalent to a returned value in
-/// the given ModuleBufferizationState.
-static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
-                                           BufferizationState &state) {
-  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
-
-  // Support only single return-terminated block in the function.
-  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  assert(returnOp && "expected func with single return op");
+/// Gather equivalence info of CallOps.
+/// Note: This only adds new equivalence info if `funcOp` was already analyzed.
+// TODO: This does not handle cyclic function call graphs etc.
+static void equivalenceAnalysis(FuncOp funcOp,
+                                BufferizationAliasInfo &aliasInfo,
+                                ModuleBufferizationState &moduleState) {
+  funcOp->walk([&](CallOp callOp) {
+    FuncOp calledFunction = getCalledFunction(callOp);
+    assert(calledFunction && "could not retrieved called FuncOp");
+
+    // No equivalence info available for the called function.
+    if (!moduleState.equivalentFuncArgs.count(calledFunction))
+      return WalkResult::skip();
+
+    for (auto it : moduleState.equivalentFuncArgs[calledFunction]) {
+      int64_t returnIdx = it.first;
+      int64_t bbargIdx = it.second;
+      Value returnVal = callOp.getResult(returnIdx);
+      Value argVal = callOp->getOperand(bbargIdx);
+      aliasInfo.unionEquivalenceClasses(returnVal, argVal);
+    }
 
-  for (Value returnVal : returnOp.operands())
-    if (returnVal.getType().isa<RankedTensorType>())
-      for (BlockArgument bbArg : funcOp.getArguments())
-        if (bbArg.getType().isa<RankedTensorType>())
-          if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
-            moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
+    return WalkResult::advance();
+  });
 }
 
 /// Rewrite the `funcOp` arguments analysis return values and terminator into
@@ -217,7 +276,8 @@
     }
 
     // If return operand is equivalent to some bbArg, no need to return it.
-    if (moduleState.equivalentReturnValToBBArg.count(returnVal))
+    if (moduleState.equivalentFuncArgs[funcOp].count(
+            returnOperand.getOperandNumber()))
       continue;
 
     // Cast values at the call site if necessary.
@@ -499,12 +559,12 @@
         }
 
         // If return operand is equivalent to some bbArg, no need to return it.
-        Value returnVal = returnOperand.get();
-        if (moduleState.equivalentReturnValToBBArg.count(returnVal)) {
-          BlockArgument bbArg =
-              moduleState.equivalentReturnValToBBArg[returnVal];
+        if (moduleState.equivalentFuncArgs[funcOp].count(
+                returnOperand.getOperandNumber())) {
+          int64_t idx =
+              moduleState
+                  .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
           Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
-          int64_t idx = bbArg.getArgNumber();
           Value buffer = state.lookupBuffer(callOp->getOperand(idx));
           // Add CallOp operand/result equivalence: this is interprocedural
           // info.
@@ -667,6 +727,7 @@
     return failure();
 
   BufferizationState state(moduleOp, options);
+  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
 
   // Interestingly, all function args that are not visible outside of a module
@@ -698,11 +759,17 @@
         aliasInfo.setBufferizesToWritableMemory(bbArg);
     }
 
+    // Register extra post analysis steps. These cannot be stored in `options`
+    // because `options` is immutable.
+    PostAnalysisStepList extraSteps;
+    extraSteps.emplace_back(std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
+
+    // Gather equivalence info for CallOps.
+    equivalenceAnalysis(funcOp, aliasInfo, moduleState);
+
     // Analyze and bufferize funcOp.
-    if (failed(runComprehensiveBufferize(funcOp, options, state)))
+    if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps)))
       return failure();
-
-    populateEquivalentFuncOpBBArgs(funcOp, state);
   }
 
   if (options.testAnalysisOnly)
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir
@@ -40,15 +40,17 @@
   -> (tensor<?xf32>, tensor<?xf32>)
 {
   // must bufferize out of place.
-  //     CHECK: tensor.insert_slice
+  //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
   %r0 = tensor.insert_slice %C into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
   // bufferizes inplace.
-  //     CHECK: tensor.insert_slice
+  //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
   %r1 = tensor.insert_slice %C into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
   return %r0, %r1: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -81,6 +83,8 @@
                      outs(%B: tensor<4x4xf32>)
     -> tensor<4x4xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, -1, 1]}
   return %C, %D, %E: tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>
 }
 
@@ -136,6 +140,8 @@
   // CHECK: {__inplace_results_attr__ = ["false"]}
   %r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %r1, %r3: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -172,6 +178,8 @@
   // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
   %r3 = tensor.insert_slice %r2 into %B[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %r1, %r3: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -208,6 +216,8 @@
   // CHECK-SAME: {__inplace_results_attr__ = ["false"]}
   %r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %r1, %r3: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -234,6 +244,9 @@
   %2 = tensor.insert_slice %1 into %A[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
 
   %3 = vector.transfer_read %1[%idx2], %cst2 : tensor<?xf32>, vector<5xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %2, %3 : tensor<?xf32>, vector<5xf32>
 }
 
@@ -274,6 +287,8 @@
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
   %6 = tensor.insert_slice %5 into %2[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %6, %3 : tensor<?xf32>, vector<5xf32>
 }
 
@@ -306,6 +321,8 @@
                      outs(%C: tensor<4x4xf32>)
     -> tensor<4x4xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 2]}
   return %D, %E: tensor<4x4xf32>, tensor<4x4xf32>
 }
 
@@ -372,6 +389,8 @@
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]}
   %20 = tensor.insert_slice %19 into %C[%s3, %s4] [%s1, %s2] [1, 1] : tensor<?x?xf32> into tensor<30x20xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [6]}
   return %20 : tensor<30x20xf32>
 }
 
@@ -502,6 +521,8 @@
   %rsC = tensor.insert_slice %FC into %sC[0, 0][12345, 67890][1, 1] : tensor<4x4xf32> into tensor<?x?xf32>
   %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 1, 2]}
   return %rA, %rB, %rC: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
 }
 
@@ -531,6 +552,8 @@
     scf.yield %t : tensor<?xf32>
   }
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
   return %r0, %r1: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -562,6 +585,8 @@
     scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
   }
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
   return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -621,6 +646,8 @@
     linalg.yield %t : tensor<?xf32>
   }
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, 1]}
   return %r1, %r3: tensor<?xf32>, tensor<?xf32>
 }
 
@@ -766,6 +793,8 @@
          ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
         outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [2]}
   return %r : tensor<256x256xf32>
 }
 
@@ -811,6 +840,8 @@
          ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
         outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [2]}
   return %r : tensor<256x256xf32>
 }
 
@@ -856,6 +887,8 @@
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %14 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [4]}
   return %15 : tensor<62x90xf32>
 }
 
@@ -881,6 +914,9 @@
     %t3 = tensor.insert_slice %t2 into %arg1[%x, 0] [5, %y] [1, 1] : tensor<5x?xf32> into tensor<10x20xf32>
     scf.yield %t3 : tensor<10x20xf32>
   }
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
  return %r : tensor<10x20xf32>
 }
 
@@ -908,6 +944,9 @@
       ^bb(%0: f32, %1: f32, %2 : f32) :
         linalg.yield %0, %0 : f32, f32
     } -> (tensor<?xf32>, tensor<?xf32>)
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [1, -1]}
   return %o#0, %o#1 : tensor<?xf32>, tensor<?xf32>
 }
 
@@ -949,6 +988,8 @@
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %14 into %e[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<?x?xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [2, -1]}
   return %8, %15 : tensor<62x90xf32>, tensor<?x?xf32>
 }
 
@@ -978,6 +1019,8 @@
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %10 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %15 : tensor<62x90xf32>
 }
 
@@ -1007,6 +1050,8 @@
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %10 into %8[31, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %15 : tensor<62x90xf32>
 }
 
@@ -1029,6 +1074,8 @@
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %15 = tensor.insert_slice %2 into %8[15, 0] [32, 90] [1, 1] : tensor<32x90xf32> into tensor<62x90xf32>
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %15 : tensor<62x90xf32>
 }
 
@@ -1130,6 +1177,8 @@
         linalg.yield %cst : f32
     } -> (tensor<?xf32>)
 
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %o, %v3 : tensor<?xf32>, vector<5xf32>
 }
 
@@ -1158,6 +1207,9 @@
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %3 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [-1, 0]}
   return %2, %3 : tensor<?xf32>, tensor<?xf32>
 }
 
@@ -1178,6 +1230,9 @@
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, 0]}
   return %2, %2 : tensor<?xf32>, tensor<?xf32>
 }
 
@@ -1212,6 +1267,8 @@
     %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
     scf.yield %t2 : tensor<?xf32>
   }
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r : tensor<?xf32>
 }
 
@@ -1261,6 +1318,9 @@
     scf.yield %r : tensor<?xf32>
   }
   %v2 = vector.transfer_read %r_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
   return %r_alias, %v2 : tensor<?xf32>, vector<10xf32>
 }
 
@@ -1286,6 +1346,9 @@
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r2 : tensor<?xf32>
 }
 
@@ -1316,6 +1379,9 @@
     %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
     scf.yield %t3 : tensor<?xf32>
   }
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r : tensor<?xf32>
 }
 
@@ -1394,6 +1460,9 @@
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r2 : tensor<?xf32>
 }
 
@@ -1418,6 +1487,9 @@
   //      CHECK: tensor.insert_slice
   // CHECK-SAME: {__inplace_results_attr__ = ["true"]
   %r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
+
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
   return %r2 : tensor<?xf32>
 }
 
@@ -1531,3 +1603,44 @@
 
   return %r1, %r2 : vector<5xf32>, vector<5xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @inner_func
+func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
+  return %t : tensor<?xf32>
+}
+
+func @equivalent_func_arg(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf32>) -> tensor<?xf32> {
+  // This test does not check IR. It just asserts there is no failure due to
+  // non-equivalent scf.for yield values.
+  %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
+    %3 = call @inner_func(%t1) : (tensor<?xf32>) -> tensor<?xf32>
+    scf.yield %3 : tensor<?xf32>
+  }
+  return %1: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @inner_func_2
+func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
+  %f = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+  //      CHECK: return
+  // CHECK-SAME: {__equivalent_func_args__ = [0]}
+  return %0 : tensor<?xf32>
+}
+
+func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf32>) -> tensor<?xf32> {
+  // This test does not check IR. It just asserts there is no failure due to
+  // non-equivalent scf.for yield values.
+  %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
+    %3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
+    scf.yield %3 : tensor<?xf32>
+  }
+  return %1: tensor<?xf32>
+}
diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
--- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
+++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir
@@ -928,3 +928,54 @@
   // CHECK: return
   return %0 : tensor<?xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @inner_func(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
+  %f = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: memref.store %{{.*}}, %[[arg0]]
+  %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @equivalent_func_arg(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @equivalent_func_arg(%t0: tensor<?xf32> {linalg.inplaceable = true},
+                          %c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
+  // CHECK-NOT: copy
+  %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
+    // CHECK: call @inner_func(%[[arg0]])
+    %3 = call @inner_func(%t1) : (tensor<?xf32>) -> tensor<?xf32>
+    scf.yield %3 : tensor<?xf32>
+  }
+  return %1: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @inner_func_2(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
+  %f = arith.constant 1.0 : f32
+  %c0 = arith.constant 0 : index
+  // CHECK: memref.store %{{.*}}, %[[arg0]]
+  %0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL: func @equivalent_func_arg_2(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<?xf32
+func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
+                            %c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
+  %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
+    // TODO: There should be a memory copy here. This is a bug in CallOp
+    // bufferization.
+    // CHECK: call @inner_func_2(%[[arg0]])
+    %3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
+    scf.yield %t1 : tensor<?xf32>
+  }
+  return %1: tensor<?xf32>
+}