diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -17,11 +17,7 @@
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
-#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Block.h"
-#include "mlir/IR/Location.h"
-#include "mlir/Support/LLVM.h"
-#include "llvm/ADT/SmallVector.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include <memory>
 #include <optional>
@@ -35,6 +31,186 @@
 class Operation;
 class Value;
+// LoopNestStateCollector walks loop nests and collects load and store
+// operations, and whether or not a region holding op other than ForOp and IfOp
+// was encountered in the loop nest.
+struct LoopNestStateCollector {
+  SmallVector<AffineForOp, 4> forOps;
+  SmallVector<Operation *, 4> loadOpInsts;
+  SmallVector<Operation *, 4> storeOpInsts;
+  bool hasNonAffineRegionOp = false;
+  // Collects load and store operations, and whether or not a region holding op
+  // other than ForOp and IfOp was encountered in the loop nest.
+  void collect(Operation *opToWalk);
+// MemRefDependenceGraph is a graph data structure where graph nodes are
+// top-level operations in a `Block` which contain load/store ops, and edges
+// are memref dependences between the nodes.
+// TODO: Add a more flexible dependence graph representation.
+// TODO: Add a depth parameter to dependence graph construction.
+struct MemRefDependenceGraph {
+  // Node represents a node in the graph. A Node is either an entire loop nest
+  // rooted at the top level which contains loads/stores, or a top level
+  // load/store.
+  struct Node {
+    // The unique identifier of this node in the graph.
+    unsigned id;
+    // The top-level statement which is (or contains) a load/store.
+    Operation *op;
+    // List of load operations.
+    SmallVector<Operation *, 4> loads;
+    // List of store op insts.
+    SmallVector<Operation *, 4> stores;
+    Node(unsigned id, Operation *op) : id(id), op(op) {}
+    // Returns the load op count for 'memref'.
+    unsigned getLoadOpCount(Value memref) const;
+    // Returns the store op count for 'memref'.
+    unsigned getStoreOpCount(Value memref) const;
+    // Returns all store ops in 'storeOps' which access 'memref'.
+    void getStoreOpsForMemref(Value memref,
+                              SmallVectorImpl<Operation *> *storeOps) const;
+    // Returns all load ops in 'loadOps' which access 'memref'.
+    void getLoadOpsForMemref(Value memref,
+                             SmallVectorImpl<Operation *> *loadOps) const;
+    // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
+    // has at least one load and store operation.
+    void getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) const;
+  };
+  // Edge represents a data dependence between nodes in the graph.
+  struct Edge {
+    // The id of the node at the other end of the edge.
+    // If this edge is stored in Edge = Node.inEdges[i], then
+    // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
+    // If this edge is stored in Edge = Node.outEdges[i], then
+    // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
+    unsigned id;
+    // The SSA value on which this edge represents a dependence.
+    // If the value is a memref, then the dependence is between graph nodes
+    // which contain accesses to the same memref 'value'. If the value is a
+    // non-memref value, then the dependence is between a graph node which
+    // defines an SSA value and another graph node which uses the SSA value
+    // (e.g. a constant or load operation defining a value which is used inside
+    // a loop nest).
+    Value value;
+  };
+  // Map from node id to Node.
+  DenseMap<unsigned, Node> nodes;
+  // Map from node id to list of input edges.
+  DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
+  // Map from node id to list of output edges.
+  DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
+  // Map from memref to a count on the dependence edges associated with that
+  // memref.
+  DenseMap<Value, unsigned> memrefEdgeCount;
+  // The next unique identifier to use for newly created graph nodes.
+  unsigned nextNodeId = 0;
+  MemRefDependenceGraph(Block &block) : block(block) {}
+  // Initializes the dependence graph based on operations in `block'.
+  // Returns true on success, false otherwise.
+  bool init();
+  // Returns the graph node for 'id'.
+  Node *getNode(unsigned id);
+  // Returns the graph node for 'forOp'.
+  Node *getForOpNode(AffineForOp forOp);
+  // Adds a node with 'op' to the graph and returns its unique identifier.
+  unsigned addNode(Operation *op);
+  // Remove node 'id' (and its associated edges) from graph.
+  void removeNode(unsigned id);
+  // Returns true if node 'id' writes to any memref which escapes (or is an
+  // argument to) the block. Returns false otherwise.
+  bool writesToLiveInOrEscapingMemrefs(unsigned id);
+  // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
+  // is for 'value' if non-null, or for any value otherwise. Returns false
+  // otherwise.
+  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr);
+  // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
+  void addEdge(unsigned srcId, unsigned dstId, Value value);
+  // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
+  void removeEdge(unsigned srcId, unsigned dstId, Value value);
+  // Returns true if there is a path in the dependence graph from node 'srcId'
+  // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
+  // operations that the edges connected are expected to be from the same block.
+  bool hasDependencePath(unsigned srcId, unsigned dstId);
+  // Returns the input edge count for node 'id' and 'memref' from src nodes
+  // which access 'memref' with a store operation.
+  unsigned getIncomingMemRefAccesses(unsigned id, Value memref);
+  // Returns the output edge count for node 'id' and 'memref' (if non-null),
+  // otherwise returns the total output edge count from node 'id'.
+  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr);
+  /// Return all nodes which define SSA values used in node 'id'.
+  void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes);
+  // Computes and returns an insertion point operation, before which the
+  // the fused <srcId, dstId> loop nest can be inserted while preserving
+  // dependences. Returns nullptr if no such insertion point is found.
+  Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId);
+  // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
+  // taking into account that:
+  //   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
+  //   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
+  //      private memref.
+  void updateEdges(unsigned srcId, unsigned dstId,
+                   const DenseSet<Value> &privateMemRefs, bool removeSrcId);
+  // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
+  // of sibling node 'sibId' into node 'dstId'.
+  void updateEdges(unsigned sibId, unsigned dstId);
+  // Adds ops in 'loads' and 'stores' to node at 'id'.
+  void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
+                 const SmallVectorImpl<Operation *> &stores);
+  void clearNodeLoadAndStores(unsigned id);
+  // Calls 'callback' for each input edge incident to node 'id' which carries a
+  // memref dependence.
+  void forEachMemRefInputEdge(unsigned id,
+                              const std::function<void(Edge)> &callback);
+  // Calls 'callback' for each output edge from node 'id' which carries a
+  // memref dependence.
+  void forEachMemRefOutputEdge(unsigned id,
+                               const std::function<void(Edge)> &callback);
+  // Calls 'callback' for each edge in 'edges' which carries a memref
+  // dependence.
+  void forEachMemRefEdge(ArrayRef<Edge> edges,
+                         const std::function<void(Edge)> &callback);
+  void print(raw_ostream &os) const;
+  void dump() const { print(llvm::errs()); }
+  /// The block for which this graph is created to perform fusion.
+  Block &block;
 /// Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered
 /// from the outermost 'affine.for' operation to the innermost one.
 void getAffineForIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops);
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -32,6 +32,475 @@
 using llvm::SmallDenseMap;
