diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h
@@ -71,6 +71,8 @@
                             SmallVector<Operation *> &newOps) = 0;
 };
 
+using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
+
 /// Options for ComprehensiveBufferize.
 struct BufferizationOptions {
   BufferizationOptions();
@@ -107,7 +109,7 @@
   bool testAnalysisOnly = false;
 
   /// Registered post analysis steps.
-  std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
+  PostAnalysisStepList postAnalysisSteps;
 };
 
 /// Specify fine-grain relationship between buffers to enable more analysis.
diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
--- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
+++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h
@@ -18,6 +18,7 @@
 
 struct BufferizationOptions;
 struct BufferizationState;
+struct PostAnalysisStep;
 
 /// Bufferize the given op.
 LogicalResult runComprehensiveBufferize(Operation *op,
@@ -25,9 +26,10 @@
 
 /// Bufferize the given function. Does not bufferize the function boundary.
 /// Reuses an existing BufferizationState object.
-LogicalResult runComprehensiveBufferize(Operation *op,
-                                        const BufferizationOptions &options,
-                                        BufferizationState &state);
+LogicalResult runComprehensiveBufferize(
+    Operation *op, const BufferizationOptions &options,
+    BufferizationState &state,
+    const std::vector<std::unique_ptr<PostAnalysisStep>> &extraSteps);
 
 } // namespace comprehensive_bufferize
 } // namespace linalg
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp
@@ -717,12 +717,13 @@
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     Operation *op, const BufferizationOptions &options) {
   BufferizationState state(op, options);
-  return runComprehensiveBufferize(op, options, state);
+  PostAnalysisStepList extraSteps;
+  return runComprehensiveBufferize(op, options, state, extraSteps);
 }
 
 LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
     Operation *op, const BufferizationOptions &options,
-    BufferizationState &state) {
+    BufferizationState &state, const PostAnalysisStepList &extraSteps) {
 
   DominanceInfo domInfo(op);
   BufferizationAliasInfo &aliasInfo = state.aliasInfo;
@@ -736,16 +737,23 @@
     return failure();
   equivalenceAnalysis(op, aliasInfo);
 
-  for (const std::unique_ptr<PostAnalysisStep> &step :
-       options.postAnalysisSteps) {
-    SmallVector<Operation *> newOps;
-    if (failed(step->run(op, state, newOps)))
-      return failure();
-    // Analyze ops that were created by the PostAnalysisStep.
-    if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
-      return failure();
-    equivalenceAnalysis(newOps, aliasInfo);
-  }
+  auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
+    for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
+      SmallVector<Operation *> newOps;
+      if (failed(step->run(op, state, newOps)))
+        return failure();
+      // Analyze ops that were created by the PostAnalysisStep.
+      if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
+        return failure();
+      equivalenceAnalysis(newOps, aliasInfo);
+    }
+    return success();
+  };
+
+  if (failed(runPostAnalysisSteps(extraSteps)))
+    return failure();
+  if (failed(runPostAnalysisSteps(options.postAnalysisSteps)))
+    return failure();
 
   // Annotate operations if we only want to report the analysis.
   if (options.testAnalysisOnly) {
diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
--- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
+++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp
@@ -33,8 +33,9 @@
   /// A map for looking up bufferized function types.
   DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
 
-  /// A mapping of return values to equivalent BlockArguments.
-  DenseMap<Value, BlockArgument> equivalentReturnValToBBArg;
+  /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
+  /// indices.
+  DenseMap<int64_t, int64_t> equivalentFuncArgs;
 };
 } // namespace
 
@@ -44,6 +45,47 @@
       StandardOpsDialect::getDialectNamespace());
 }
 
