diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -125,7 +125,7 @@ Node(unsigned id, Operation *op) : id(id), op(op) {} // Returns the load op count for 'memref'. - unsigned getLoadOpCount(Value memref) { + unsigned getLoadOpCount(Value memref) const { unsigned loadOpCount = 0; for (auto *loadOpInst : loads) { if (memref == cast(loadOpInst).getMemRef()) @@ -135,7 +135,7 @@ } // Returns the store op count for 'memref'. - unsigned getStoreOpCount(Value memref) { + unsigned getStoreOpCount(Value memref) const { unsigned storeOpCount = 0; for (auto *storeOpInst : stores) { if (memref == cast(storeOpInst).getMemRef()) @@ -146,7 +146,7 @@ // Returns all store ops in 'storeOps' which access 'memref'. void getStoreOpsForMemref(Value memref, - SmallVectorImpl *storeOps) { + SmallVectorImpl *storeOps) const { for (auto *storeOpInst : stores) { if (memref == cast(storeOpInst).getMemRef()) storeOps->push_back(storeOpInst); @@ -155,7 +155,7 @@ // Returns all load ops in 'loadOps' which access 'memref'. void getLoadOpsForMemref(Value memref, - SmallVectorImpl *loadOps) { + SmallVectorImpl *loadOps) const { for (auto *loadOpInst : loads) { if (memref == cast(loadOpInst).getMemRef()) loadOps->push_back(loadOpInst); @@ -164,7 +164,8 @@ // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node // has at least one load and store operation. - void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) { + void + getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) const { llvm::SmallDenseSet loadMemrefs; for (auto *loadOpInst : loads) { loadMemrefs.insert(cast(loadOpInst).getMemRef()); @@ -744,14 +745,12 @@ // Check if 'memref' is used by a non-deferencing op (including unknown ones) // (e.g., call ops, alias creating ops, etc.). - for (Operation *user : memref.getUsers()) { + return llvm::any_of(memref.getUsers(), [&](Operation *user) { // Ignore users outside of `block`. if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block) - continue; - if (!isa(*user)) - return true; - } - return false; + return false; + return !isa(*user); + }); } /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' @@ -1076,10 +1075,9 @@ return WalkResult::advance(); }); // Looking for users between node 'srcId' and node 'dstId'. - for (Value memref : memRefValues) - if (hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg)) - return true; - return false; + return llvm::any_of(memRefValues, [&](Value memref) { + return hasNonAffineUsersOnThePath(srcId, dstId, memref, mdg); + }); } // Checks the profitability of fusing a backwards slice of the loop nest @@ -1452,280 +1450,301 @@ eraseUnusedMemRefAllocations(); } - /// Visit each node in the graph, and for each node, attempt to fuse it with - /// producer-consumer candidates. No fusion is performed when producers with a - /// user count greater than `maxSrcUserCount` for any of the memrefs involved - /// are encountered. - void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { - LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); - init(); - while (!worklist.empty()) { - unsigned dstId = worklist.back(); - worklist.pop_back(); + /// Returns true if a private memref can be created for `memref` given + /// the fusion scenario reflected by the other arguments. + bool canCreatePrivateMemRef(Value memref, + const DenseSet &srcEscapingMemRefs, + unsigned producerId, unsigned consumerId, + bool removeSrcNode) { + const Node *consumerNode = mdg->getNode(consumerId); + // If `memref` is an escaping one, do not create a private memref + // for the below scenarios, since doing so will leave the escaping + // memref unmodified as all the writes originally meant for the + // escaping memref would be performed on the private memref: + // 1. The source is to be removed after fusion, + // OR + // 2. The destination writes to `memref`. + if (srcEscapingMemRefs.count(memref) > 0 && + (removeSrcNode || consumerNode->getStoreOpCount(memref) > 0)) + return false; - // Skip if this node was removed (fused into another node). - if (mdg->nodes.count(dstId) == 0) - continue; - // Get 'dstNode' into which to attempt fusion. - auto *dstNode = mdg->getNode(dstId); - // Skip if 'dstNode' is not a loop nest. - if (!isa(dstNode->op)) - continue; - // Skip if 'dstNode' is a loop nest returning values. - // TODO: support loop nests that return values. - if (dstNode->op->getNumResults() > 0) - continue; + // Don't create a private memref if 'srcNode' has in edges on + // 'memref' or 'dstNode' has out edges on 'memref'. + if (mdg->getIncomingMemRefAccesses(producerId, memref) > 0 || + mdg->getOutEdgeCount(consumerId, memref) > 0) + return false; - LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); - - // Sink sequential loops in 'dstNode' (and thus raise parallel loops) - // while preserving relative order. This can increase the maximum loop - // depth at which we can fuse a slice of a producer loop nest into a - // consumer loop nest. - sinkSequentialLoops(dstNode); - auto dstAffineForOp = cast(dstNode->op); - - // Try to fuse 'dstNode' with candidate producer loops until a fixed point - // is reached. Fusing two loops may expose new fusion opportunities. - bool dstNodeChanged; - do { - // Gather src loop candidates for 'dstNode' and visit them in "quasi" - // reverse program order to minimize the number of iterations needed to - // reach the fixed point. Note that this is a best effort approach since - // 'getProducerCandidates' does not always guarantee that program order - // in 'srcIdCandidates'. - dstNodeChanged = false; - SmallVector srcIdCandidates; - getProducerCandidates(dstId, mdg, srcIdCandidates); - - for (unsigned srcId : llvm::reverse(srcIdCandidates)) { - // Get 'srcNode' from which to attempt fusion into 'dstNode'. - auto *srcNode = mdg->getNode(srcId); - auto srcAffineForOp = cast(srcNode->op); - LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId - << " for dst loop " << dstId << "\n"); - - // Skip if 'srcNode' is a loop nest returning values. - // TODO: support loop nests that return values. - if (isa(srcNode->op) && srcNode->op->getNumResults() > 0) - continue; + // If 'srcNode' will be removed but it has out edges on 'memref' to + // nodes other than 'dstNode', we have to preserve dependences and + // cannot create a private memref. + if (removeSrcNode && + any_of(mdg->outEdges[producerId], [&](const auto &edge) { + return edge.value == memref && edge.id != consumerId; + })) + return false; - DenseSet producerConsumerMemrefs; - gatherProducerConsumerMemrefs(srcId, dstId, mdg, - producerConsumerMemrefs); + return true; + } - // Skip if 'srcNode' out edge count on any memref is greater than - // 'maxSrcUserCount'. - if (any_of(producerConsumerMemrefs, [&](Value memref) { - return mdg->getOutEdgeCount(srcNode->id, memref) > - maxSrcUserCount; - })) - continue; + /// Perform fusions with node `dstId` as the destination of fusion, with + /// No fusion is performed when producers with a user count greater than + /// `maxSrcUserCount` for any of the memrefs involved. + void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) { + LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); + // Skip if this node was removed (fused into another node). + if (mdg->nodes.count(dstId) == 0) + return; + // Get 'dstNode' into which to attempt fusion. + auto *dstNode = mdg->getNode(dstId); + // Skip if 'dstNode' is not a loop nest. + if (!isa(dstNode->op)) + return; + // Skip if 'dstNode' is a loop nest returning values. + // TODO: support loop nests that return values. + if (dstNode->op->getNumResults() > 0) + return; + + LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); + + // Sink sequential loops in 'dstNode' (and thus raise parallel loops) + // while preserving relative order. This can increase the maximum loop + // depth at which we can fuse a slice of a producer loop nest into a + // consumer loop nest. + sinkSequentialLoops(dstNode); + auto dstAffineForOp = cast(dstNode->op); - // Gather memrefs in 'srcNode' that are written and escape out of the - // block (e.g., memref block arguments, returned memrefs, - // memrefs passed to function calls, etc.). - DenseSet srcEscapingMemRefs; - gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); - - // Skip if there are non-affine operations in between the 'srcNode' - // and 'dstNode' using their memrefs. If so, we wouldn't be able to - // compute a legal insertion point for now. 'srcNode' and 'dstNode' - // memrefs with non-affine operation users would be considered - // escaping memrefs so we can limit this check to only scenarios with - // escaping memrefs. - if (!srcEscapingMemRefs.empty() && - hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { - LLVM_DEBUG( - llvm::dbgs() - << "Can't fuse: non-affine users in between the loops\n."); - continue; - } + // Try to fuse 'dstNode' with candidate producer loops until a fixed point + // is reached. Fusing two loops may expose new fusion opportunities. + bool dstNodeChanged; + do { + // Gather src loop candidates for 'dstNode' and visit them in "quasi" + // reverse program order to minimize the number of iterations needed to + // reach the fixed point. Note that this is a best effort approach since + // 'getProducerCandidates' does not always guarantee that program order + // in 'srcIdCandidates'. + dstNodeChanged = false; + SmallVector srcIdCandidates; + getProducerCandidates(dstId, mdg, srcIdCandidates); + + for (unsigned srcId : llvm::reverse(srcIdCandidates)) { + // Get 'srcNode' from which to attempt fusion into 'dstNode'. + auto *srcNode = mdg->getNode(srcId); + auto srcAffineForOp = cast(srcNode->op); + LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId + << " for dst loop " << dstId << "\n"); + + // Skip if 'srcNode' is a loop nest returning values. + // TODO: support loop nests that return values. + if (isa(srcNode->op) && srcNode->op->getNumResults() > 0) + continue; - // Compute an operation list insertion point for the fused loop - // nest which preserves dependences. - Operation *fusedLoopInsPoint = - mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); - if (fusedLoopInsPoint == nullptr) - continue; + DenseSet producerConsumerMemrefs; + gatherProducerConsumerMemrefs(srcId, dstId, mdg, + producerConsumerMemrefs); - // Compute the innermost common loop depth for dstNode - // producer-consumer loads/stores. - SmallVector dstMemrefOps; - for (Operation *op : dstNode->loads) - if (producerConsumerMemrefs.count( - cast(op).getMemRef()) > 0) - dstMemrefOps.push_back(op); - for (Operation *op : dstNode->stores) - if (producerConsumerMemrefs.count( - cast(op).getMemRef())) - dstMemrefOps.push_back(op); - unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); - - // Check the feasibility of fusing src loop nest into dst loop nest - // at loop depths in range [1, dstLoopDepthTest]. - unsigned maxLegalFusionDepth = 0; - SmallVector depthSliceUnions; - depthSliceUnions.resize(dstLoopDepthTest); - FusionStrategy strategy(FusionStrategy::ProducerConsumer); - for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { - FusionResult result = mlir::canFuseLoops( - srcAffineForOp, dstAffineForOp, - /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); - - if (result.value == FusionResult::Success) - maxLegalFusionDepth = i; - } + // Skip if 'srcNode' out edge count on any memref is greater than + // 'maxSrcUserCount'. + if (any_of(producerConsumerMemrefs, [&](Value memref) { + return mdg->getOutEdgeCount(srcNode->id, memref) > + maxSrcUserCount; + })) + continue; - if (maxLegalFusionDepth == 0) { - LLVM_DEBUG(llvm::dbgs() - << "Can't fuse: fusion is not legal at any depth\n"); - continue; - } + // Gather memrefs in 'srcNode' that are written and escape out of the + // block (e.g., memref block arguments, returned memrefs, + // memrefs passed to function calls, etc.). + DenseSet srcEscapingMemRefs; + gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); + + // Skip if there are non-affine operations in between the 'srcNode' + // and 'dstNode' using their memrefs. If so, we wouldn't be able to + // compute a legal insertion point for now. 'srcNode' and 'dstNode' + // memrefs with non-affine operation users would be considered + // escaping memrefs so we can limit this check to only scenarios with + // escaping memrefs. + if (!srcEscapingMemRefs.empty() && + hasNonAffineUsersOnThePath(srcId, dstId, mdg)) { + LLVM_DEBUG(llvm::dbgs() + << "Can't fuse: non-affine users in between the loops\n."); + continue; + } - // Check if fusion would be profitable. We skip profitability analysis - // for maximal fusion since we already know the maximal legal depth to - // fuse. - unsigned bestDstLoopDepth = maxLegalFusionDepth; - if (!maximalFusion) { - // Retrieve producer stores from the src loop. - SmallVector producerStores; - for (Operation *op : srcNode->stores) - if (producerConsumerMemrefs.count( - cast(op).getMemRef())) - producerStores.push_back(op); - - // TODO: Suppport multiple producer stores in profitability - // analysis. We limit profitability analysis to only scenarios with - // a single producer store for now. Note that some multi-store - // producer scenarios will still go through profitability analysis - // if only one of the stores is involved the producer-consumer - // relationship of the candidate loops. - assert(!producerStores.empty() && "Expected producer store"); - if (producerStores.size() > 1) - LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " - "supported for this case\n"); - else if (!isFusionProfitable(producerStores[0], producerStores[0], - dstAffineForOp, depthSliceUnions, - maxLegalFusionDepth, &bestDstLoopDepth, - computeToleranceThreshold)) - continue; - } + // Compute an operation list insertion point for the fused loop + // nest which preserves dependences. + Operation *fusedLoopInsPoint = + mdg->getFusedLoopNestInsertionPoint(srcNode->id, dstNode->id); + if (fusedLoopInsPoint == nullptr) + continue; + + // Compute the innermost common loop depth for dstNode + // producer-consumer loads/stores. + SmallVector dstMemrefOps; + for (Operation *op : dstNode->loads) + if (producerConsumerMemrefs.count( + cast(op).getMemRef()) > 0) + dstMemrefOps.push_back(op); + for (Operation *op : dstNode->stores) + if (producerConsumerMemrefs.count( + cast(op).getMemRef())) + dstMemrefOps.push_back(op); + unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); + + // Check the feasibility of fusing src loop nest into dst loop nest + // at loop depths in range [1, dstLoopDepthTest]. + unsigned maxLegalFusionDepth = 0; + SmallVector depthSliceUnions; + depthSliceUnions.resize(dstLoopDepthTest); + FusionStrategy strategy(FusionStrategy::ProducerConsumer); + for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { + FusionResult result = mlir::canFuseLoops( + srcAffineForOp, dstAffineForOp, + /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); + + if (result.value == FusionResult::Success) + maxLegalFusionDepth = i; + } - assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); - ComputationSliceState &bestSlice = - depthSliceUnions[bestDstLoopDepth - 1]; - assert(!bestSlice.isEmpty() && "Missing slice union for depth"); - - // Determine if 'srcId' can be removed after fusion, taking into - // account remaining dependences, escaping memrefs and the fusion - // insertion point. - bool removeSrcNode = canRemoveSrcNodeAfterFusion( - srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, - mdg); - - DenseSet privateMemrefs; - for (Value memref : producerConsumerMemrefs) { - // If `memref` is an escaping one, do not create a private memref - // for the below scenarios, since doing so will leave the escaping - // memref unmodified as all the writes originally meant for the - // escaping memref would be performed on the private memref: - // 1. The source is to be removed after fusion, - // OR - // 2. The destination writes to `memref`. - if (srcEscapingMemRefs.count(memref) > 0 && - (removeSrcNode || dstNode->getStoreOpCount(memref) > 0)) - continue; - - // Don't create a private memref if 'srcNode' has in edges on - // 'memref' or 'dstNode' has out edges on 'memref'. - if (mdg->getIncomingMemRefAccesses(srcId, memref) > 0 || - mdg->getOutEdgeCount(dstId, memref) > 0) - continue; - - // If 'srcNode' will be removed but it has out edges on 'memref' to - // nodes other than 'dstNode', we have to preserve dependences and - // cannot create a private memref. - if (removeSrcNode && - any_of(mdg->outEdges[srcId], [&](const auto &edge) { - return edge.value == memref && edge.id != dstId; - })) - continue; + if (maxLegalFusionDepth == 0) { + LLVM_DEBUG(llvm::dbgs() + << "Can't fuse: fusion is not legal at any depth\n"); + continue; + } + // Check if fusion would be profitable. We skip profitability analysis + // for maximal fusion since we already know the maximal legal depth to + // fuse. + unsigned bestDstLoopDepth = maxLegalFusionDepth; + if (!maximalFusion) { + // Retrieve producer stores from the src loop. + SmallVector producerStores; + for (Operation *op : srcNode->stores) + if (producerConsumerMemrefs.count( + cast(op).getMemRef())) + producerStores.push_back(op); + + // TODO: Suppport multiple producer stores in profitability + // analysis. We limit profitability analysis to only scenarios with + // a single producer store for now. Note that some multi-store + // producer scenarios will still go through profitability analysis + // if only one of the stores is involved the producer-consumer + // relationship of the candidate loops. + assert(!producerStores.empty() && "Expected producer store"); + if (producerStores.size() > 1) + LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not " + "supported for this case\n"); + else if (!isFusionProfitable(producerStores[0], producerStores[0], + dstAffineForOp, depthSliceUnions, + maxLegalFusionDepth, &bestDstLoopDepth, + computeToleranceThreshold)) + continue; + } + + assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); + ComputationSliceState &bestSlice = + depthSliceUnions[bestDstLoopDepth - 1]; + assert(!bestSlice.isEmpty() && "Missing slice union for depth"); + + // Determine if 'srcId' can be removed after fusion, taking into + // account remaining dependences, escaping memrefs and the fusion + // insertion point. + bool removeSrcNode = canRemoveSrcNodeAfterFusion( + srcId, dstId, bestSlice, fusedLoopInsPoint, srcEscapingMemRefs, + mdg); + + DenseSet privateMemrefs; + for (Value memref : producerConsumerMemrefs) { + if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, + removeSrcNode)) { + // Create a private version of this memref. + LLVM_DEBUG(llvm::dbgs() + << "Creating private memref for " << memref << '\n'); // Create a private version of this memref. privateMemrefs.insert(memref); } + } - // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); - dstNodeChanged = true; + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. + fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); + dstNodeChanged = true; + + LLVM_DEBUG(llvm::dbgs() + << "Fused src loop " << srcId << " into dst loop " << dstId + << " at depth " << bestDstLoopDepth << ":\n" + << dstAffineForOp << "\n"); + + // Move 'dstAffineForOp' before 'insertPointInst' if needed. + if (fusedLoopInsPoint != dstAffineForOp) + dstAffineForOp->moveBefore(fusedLoopInsPoint); + + // Update edges between 'srcNode' and 'dstNode'. + mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, + removeSrcNode); + + // Create private memrefs. + if (!privateMemrefs.empty()) { + // Gather stores for all the private-to-be memrefs. + DenseMap> privateMemRefToStores; + dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { + Value storeMemRef = storeOp.getMemRef(); + if (privateMemrefs.count(storeMemRef) > 0) + privateMemRefToStores[storeMemRef].push_back(storeOp); + }); - LLVM_DEBUG(llvm::dbgs() - << "Fused src loop " << srcId << " into dst loop " << dstId - << " at depth " << bestDstLoopDepth << ":\n" - << dstAffineForOp << "\n"); - - // Move 'dstAffineForOp' before 'insertPointInst' if needed. - if (fusedLoopInsPoint != dstAffineForOp) - dstAffineForOp->moveBefore(fusedLoopInsPoint); - - // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id, privateMemrefs, - removeSrcNode); - - // Create private memrefs. - if (!privateMemrefs.empty()) { - // Gather stores for all the private-to-be memrefs. - DenseMap> privateMemRefToStores; - dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) { - Value storeMemRef = storeOp.getMemRef(); - if (privateMemrefs.count(storeMemRef) > 0) - privateMemRefToStores[storeMemRef].push_back(storeOp); - }); - - // Replace original memrefs with private memrefs. Note that all the - // loads and stores on these memrefs will be replaced with a new - // loads and stores. Any reference to the original ones becomes - // invalid after this point. - for (auto &memrefToStoresPair : privateMemRefToStores) { - // TODO: Use union of memref write regions to compute - // private memref footprint. - SmallVector &storesForMemref = - memrefToStoresPair.second; - Value newMemRef = createPrivateMemRef( - dstAffineForOp, storesForMemref[0], bestDstLoopDepth, - fastMemorySpace, localBufSizeThreshold); - // Create new node in dependence graph for 'newMemRef' alloc op. - unsigned newMemRefNodeId = - mdg->addNode(newMemRef.getDefiningOp()); - // Add edge from 'newMemRef' node to dstNode. - mdg->addEdge(newMemRefNodeId, dstId, newMemRef); - } - // One or more entries for 'newMemRef' alloc op are inserted into - // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to - // reallocate, update dstNode. - dstNode = mdg->getNode(dstId); + // Replace original memrefs with private memrefs. Note that all the + // loads and stores on these memrefs will be replaced with a new + // loads and stores. Any reference to the original ones becomes + // invalid after this point. + for (auto &memrefToStoresPair : privateMemRefToStores) { + // TODO: Use union of memref write regions to compute + // private memref footprint. + SmallVector &storesForMemref = + memrefToStoresPair.second; + Value newMemRef = createPrivateMemRef( + dstAffineForOp, storesForMemref[0], bestDstLoopDepth, + fastMemorySpace, localBufSizeThreshold); + // Create new node in dependence graph for 'newMemRef' alloc op. + unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); + // Add edge from 'newMemRef' node to dstNode. + mdg->addEdge(newMemRefNodeId, dstId, newMemRef); } + // One or more entries for 'newMemRef' alloc op are inserted into + // the DenseMap mdg->nodes. Since an insertion may cause DenseMap to + // reallocate, update dstNode. + dstNode = mdg->getNode(dstId); + } - // Collect dst loop stats after memref privatization transformation. - LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstAffineForOp); + // Collect dst loop stats after memref privatization transformation. + LoopNestStateCollector dstLoopCollector; + dstLoopCollector.collect(dstAffineForOp); - // Clear and add back loads and stores. - mdg->clearNodeLoadAndStores(dstNode->id); - mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, - dstLoopCollector.storeOpInsts); + // Clear and add back loads and stores. + mdg->clearNodeLoadAndStores(dstNode->id); + mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, + dstLoopCollector.storeOpInsts); - if (removeSrcNode) { - LLVM_DEBUG(llvm::dbgs() - << "Removing src loop " << srcId << " after fusion\n"); - // srcNode is no longer valid after it is removed from mdg. - srcAffineForOp.erase(); - mdg->removeNode(srcId); - srcNode = nullptr; - } + if (removeSrcNode) { + LLVM_DEBUG(llvm::dbgs() + << "Removing src loop " << srcId << " after fusion\n"); + // srcNode is no longer valid after it is removed from mdg. + srcAffineForOp.erase(); + mdg->removeNode(srcId); + srcNode = nullptr; } - } while (dstNodeChanged); + } + } while (dstNodeChanged); + } + + /// Visit each node in the graph, and for each node, attempt to fuse it with + /// producer-consumer candidates. No fusion is performed when producers with a + /// user count greater than `maxSrcUserCount` for any of the memrefs involved + /// are encountered. + void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { + LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); + init(); + while (!worklist.empty()) { + unsigned dstId = worklist.back(); + worklist.pop_back(); + performFusionsIntoDest(dstId, maxSrcUserCount); } }