+using Node = MemRefDependenceGraph::Node;
+// LoopNestStateCollector walks loop nests and collects load and store
+// operations, and whether or not a region holding op other than ForOp and IfOp
+// was encountered in the loop nest.
+void LoopNestStateCollector::collect(Operation *opToWalk) {
+  opToWalk->walk([&](Operation *op) {
+    if (isa<AffineForOp>(op))
+      forOps.push_back(cast<AffineForOp>(op));
+    else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
+      hasNonAffineRegionOp = true;
+    else if (isa<AffineReadOpInterface>(op))
+      loadOpInsts.push_back(op);
+    else if (isa<AffineWriteOpInterface>(op))
+      storeOpInsts.push_back(op);
+  });
+// Returns the load op count for 'memref'.
+unsigned Node::getLoadOpCount(Value memref) const {
+  unsigned loadOpCount = 0;
+  for (Operation *loadOp : loads) {
+    if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
+      ++loadOpCount;
+  }
+  return loadOpCount;
+// Returns the store op count for 'memref'.
+unsigned Node::getStoreOpCount(Value memref) const {
+  unsigned storeOpCount = 0;
+  for (Operation *storeOp : stores) {
+    if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
+      ++storeOpCount;
+  }
+  return storeOpCount;
+// Returns all store ops in 'storeOps' which access 'memref'.
+void Node::getStoreOpsForMemref(Value memref,
+                                SmallVectorImpl<Operation *> *storeOps) const {
+  for (Operation *storeOp : stores) {
+    if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
+      storeOps->push_back(storeOp);
+  }
+// Returns all load ops in 'loadOps' which access 'memref'.
+void Node::getLoadOpsForMemref(Value memref,
+                               SmallVectorImpl<Operation *> *loadOps) const {
+  for (Operation *loadOp : loads) {
+    if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
+      loadOps->push_back(loadOp);
+  }
+// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
+// has at least one load and store operation.
+void Node::getLoadAndStoreMemrefSet(
+    DenseSet<Value> *loadAndStoreMemrefSet) const {
+  llvm::SmallDenseSet<Value, 2> loadMemrefs;
+  for (Operation *loadOp : loads) {
+    loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
+  }
+  for (Operation *storeOp : stores) {
+    auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
+    if (loadMemrefs.count(memref) > 0)
+      loadAndStoreMemrefSet->insert(memref);
+  }
+// Returns the graph node for 'id'.
+Node *MemRefDependenceGraph::getNode(unsigned id) {
+  auto it = nodes.find(id);
+  assert(it != nodes.end());
+  return &it->second;
+// Returns the graph node for 'forOp'.
+Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) {
+  for (auto &idAndNode : nodes)
+    if (idAndNode.second.op == forOp)
+      return &idAndNode.second;
+  return nullptr;
+// Adds a node with 'op' to the graph and returns its unique identifier.
+unsigned MemRefDependenceGraph::addNode(Operation *op) {
+  Node node(nextNodeId++, op);
+  nodes.insert({node.id, node});
+  return node.id;
+// Remove node 'id' (and its associated edges) from graph.
+void MemRefDependenceGraph::removeNode(unsigned id) {
+  // Remove each edge in 'inEdges[id]'.
+  if (inEdges.count(id) > 0) {
+    SmallVector<Edge, 2> oldInEdges = inEdges[id];
+    for (auto &inEdge : oldInEdges) {
+      removeEdge(inEdge.id, id, inEdge.value);
+    }
+  }
+  // Remove each edge in 'outEdges[id]'.
+  if (outEdges.count(id) > 0) {
+    SmallVector<Edge, 2> oldOutEdges = outEdges[id];
+    for (auto &outEdge : oldOutEdges) {
+      removeEdge(id, outEdge.id, outEdge.value);
+    }
+  }
+  // Erase remaining node state.
+  inEdges.erase(id);
+  outEdges.erase(id);
+  nodes.erase(id);
+// Returns true if node 'id' writes to any memref which escapes (or is an
+// argument to) the block. Returns false otherwise.
+bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) {
+  Node *node = getNode(id);
+  for (auto *storeOpInst : node->stores) {
+    auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
+    auto *op = memref.getDefiningOp();
+    // Return true if 'memref' is a block argument.
+    if (!op)
+      return true;
+    // Return true if any use of 'memref' does not deference it in an affine
+    // way.
+    for (auto *user : memref.getUsers())
+      if (!isa<AffineMapAccessInterface>(*user))
+        return true;
+  }
+  return false;
+// Returns true iff there is an edge from node 'srcId' to node 'dstId' which
+// is for 'value' if non-null, or for any value otherwise. Returns false
+// otherwise.
+bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId,
+                                    Value value) {
+  if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
+    return false;
+  }
+  bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
+    return edge.id == dstId && (!value || edge.value == value);
+  });
+  bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
+    return edge.id == srcId && (!value || edge.value == value);
+  });
+  return hasOutEdge && hasInEdge;
+// Adds an edge from node 'srcId' to node 'dstId' for 'value'.
+void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId,
+                                    Value value) {
+  if (!hasEdge(srcId, dstId, value)) {
+    outEdges[srcId].push_back({dstId, value});
+    inEdges[dstId].push_back({srcId, value});
+    if (value.getType().isa<MemRefType>())
+      memrefEdgeCount[value]++;
+  }
+// Removes an edge from node 'srcId' to node 'dstId' for 'value'.
+void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId,
+                                       Value value) {
+  assert(inEdges.count(dstId) > 0);
+  assert(outEdges.count(srcId) > 0);
+  if (value.getType().isa<MemRefType>()) {
+    assert(memrefEdgeCount.count(value) > 0);
+    memrefEdgeCount[value]--;
+  }
+  // Remove 'srcId' from 'inEdges[dstId]'.
+  for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
+    if ((*it).id == srcId && (*it).value == value) {
+      inEdges[dstId].erase(it);
+      break;
+    }
+  }
+  // Remove 'dstId' from 'outEdges[srcId]'.
+  for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) {
+    if ((*it).id == dstId && (*it).value == value) {
+      outEdges[srcId].erase(it);
+      break;
+    }
+  }
+// Returns true if there is a path in the dependence graph from node 'srcId'
+// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
+// operations that the edges connected are expected to be from the same block.
+bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) {
+  // Worklist state is: <node-id, next-output-edge-index-to-visit>
+  SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
+  worklist.push_back({srcId, 0});
+  Operation *dstOp = getNode(dstId)->op;
+  // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
+  while (!worklist.empty()) {
+    auto &idAndIndex = worklist.back();
+    // Return true if we have reached 'dstId'.
+    if (idAndIndex.first == dstId)
+      return true;
+    // Pop and continue if node has no out edges, or if all out edges have
+    // already been visited.
+    if (outEdges.count(idAndIndex.first) == 0 ||
+        idAndIndex.second == outEdges[idAndIndex.first].size()) {
+      worklist.pop_back();
+      continue;
+    }
+    // Get graph edge to traverse.
+    Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
+    // Increment next output edge index for 'idAndIndex'.
+    ++idAndIndex.second;
+    // Add node at 'edge.id' to the worklist. We don't need to consider
+    // nodes that are "after" dstId in the containing block; one can't have a
+    // path to `dstId` from any of those nodes.
+    bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
+    if (!afterDst && edge.id != idAndIndex.first)
+      worklist.push_back({edge.id, 0});
+  }
+  return false;
+// Returns the input edge count for node 'id' and 'memref' from src nodes
+// which access 'memref' with a store operation.
+unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id,
+                                                          Value memref) {
+  unsigned inEdgeCount = 0;
+  if (inEdges.count(id) > 0)
+    for (auto &inEdge : inEdges[id])
+      if (inEdge.value == memref) {
+        Node *srcNode = getNode(inEdge.id);
+        // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
+        if (srcNode->getStoreOpCount(memref) > 0)
+          ++inEdgeCount;
+      }
+  return inEdgeCount;
+// Returns the output edge count for node 'id' and 'memref' (if non-null),
+// otherwise returns the total output edge count from node 'id'.
+unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) {
+  unsigned outEdgeCount = 0;
+  if (outEdges.count(id) > 0)
+    for (auto &outEdge : outEdges[id])
+      if (!memref || outEdge.value == memref)
+        ++outEdgeCount;
+  return outEdgeCount;
+/// Return all nodes which define SSA values used in node 'id'.
+void MemRefDependenceGraph::gatherDefiningNodes(
+    unsigned id, DenseSet<unsigned> &definingNodes) {
+  for (MemRefDependenceGraph::Edge edge : inEdges[id])
+    // By definition of edge, if the edge value is a non-memref value,
+    // then the dependence is between a graph node which defines an SSA value
+    // and another graph node which uses the SSA value.
+    if (!edge.value.getType().isa<MemRefType>())
+      definingNodes.insert(edge.id);
+// Computes and returns an insertion point operation, before which the
+// the fused <srcId, dstId> loop nest can be inserted while preserving
+// dependences. Returns nullptr if no such insertion point is found.
+Operation *
+MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId,
+                                                      unsigned dstId) {
+  if (outEdges.count(srcId) == 0)
+    return getNode(dstId)->op;
+  // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
+  DenseSet<unsigned> definingNodes;
+  gatherDefiningNodes(dstId, definingNodes);
+  if (llvm::any_of(definingNodes,
+                   [&](unsigned id) { return hasDependencePath(srcId, id); })) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Can't fuse: a defining op with a user in the dst "
+                  "loop has dependence from the src loop\n");
+    return nullptr;
+  }
+  // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
+  SmallPtrSet<Operation *, 2> srcDepInsts;
+  for (auto &outEdge : outEdges[srcId])
+    if (outEdge.id != dstId)
+      srcDepInsts.insert(getNode(outEdge.id)->op);
+  // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
+  SmallPtrSet<Operation *, 2> dstDepInsts;
+  for (auto &inEdge : inEdges[dstId])
+    if (inEdge.id != srcId)
+      dstDepInsts.insert(getNode(inEdge.id)->op);
+  Operation *srcNodeInst = getNode(srcId)->op;
+  Operation *dstNodeInst = getNode(dstId)->op;
+  // Computing insertion point:
+  // *) Walk all operation positions in Block operation list in the
+  //    range (src, dst). For each operation 'op' visited in this search:
+  //   *) Store in 'firstSrcDepPos' the first position where 'op' has a
+  //      dependence edge from 'srcNode'.
+  //   *) Store in 'lastDstDepPost' the last position where 'op' has a
+  //      dependence edge to 'dstNode'.
+  // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
+  //    operation insertion point (or return null pointer if no such
+  //    insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
+  SmallVector<Operation *, 2> depInsts;
+  std::optional<unsigned> firstSrcDepPos;
+  std::optional<unsigned> lastDstDepPos;
+  unsigned pos = 0;
+  for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
+       it != Block::iterator(dstNodeInst); ++it) {
+    Operation *op = &(*it);
+    if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
+      firstSrcDepPos = pos;
+    if (dstDepInsts.count(op) > 0)
+      lastDstDepPos = pos;
+    depInsts.push_back(op);
+    ++pos;
+  }
+  if (firstSrcDepPos.has_value()) {
+    if (lastDstDepPos.has_value()) {
+      if (*firstSrcDepPos <= *lastDstDepPos) {
+        // No valid insertion point exists which preserves dependences.
+        return nullptr;
+      }
+    }
+    // Return the insertion point at 'firstSrcDepPos'.
+    return depInsts[*firstSrcDepPos];
+  }
+  // No dependence targets in range (or only dst deps in range), return
+  // 'dstNodInst' insertion point.
+  return dstNodeInst;
+// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
+// taking into account that:
+//   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
+//   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
+//      private memref.
+void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId,
+                                        const DenseSet<Value> &privateMemRefs,
+                                        bool removeSrcId) {
+  // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
+  if (inEdges.count(srcId) > 0) {
+    SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
+    for (auto &inEdge : oldInEdges) {
+      // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
+      if (privateMemRefs.count(inEdge.value) == 0)
+        addEdge(inEdge.id, dstId, inEdge.value);
+    }
+  }
+  // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
+  // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
+  if (outEdges.count(srcId) > 0) {
+    SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
+    for (auto &outEdge : oldOutEdges) {
+      // Remove any out edges from 'srcId' to 'dstId' across memrefs.
+      if (outEdge.id == dstId)
+        removeEdge(srcId, outEdge.id, outEdge.value);
+      else if (removeSrcId) {
+        addEdge(dstId, outEdge.id, outEdge.value);
+        removeEdge(srcId, outEdge.id, outEdge.value);
+      }
+    }
+  }
+  // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
+  // replaced by a private memref). These edges could come from nodes
+  // other than 'srcId' which were removed in the previous step.
+  if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
+    SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
+    for (auto &inEdge : oldInEdges)
+      if (privateMemRefs.count(inEdge.value) > 0)
+        removeEdge(inEdge.id, dstId, inEdge.value);
+  }
+// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
+// of sibling node 'sibId' into node 'dstId'.
+void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) {
+  // For each edge in 'inEdges[sibId]':
+  // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
+  // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
+  if (inEdges.count(sibId) > 0) {
+    SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
+    for (auto &inEdge : oldInEdges) {
+      addEdge(inEdge.id, dstId, inEdge.value);
+      removeEdge(inEdge.id, sibId, inEdge.value);
+    }
+  }
+  // For each edge in 'outEdges[sibId]' to node 'id'
+  // *) Add new edge from 'dstId' to 'outEdge.id'.
+  // *) Remove edge from 'sibId' to 'outEdge.id'.
+  if (outEdges.count(sibId) > 0) {
+    SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
+    for (auto &outEdge : oldOutEdges) {
+      addEdge(dstId, outEdge.id, outEdge.value);
+      removeEdge(sibId, outEdge.id, outEdge.value);
+    }
+  }
+// Adds ops in 'loads' and 'stores' to node at 'id'.
+void MemRefDependenceGraph::addToNode(
+    unsigned id, const SmallVectorImpl<Operation *> &loads,
+    const SmallVectorImpl<Operation *> &stores) {
+  Node *node = getNode(id);
+  llvm::append_range(node->loads, loads);
+  llvm::append_range(node->stores, stores);
+void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) {
+  Node *node = getNode(id);
+  node->loads.clear();
+  node->stores.clear();
+// Calls 'callback' for each input edge incident to node 'id' which carries a
+// memref dependence.
+void MemRefDependenceGraph::forEachMemRefInputEdge(
+    unsigned id, const std::function<void(Edge)> &callback) {
+  if (inEdges.count(id) > 0)
+    forEachMemRefEdge(inEdges[id], callback);
+// Calls 'callback' for each output edge from node 'id' which carries a
+// memref dependence.
+void MemRefDependenceGraph::forEachMemRefOutputEdge(
+    unsigned id, const std::function<void(Edge)> &callback) {
+  if (outEdges.count(id) > 0)
+    forEachMemRefEdge(outEdges[id], callback);
+// Calls 'callback' for each edge in 'edges' which carries a memref
+// dependence.
+void MemRefDependenceGraph::forEachMemRefEdge(
+    ArrayRef<Edge> edges, const std::function<void(Edge)> &callback) {
+  for (const auto &edge : edges) {
+    // Skip if 'edge' is not a memref dependence edge.
+    if (!edge.value.getType().isa<MemRefType>())
+      continue;
+    assert(nodes.count(edge.id) > 0);
+    // Skip if 'edge.id' is not a loop nest.
+    if (!isa<AffineForOp>(getNode(edge.id)->op))
+      continue;
+    // Visit current input edge 'edge'.
+    callback(edge);
+  }
+void MemRefDependenceGraph::print(raw_ostream &os) const {
+  os << "\nMemRefDependenceGraph\n";
+  os << "\nNodes:\n";
+  for (const auto &idAndNode : nodes) {
+    os << "Node: " << idAndNode.first << "\n";
+    auto it = inEdges.find(idAndNode.first);
+    if (it != inEdges.end()) {
+      for (const auto &e : it->second)
+        os << "  InEdge: " << e.id << " " << e.value << "\n";
+    }
+    it = outEdges.find(idAndNode.first);
+    if (it != outEdges.end()) {
+      for (const auto &e : it->second)
+        os << "  OutEdge: " << e.id << " " << e.value << "\n";
+    }
+  }
 void mlir::getAffineForIVs(Operation &op, SmallVectorImpl<AffineForOp> *loops) {
   auto *currOp = op.getParentOp();
   AffineForOp currAffineForOp;
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -70,546 +70,6 @@
 } // namespace
-mlir::createLoopFusionPass(unsigned fastMemorySpace,
-                           uint64_t localBufSizeThreshold, bool maximalFusion,
-                           enum FusionMode affineFusionMode) {
-  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
-                                      maximalFusion, affineFusionMode);
-namespace {
-// LoopNestStateCollector walks loop nests and collects load and store
-// operations, and whether or not a region holding op other than ForOp and IfOp
-// was encountered in the loop nest.
-struct LoopNestStateCollector {
-  SmallVector<AffineForOp, 4> forOps;
-  SmallVector<Operation *, 4> loadOpInsts;
-  SmallVector<Operation *, 4> storeOpInsts;
-  bool hasNonAffineRegionOp = false;
-  void collect(Operation *opToWalk) {
-    opToWalk->walk([&](Operation *op) {
-      if (isa<AffineForOp>(op))
-        forOps.push_back(cast<AffineForOp>(op));
-      else if (op->getNumRegions() != 0 && !isa<AffineIfOp>(op))
-        hasNonAffineRegionOp = true;
-      else if (isa<AffineReadOpInterface>(op))
-        loadOpInsts.push_back(op);
-      else if (isa<AffineWriteOpInterface>(op))
-        storeOpInsts.push_back(op);
-    });
-  }
-// MemRefDependenceGraph is a graph data structure where graph nodes are
-// top-level operations in a `Block` which contain load/store ops, and edges
-// are memref dependences between the nodes.
-// TODO: Add a more flexible dependence graph representation.
-// TODO: Add a depth parameter to dependence graph construction.
-struct MemRefDependenceGraph {
-  // Node represents a node in the graph. A Node is either an entire loop nest
-  // rooted at the top level which contains loads/stores, or a top level
-  // load/store.
-  struct Node {
-    // The unique identifier of this node in the graph.
-    unsigned id;
-    // The top-level statement which is (or contains) a load/store.
-    Operation *op;
-    // List of load operations.
-    SmallVector<Operation *, 4> loads;
-    // List of store op insts.
-    SmallVector<Operation *, 4> stores;
-    Node(unsigned id, Operation *op) : id(id), op(op) {}
-    // Returns the load op count for 'memref'.
-    unsigned getLoadOpCount(Value memref) const {
-      unsigned loadOpCount = 0;
-      for (Operation *loadOp : loads) {
-        if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
-          ++loadOpCount;
-      }
-      return loadOpCount;
-    }
-    // Returns the store op count for 'memref'.
-    unsigned getStoreOpCount(Value memref) const {
-      unsigned storeOpCount = 0;
-      for (Operation *storeOp : stores) {
-        if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
-          ++storeOpCount;
-      }
-      return storeOpCount;
-    }
-    // Returns all store ops in 'storeOps' which access 'memref'.
-    void getStoreOpsForMemref(Value memref,
-                              SmallVectorImpl<Operation *> *storeOps) const {
-      for (Operation *storeOp : stores) {
-        if (memref == cast<AffineWriteOpInterface>(storeOp).getMemRef())
-          storeOps->push_back(storeOp);
-      }
-    }
-    // Returns all load ops in 'loadOps' which access 'memref'.
-    void getLoadOpsForMemref(Value memref,
-                             SmallVectorImpl<Operation *> *loadOps) const {
-      for (Operation *loadOp : loads) {
-        if (memref == cast<AffineReadOpInterface>(loadOp).getMemRef())
-          loadOps->push_back(loadOp);
-      }
-    }
-    // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node
-    // has at least one load and store operation.
-    void
-    getLoadAndStoreMemrefSet(DenseSet<Value> *loadAndStoreMemrefSet) const {
-      llvm::SmallDenseSet<Value, 2> loadMemrefs;
-      for (Operation *loadOp : loads) {
-        loadMemrefs.insert(cast<AffineReadOpInterface>(loadOp).getMemRef());
-      }
-      for (Operation *storeOp : stores) {
-        auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
-        if (loadMemrefs.count(memref) > 0)
-          loadAndStoreMemrefSet->insert(memref);
-      }
-    }
-  };
-  // Edge represents a data dependence between nodes in the graph.
-  struct Edge {
-    // The id of the node at the other end of the edge.
-    // If this edge is stored in Edge = Node.inEdges[i], then
-    // 'Node.inEdges[i].id' is the identifier of the source node of the edge.
-    // If this edge is stored in Edge = Node.outEdges[i], then
-    // 'Node.outEdges[i].id' is the identifier of the dest node of the edge.
-    unsigned id;
-    // The SSA value on which this edge represents a dependence.
-    // If the value is a memref, then the dependence is between graph nodes
-    // which contain accesses to the same memref 'value'. If the value is a
-    // non-memref value, then the dependence is between a graph node which
-    // defines an SSA value and another graph node which uses the SSA value
-    // (e.g. a constant or load operation defining a value which is used inside
-    // a loop nest).
-    Value value;
-  };
-  // Map from node id to Node.
-  DenseMap<unsigned, Node> nodes;
-  // Map from node id to list of input edges.
-  DenseMap<unsigned, SmallVector<Edge, 2>> inEdges;
-  // Map from node id to list of output edges.
-  DenseMap<unsigned, SmallVector<Edge, 2>> outEdges;
-  // Map from memref to a count on the dependence edges associated with that
-  // memref.
-  DenseMap<Value, unsigned> memrefEdgeCount;
-  // The next unique identifier to use for newly created graph nodes.
-  unsigned nextNodeId = 0;
-  MemRefDependenceGraph(Block &block) : block(block) {}
-  // Initializes the dependence graph based on operations in `block'.
-  // Returns true on success, false otherwise.
-  bool init();
-  // Returns the graph node for 'id'.
-  Node *getNode(unsigned id) {
-    auto it = nodes.find(id);
-    assert(it != nodes.end());
-    return &it->second;
-  }
-  // Returns the graph node for 'forOp'.
-  Node *getForOpNode(AffineForOp forOp) {
-    for (auto &idAndNode : nodes)
-      if (idAndNode.second.op == forOp)
-        return &idAndNode.second;
-    return nullptr;
-  }
-  // Adds a node with 'op' to the graph and returns its unique identifier.
-  unsigned addNode(Operation *op) {
-    Node node(nextNodeId++, op);
-    nodes.insert({node.id, node});
-    return node.id;
-  }
-  // Remove node 'id' (and its associated edges) from graph.
-  void removeNode(unsigned id) {
-    // Remove each edge in 'inEdges[id]'.
-    if (inEdges.count(id) > 0) {
-      SmallVector<Edge, 2> oldInEdges = inEdges[id];
-      for (auto &inEdge : oldInEdges) {
-        removeEdge(inEdge.id, id, inEdge.value);
-      }
-    }
-    // Remove each edge in 'outEdges[id]'.
-    if (outEdges.count(id) > 0) {
-      SmallVector<Edge, 2> oldOutEdges = outEdges[id];
-      for (auto &outEdge : oldOutEdges) {
-        removeEdge(id, outEdge.id, outEdge.value);
-      }
-    }
-    // Erase remaining node state.
-    inEdges.erase(id);
-    outEdges.erase(id);
-    nodes.erase(id);
-  }
-  // Returns true if node 'id' writes to any memref which escapes (or is an
-  // argument to) the block. Returns false otherwise.
-  bool writesToLiveInOrEscapingMemrefs(unsigned id) {
-    Node *node = getNode(id);
-    for (auto *storeOpInst : node->stores) {
-      auto memref = cast<AffineWriteOpInterface>(storeOpInst).getMemRef();
-      auto *op = memref.getDefiningOp();
-      // Return true if 'memref' is a block argument.
-      if (!op)
-        return true;
-      // Return true if any use of 'memref' does not deference it in an affine
-      // way.
-      for (auto *user : memref.getUsers())
-        if (!isa<AffineMapAccessInterface>(*user))
-          return true;
-    }
-    return false;
-  }
-  // Returns true iff there is an edge from node 'srcId' to node 'dstId' which
-  // is for 'value' if non-null, or for any value otherwise. Returns false
-  // otherwise.
-  bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) {
-    if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) {
-      return false;
-    }
-    bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) {
-      return edge.id == dstId && (!value || edge.value == value);
-    });
-    bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) {
-      return edge.id == srcId && (!value || edge.value == value);
-    });
-    return hasOutEdge && hasInEdge;
-  }
-  // Adds an edge from node 'srcId' to node 'dstId' for 'value'.
-  void addEdge(unsigned srcId, unsigned dstId, Value value) {
-    if (!hasEdge(srcId, dstId, value)) {
-      outEdges[srcId].push_back({dstId, value});
-      inEdges[dstId].push_back({srcId, value});
-      if (value.getType().isa<MemRefType>())
-        memrefEdgeCount[value]++;
-    }
-  }
-  // Removes an edge from node 'srcId' to node 'dstId' for 'value'.
-  void removeEdge(unsigned srcId, unsigned dstId, Value value) {
-    assert(inEdges.count(dstId) > 0);
-    assert(outEdges.count(srcId) > 0);
-    if (value.getType().isa<MemRefType>()) {
-      assert(memrefEdgeCount.count(value) > 0);
-      memrefEdgeCount[value]--;
-    }
-    // Remove 'srcId' from 'inEdges[dstId]'.
-    for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) {
-      if ((*it).id == srcId && (*it).value == value) {
-        inEdges[dstId].erase(it);
-        break;
-      }
-    }
-    // Remove 'dstId' from 'outEdges[srcId]'.
-    for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end();
-         ++it) {
-      if ((*it).id == dstId && (*it).value == value) {
-        outEdges[srcId].erase(it);
-        break;
-      }
-    }
-  }
-  // Returns true if there is a path in the dependence graph from node 'srcId'
-  // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the
-  // operations that the edges connected are expected to be from the same block.
-  bool hasDependencePath(unsigned srcId, unsigned dstId) {
-    // Worklist state is: <node-id, next-output-edge-index-to-visit>
-    SmallVector<std::pair<unsigned, unsigned>, 4> worklist;
-    worklist.push_back({srcId, 0});
-    Operation *dstOp = getNode(dstId)->op;
-    // Run DFS traversal to see if 'dstId' is reachable from 'srcId'.
-    while (!worklist.empty()) {
-      auto &idAndIndex = worklist.back();
-      // Return true if we have reached 'dstId'.
-      if (idAndIndex.first == dstId)
-        return true;
-      // Pop and continue if node has no out edges, or if all out edges have
-      // already been visited.
-      if (outEdges.count(idAndIndex.first) == 0 ||
-          idAndIndex.second == outEdges[idAndIndex.first].size()) {
-        worklist.pop_back();
-        continue;
-      }
-      // Get graph edge to traverse.
-      Edge edge = outEdges[idAndIndex.first][idAndIndex.second];
-      // Increment next output edge index for 'idAndIndex'.
-      ++idAndIndex.second;
-      // Add node at 'edge.id' to the worklist. We don't need to consider
-      // nodes that are "after" dstId in the containing block; one can't have a
-      // path to `dstId` from any of those nodes.
-      bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op);
-      if (!afterDst && edge.id != idAndIndex.first)
-        worklist.push_back({edge.id, 0});
-    }
-    return false;
-  }
-  // Returns the input edge count for node 'id' and 'memref' from src nodes
-  // which access 'memref' with a store operation.
-  unsigned getIncomingMemRefAccesses(unsigned id, Value memref) {
-    unsigned inEdgeCount = 0;
-    if (inEdges.count(id) > 0)
-      for (auto &inEdge : inEdges[id])
-        if (inEdge.value == memref) {
-          Node *srcNode = getNode(inEdge.id);
-          // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref'
-          if (srcNode->getStoreOpCount(memref) > 0)
-            ++inEdgeCount;
-        }
-    return inEdgeCount;
-  }
-  // Returns the output edge count for node 'id' and 'memref' (if non-null),
-  // otherwise returns the total output edge count from node 'id'.
-  unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) {
-    unsigned outEdgeCount = 0;
-    if (outEdges.count(id) > 0)
-      for (auto &outEdge : outEdges[id])
-        if (!memref || outEdge.value == memref)
-          ++outEdgeCount;
-    return outEdgeCount;
-  }
-  /// Return all nodes which define SSA values used in node 'id'.
-  void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes) {
-    for (MemRefDependenceGraph::Edge edge : inEdges[id])
-      // By definition of edge, if the edge value is a non-memref value,
-      // then the dependence is between a graph node which defines an SSA value
-      // and another graph node which uses the SSA value.
-      if (!edge.value.getType().isa<MemRefType>())
-        definingNodes.insert(edge.id);
-  }
-  // Computes and returns an insertion point operation, before which the
-  // the fused <srcId, dstId> loop nest can be inserted while preserving
-  // dependences. Returns nullptr if no such insertion point is found.
-  Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
-    if (outEdges.count(srcId) == 0)
-      return getNode(dstId)->op;
-    // Skip if there is any defining node of 'dstId' that depends on 'srcId'.
-    DenseSet<unsigned> definingNodes;
-    gatherDefiningNodes(dstId, definingNodes);
-    if (llvm::any_of(definingNodes, [&](unsigned id) {
-          return hasDependencePath(srcId, id);
-        })) {
-      LLVM_DEBUG(llvm::dbgs()
-                 << "Can't fuse: a defining op with a user in the dst "
-                    "loop has dependence from the src loop\n");
-      return nullptr;
-    }
-    // Build set of insts in range (srcId, dstId) which depend on 'srcId'.
-    SmallPtrSet<Operation *, 2> srcDepInsts;
-    for (auto &outEdge : outEdges[srcId])
-      if (outEdge.id != dstId)
-        srcDepInsts.insert(getNode(outEdge.id)->op);
-    // Build set of insts in range (srcId, dstId) on which 'dstId' depends.
-    SmallPtrSet<Operation *, 2> dstDepInsts;
-    for (auto &inEdge : inEdges[dstId])
-      if (inEdge.id != srcId)
-        dstDepInsts.insert(getNode(inEdge.id)->op);
-    Operation *srcNodeInst = getNode(srcId)->op;
-    Operation *dstNodeInst = getNode(dstId)->op;
-    // Computing insertion point:
-    // *) Walk all operation positions in Block operation list in the
-    //    range (src, dst). For each operation 'op' visited in this search:
-    //   *) Store in 'firstSrcDepPos' the first position where 'op' has a
-    //      dependence edge from 'srcNode'.
-    //   *) Store in 'lastDstDepPost' the last position where 'op' has a
-    //      dependence edge to 'dstNode'.
-    // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the
-    //    operation insertion point (or return null pointer if no such
-    //    insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos').
-    SmallVector<Operation *, 2> depInsts;
-    std::optional<unsigned> firstSrcDepPos;
-    std::optional<unsigned> lastDstDepPos;
-    unsigned pos = 0;
-    for (Block::iterator it = std::next(Block::iterator(srcNodeInst));
-         it != Block::iterator(dstNodeInst); ++it) {
-      Operation *op = &(*it);
-      if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt)
-        firstSrcDepPos = pos;
-      if (dstDepInsts.count(op) > 0)
-        lastDstDepPos = pos;
-      depInsts.push_back(op);
-      ++pos;
-    }
-    if (firstSrcDepPos.has_value()) {
-      if (lastDstDepPos.has_value()) {
-        if (*firstSrcDepPos <= *lastDstDepPos) {
-          // No valid insertion point exists which preserves dependences.
-          return nullptr;
-        }
-      }
-      // Return the insertion point at 'firstSrcDepPos'.
-      return depInsts[*firstSrcDepPos];
-    }
-    // No dependence targets in range (or only dst deps in range), return
-    // 'dstNodInst' insertion point.
-    return dstNodeInst;
-  }
-  // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them,
-  // taking into account that:
-  //   *) if 'removeSrcId' is true, 'srcId' will be removed after fusion,
-  //   *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a
-  //      private memref.
-  void updateEdges(unsigned srcId, unsigned dstId,
-                   const DenseSet<Value> &privateMemRefs, bool removeSrcId) {
-    // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'.
-    if (inEdges.count(srcId) > 0) {
-      SmallVector<Edge, 2> oldInEdges = inEdges[srcId];
-      for (auto &inEdge : oldInEdges) {
-        // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref.
-        if (privateMemRefs.count(inEdge.value) == 0)
-          addEdge(inEdge.id, dstId, inEdge.value);
-      }
-    }
-    // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'.
-    // If 'srcId' is going to be removed, remap all the out edges to 'dstId'.
-    if (outEdges.count(srcId) > 0) {
-      SmallVector<Edge, 2> oldOutEdges = outEdges[srcId];
-      for (auto &outEdge : oldOutEdges) {
-        // Remove any out edges from 'srcId' to 'dstId' across memrefs.
-        if (outEdge.id == dstId)
-          removeEdge(srcId, outEdge.id, outEdge.value);
-        else if (removeSrcId) {
-          addEdge(dstId, outEdge.id, outEdge.value);
-          removeEdge(srcId, outEdge.id, outEdge.value);
-        }
-      }
-    }
-    // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being
-    // replaced by a private memref). These edges could come from nodes
-    // other than 'srcId' which were removed in the previous step.
-    if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) {
-      SmallVector<Edge, 2> oldInEdges = inEdges[dstId];
-      for (auto &inEdge : oldInEdges)
-        if (privateMemRefs.count(inEdge.value) > 0)
-          removeEdge(inEdge.id, dstId, inEdge.value);
-    }
-  }
-  // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion
-  // of sibling node 'sibId' into node 'dstId'.
-  void updateEdges(unsigned sibId, unsigned dstId) {
-    // For each edge in 'inEdges[sibId]':
-    // *) Add new edge from source node 'inEdge.id' to 'dstNode'.
-    // *) Remove edge from source node 'inEdge.id' to 'sibNode'.
-    if (inEdges.count(sibId) > 0) {
-      SmallVector<Edge, 2> oldInEdges = inEdges[sibId];
-      for (auto &inEdge : oldInEdges) {
-        addEdge(inEdge.id, dstId, inEdge.value);
-        removeEdge(inEdge.id, sibId, inEdge.value);
-      }
-    }
-    // For each edge in 'outEdges[sibId]' to node 'id'
-    // *) Add new edge from 'dstId' to 'outEdge.id'.
-    // *) Remove edge from 'sibId' to 'outEdge.id'.
-    if (outEdges.count(sibId) > 0) {
-      SmallVector<Edge, 2> oldOutEdges = outEdges[sibId];
-      for (auto &outEdge : oldOutEdges) {
-        addEdge(dstId, outEdge.id, outEdge.value);
-        removeEdge(sibId, outEdge.id, outEdge.value);
-      }
-    }
-  }
-  // Adds ops in 'loads' and 'stores' to node at 'id'.
-  void addToNode(unsigned id, const SmallVectorImpl<Operation *> &loads,
-                 const SmallVectorImpl<Operation *> &stores) {
-    Node *node = getNode(id);
-    llvm::append_range(node->loads, loads);
-    llvm::append_range(node->stores, stores);
-  }
-  void clearNodeLoadAndStores(unsigned id) {
-    Node *node = getNode(id);
-    node->loads.clear();
-    node->stores.clear();
-  }
-  // Calls 'callback' for each input edge incident to node 'id' which carries a
-  // memref dependence.
-  void forEachMemRefInputEdge(unsigned id,
-                              const std::function<void(Edge)> &callback) {
-    if (inEdges.count(id) > 0)
-      forEachMemRefEdge(inEdges[id], callback);
-  }
-  // Calls 'callback' for each output edge from node 'id' which carries a
-  // memref dependence.
-  void forEachMemRefOutputEdge(unsigned id,
-                               const std::function<void(Edge)> &callback) {
-    if (outEdges.count(id) > 0)
-      forEachMemRefEdge(outEdges[id], callback);
-  }
-  // Calls 'callback' for each edge in 'edges' which carries a memref
-  // dependence.
-  void forEachMemRefEdge(ArrayRef<Edge> edges,
-                         const std::function<void(Edge)> &callback) {
-    for (const auto &edge : edges) {
-      // Skip if 'edge' is not a memref dependence edge.
-      if (!edge.value.getType().isa<MemRefType>())
-        continue;
-      assert(nodes.count(edge.id) > 0);
-      // Skip if 'edge.id' is not a loop nest.
-      if (!isa<AffineForOp>(getNode(edge.id)->op))
-        continue;
-      // Visit current input edge 'edge'.
-      callback(edge);
-    }
-  }
-  void print(raw_ostream &os) const {
-    os << "\nMemRefDependenceGraph\n";
-    os << "\nNodes:\n";
-    for (const auto &idAndNode : nodes) {
-      os << "Node: " << idAndNode.first << "\n";
-      auto it = inEdges.find(idAndNode.first);
-      if (it != inEdges.end()) {
-        for (const auto &e : it->second)
-          os << "  InEdge: " << e.id << " " << e.value << "\n";
-      }
-      it = outEdges.find(idAndNode.first);
-      if (it != outEdges.end()) {
-        for (const auto &e : it->second)
-          os << "  OutEdge: " << e.id << " " << e.value << "\n";
-      }
-    }
-  }
-  void dump() const { print(llvm::errs()); }
-  /// The block for which this graph is created to perform fusion.
-  Block &block;
 /// Returns true if node 'srcId' can be removed after fusing it with node
 /// 'dstId'. The node can be removed if any of the following conditions are met:
 ///   1. 'srcId' has no output dependences after fusion and no escaping memrefs.
@@ -755,8 +215,8 @@
 /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id'
 /// that escape the block or are accessed in a non-affine way.
-void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
-                           DenseSet<Value> &escapingMemRefs) {
+static void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg,
+                                  DenseSet<Value> &escapingMemRefs) {
   auto *node = mdg->getNode(id);
   for (Operation *storeOp : node->stores) {
     auto memref = cast<AffineWriteOpInterface>(storeOp).getMemRef();
@@ -767,8 +227,6 @@
-} // namespace
 // Initializes the data dependence graph by walking operations in `block`.
 // Assigns each node in the graph a node id based on program order in 'f'.
 bool MemRefDependenceGraph::init() {
@@ -2042,3 +1500,11 @@
     for (Block &block : region.getBlocks())
+mlir::createLoopFusionPass(unsigned fastMemorySpace,
+                           uint64_t localBufSizeThreshold, bool maximalFusion,
+                           enum FusionMode affineFusionMode) {
+  return std::make_unique<LoopFusion>(fastMemorySpace, localBufSizeThreshold,
+                                      maximalFusion, affineFusionMode);