+/// Return the unique ReturnOp that terminates `funcOp`.
+/// Return nullptr if there is no such unique ReturnOp.
+static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
+  ReturnOp returnOp;
+  for (Block &b : funcOp.body()) {
+    if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
+      if (returnOp)
+        return nullptr;
+      returnOp = candidateOp;
+    }
+  }
+  return returnOp;
+}
+
+namespace {
+/// Store function BlockArguments that are equivalent to a returned value in
+/// ModuleBufferizationState.
+struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
+  LogicalResult run(Operation *op, BufferizationState &state,
+                    SmallVector<Operation *> &newOps) override {
+    ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
+    auto funcOp = cast<FuncOp>(op);
+
+    // Support only single return-terminated block in the function.
+    ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
+    assert(returnOp && "expected func with single return op");
+
+    for (OpOperand &returnVal : returnOp->getOpOperands())
+      if (returnVal.get().getType().isa<RankedTensorType>())
+        for (BlockArgument bbArg : funcOp.getArguments())
+          if (bbArg.getType().isa<RankedTensorType>())
+            if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
+                                                              bbArg))
+              moduleState.equivalentFuncArgs[returnVal.getOperandNumber()] =
+                  bbArg.getArgNumber();
+
+    return success();
+  }
+};
+} // namespace
+
 static bool isaTensor(Type t) { return t.isa<TensorType>(); }
 
 /// If `value` is a memref::CastOp, return its source. Otherwise, return
@@ -73,20 +115,6 @@
       SymbolTable::lookupNearestSymbolFrom(callOp, sym));
 }
 
-/// Return the unique ReturnOp that terminates `funcOp`.
-/// Return nullptr if there is no such unique ReturnOp.
-static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
-  ReturnOp returnOp;
-  for (Block &b : funcOp.body()) {
-    if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
-      if (returnOp)
-        return nullptr;
-      returnOp = candidateOp;
-    }
-  }
-  return returnOp;
-}
-
 /// Return the FunctionType with `argumentTypes` and `resultTypes` where each
 /// tensor is replaced by the corresponding buffer type.
 /// In order for all the callers to agree, this *must* bufferize to the most
@@ -128,24 +156,6 @@
   return it2.first->second;
 }
 
-/// Store function BlockArguments that are equivalent to a returned value in
-/// the given ModuleBufferizationState.
-static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
-                                           BufferizationState &state) {
-  ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
-
-  // Support only single return-terminated block in the function.
-  ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
-  assert(returnOp && "expected func with single return op");
-
-  for (Value returnVal : returnOp.operands())
-    if (returnVal.getType().isa<RankedTensorType>())
-      for (BlockArgument bbArg : funcOp.getArguments())
-        if (bbArg.getType().isa<RankedTensorType>())
-          if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
-            moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
-}
-
 /// Rewrite the `funcOp` arguments analysis return values and terminator into
 /// buffer form (using the canonical memref layout for now), according to the
 /// inPlace-bufferizable information of the function arguments.
@@ -217,7 +227,7 @@
     }
 
     // If return operand is equivalent to some bbArg, no need to return it.
-    if (moduleState.equivalentReturnValToBBArg.count(returnVal))
+    if (moduleState.equivalentFuncArgs.count(returnOperand.getOperandNumber()))
       continue;
 
     // Cast values at the call site if necessary.
@@ -503,12 +513,11 @@
         }
 
         // If return operand is equivalent to some bbArg, no need to return it.
-        Value returnVal = returnOperand.get();
-        if (moduleState.equivalentReturnValToBBArg.count(returnVal)) {
-          BlockArgument bbArg =
-              moduleState.equivalentReturnValToBBArg[returnVal];
+        if (moduleState.equivalentFuncArgs.count(
+                returnOperand.getOperandNumber())) {
+          int64_t idx =
+              moduleState.equivalentFuncArgs[returnOperand.getOperandNumber()];
           Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
-          int64_t idx = bbArg.getArgNumber();
           Value buffer = state.lookupBuffer(callOp->getOperand(idx));
           // Add CallOp operand/result equivalence: this is interprocedural
           // info.
@@ -710,11 +719,14 @@
         aliasInfo.setBufferizesToWritableMemory(bbArg);
     }
 
+    // Register extra post analysis steps. These cannot be stored in `options`
+    // because `options` is immutable.
+    PostAnalysisStepList extraSteps;
+    extraSteps.emplace_back(std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
+
     // Analyze and bufferize funcOp.
-    if (failed(runComprehensiveBufferize(funcOp, options, state)))
+    if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps)))
       return failure();
-
-    populateEquivalentFuncOpBBArgs(funcOp, state);
   }
 
   if (options.testAnalysisOnly)