diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h --- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h @@ -17,11 +17,7 @@ #define MLIR_DIALECT_AFFINE_ANALYSIS_UTILS_H #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Location.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include #include @@ -35,6 +31,186 @@ class Operation; class Value; +// LoopNestStateCollector walks loop nests and collects load and store +// operations, and whether or not a region holding op other than ForOp and IfOp +// was encountered in the loop nest. +struct LoopNestStateCollector { + SmallVector forOps; + SmallVector loadOpInsts; + SmallVector storeOpInsts; + bool hasNonAffineRegionOp = false; + + // Collects load and store operations, and whether or not a region holding op + // other than ForOp and IfOp was encountered in the loop nest. + void collect(Operation *opToWalk); +}; + +// MemRefDependenceGraph is a graph data structure where graph nodes are +// top-level operations in a `Block` which contain load/store ops, and edges +// are memref dependences between the nodes. +// TODO: Add a more flexible dependence graph representation. +// TODO: Add a depth parameter to dependence graph construction. +struct MemRefDependenceGraph { +public: + // Node represents a node in the graph. A Node is either an entire loop nest + // rooted at the top level which contains loads/stores, or a top level + // load/store. + struct Node { + // The unique identifier of this node in the graph. + unsigned id; + // The top-level statement which is (or contains) a load/store. + Operation *op; + // List of load operations. + SmallVector loads; + // List of store op insts. + SmallVector stores; + + Node(unsigned id, Operation *op) : id(id), op(op) {} + + // Returns the load op count for 'memref'. + unsigned getLoadOpCount(Value memref) const; + + // Returns the store op count for 'memref'. + unsigned getStoreOpCount(Value memref) const; + + // Returns all store ops in 'storeOps' which access 'memref'. + void getStoreOpsForMemref(Value memref, + SmallVectorImpl *storeOps) const; + + // Returns all load ops in 'loadOps' which access 'memref'. + void getLoadOpsForMemref(Value memref, + SmallVectorImpl *loadOps) const; + + // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node + // has at least one load and store operation. + void getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) const; + }; + + // Edge represents a data dependence between nodes in the graph. + struct Edge { + // The id of the node at the other end of the edge. + // If this edge is stored in Edge = Node.inEdges[i], then + // 'Node.inEdges[i].id' is the identifier of the source node of the edge. + // If this edge is stored in Edge = Node.outEdges[i], then + // 'Node.outEdges[i].id' is the identifier of the dest node of the edge. + unsigned id; + // The SSA value on which this edge represents a dependence. + // If the value is a memref, then the dependence is between graph nodes + // which contain accesses to the same memref 'value'. If the value is a + // non-memref value, then the dependence is between a graph node which + // defines an SSA value and another graph node which uses the SSA value + // (e.g. a constant or load operation defining a value which is used inside + // a loop nest). + Value value; + }; + + // Map from node id to Node. + DenseMap nodes; + // Map from node id to list of input edges. + DenseMap> inEdges; + // Map from node id to list of output edges. + DenseMap> outEdges; + // Map from memref to a count on the dependence edges associated with that + // memref. + DenseMap memrefEdgeCount; + // The next unique identifier to use for newly created graph nodes. + unsigned nextNodeId = 0; + + MemRefDependenceGraph(Block &block) : block(block) {} + + // Initializes the dependence graph based on operations in `block'. + // Returns true on success, false otherwise. + bool init(); + + // Returns the graph node for 'id'. + Node *getNode(unsigned id); + + // Returns the graph node for 'forOp'. + Node *getForOpNode(AffineForOp forOp); + + // Adds a node with 'op' to the graph and returns its unique identifier. + unsigned addNode(Operation *op); + + // Remove node 'id' (and its associated edges) from graph. + void removeNode(unsigned id); + + // Returns true if node 'id' writes to any memref which escapes (or is an + // argument to) the block. Returns false otherwise. + bool writesToLiveInOrEscapingMemrefs(unsigned id); + + // Returns true iff there is an edge from node 'srcId' to node 'dstId' which + // is for 'value' if non-null, or for any value otherwise. Returns false + // otherwise. + bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr); + + // Adds an edge from node 'srcId' to node 'dstId' for 'value'. + void addEdge(unsigned srcId, unsigned dstId, Value value); + + // Removes an edge from node 'srcId' to node 'dstId' for 'value'. + void removeEdge(unsigned srcId, unsigned dstId, Value value); + + // Returns true if there is a path in the dependence graph from node 'srcId' + // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the + // operations that the edges connected are expected to be from the same block. + bool hasDependencePath(unsigned srcId, unsigned dstId); + + // Returns the input edge count for node 'id' and 'memref' from src nodes + // which access 'memref' with a store operation. + unsigned getIncomingMemRefAccesses(unsigned id, Value memref); + + // Returns the output edge count for node 'id' and 'memref' (if non-null), + // otherwise returns the total output edge count from node 'id'. + unsigned getOutEdgeCount(unsigned id, Value memref = nullptr); + + /// Return all nodes which define SSA values used in node 'id'. + void gatherDefiningNodes(unsigned id, DenseSet &definingNodes); + + // Computes and returns an insertion point operation, before which the + // the fused loop nest can be inserted while preserving + // dependences. Returns nullptr if no such insertion point is found. + Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId); + + // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, + // taking into account that: + // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, + // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a + // private memref. + void updateEdges(unsigned srcId, unsigned dstId, + const DenseSet &privateMemRefs, bool removeSrcId); + + // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion + // of sibling node 'sibId' into node 'dstId'. + void updateEdges(unsigned sibId, unsigned dstId); + + // Adds ops in 'loads' and 'stores' to node at 'id'. + void addToNode(unsigned id, const SmallVectorImpl &loads, + const SmallVectorImpl &stores); + + void clearNodeLoadAndStores(unsigned id); + + // Calls 'callback' for each input edge incident to node 'id' which carries a + // memref dependence. + void forEachMemRefInputEdge(unsigned id, + const std::function &callback); + + // Calls 'callback' for each output edge from node 'id' which carries a + // memref dependence. + void forEachMemRefOutputEdge(unsigned id, + const std::function &callback); + + // Calls 'callback' for each edge in 'edges' which carries a memref + // dependence. + void forEachMemRefEdge(ArrayRef edges, + const std::function &callback); + + void print(raw_ostream &os) const; + + void dump() const { print(llvm::errs()); } + + /// The block for which this graph is created to perform fusion. + Block █ +}; + /// Populates 'loops' with IVs of the affine.for ops surrounding 'op' ordered /// from the outermost 'affine.for' operation to the innermost one. void getAffineForIVs(Operation &op, SmallVectorImpl *loops); diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -32,6 +32,475 @@ using llvm::SmallDenseMap; +using Node = MemRefDependenceGraph::Node; + +// LoopNestStateCollector walks loop nests and collects load and store +// operations, and whether or not a region holding op other than ForOp and IfOp +// was encountered in the loop nest. +void LoopNestStateCollector::collect(Operation *opToWalk) { + opToWalk->walk([&](Operation *op) { + if (isa(op)) + forOps.push_back(cast(op)); + else if (op->getNumRegions() != 0 && !isa(op)) + hasNonAffineRegionOp = true; + else if (isa(op)) + loadOpInsts.push_back(op); + else if (isa(op)) + storeOpInsts.push_back(op); + }); +} + +// Returns the load op count for 'memref'. +unsigned Node::getLoadOpCount(Value memref) const { + unsigned loadOpCount = 0; + for (Operation *loadOp : loads) { + if (memref == cast(loadOp).getMemRef()) + ++loadOpCount; + } + return loadOpCount; +} + +// Returns the store op count for 'memref'. +unsigned Node::getStoreOpCount(Value memref) const { + unsigned storeOpCount = 0; + for (Operation *storeOp : stores) { + if (memref == cast(storeOp).getMemRef()) + ++storeOpCount; + } + return storeOpCount; +} + +// Returns all store ops in 'storeOps' which access 'memref'. +void Node::getStoreOpsForMemref(Value memref, + SmallVectorImpl *storeOps) const { + for (Operation *storeOp : stores) { + if (memref == cast(storeOp).getMemRef()) + storeOps->push_back(storeOp); + } +} + +// Returns all load ops in 'loadOps' which access 'memref'. +void Node::getLoadOpsForMemref(Value memref, + SmallVectorImpl *loadOps) const { + for (Operation *loadOp : loads) { + if (memref == cast(loadOp).getMemRef()) + loadOps->push_back(loadOp); + } +} + +// Returns all memrefs in 'loadAndStoreMemrefSet' for which this node +// has at least one load and store operation. +void Node::getLoadAndStoreMemrefSet( + DenseSet *loadAndStoreMemrefSet) const { + llvm::SmallDenseSet loadMemrefs; + for (Operation *loadOp : loads) { + loadMemrefs.insert(cast(loadOp).getMemRef()); + } + for (Operation *storeOp : stores) { + auto memref = cast(storeOp).getMemRef(); + if (loadMemrefs.count(memref) > 0) + loadAndStoreMemrefSet->insert(memref); + } +} + +// Returns the graph node for 'id'. +Node *MemRefDependenceGraph::getNode(unsigned id) { + auto it = nodes.find(id); + assert(it != nodes.end()); + return &it->second; +} + +// Returns the graph node for 'forOp'. +Node *MemRefDependenceGraph::getForOpNode(AffineForOp forOp) { + for (auto &idAndNode : nodes) + if (idAndNode.second.op == forOp) + return &idAndNode.second; + return nullptr; +} + +// Adds a node with 'op' to the graph and returns its unique identifier. +unsigned MemRefDependenceGraph::addNode(Operation *op) { + Node node(nextNodeId++, op); + nodes.insert({node.id, node}); + return node.id; +} + +// Remove node 'id' (and its associated edges) from graph. +void MemRefDependenceGraph::removeNode(unsigned id) { + // Remove each edge in 'inEdges[id]'. + if (inEdges.count(id) > 0) { + SmallVector oldInEdges = inEdges[id]; + for (auto &inEdge : oldInEdges) { + removeEdge(inEdge.id, id, inEdge.value); + } + } + // Remove each edge in 'outEdges[id]'. + if (outEdges.count(id) > 0) { + SmallVector oldOutEdges = outEdges[id]; + for (auto &outEdge : oldOutEdges) { + removeEdge(id, outEdge.id, outEdge.value); + } + } + // Erase remaining node state. + inEdges.erase(id); + outEdges.erase(id); + nodes.erase(id); +} + +// Returns true if node 'id' writes to any memref which escapes (or is an +// argument to) the block. Returns false otherwise. +bool MemRefDependenceGraph::writesToLiveInOrEscapingMemrefs(unsigned id) { + Node *node = getNode(id); + for (auto *storeOpInst : node->stores) { + auto memref = cast(storeOpInst).getMemRef(); + auto *op = memref.getDefiningOp(); + // Return true if 'memref' is a block argument. + if (!op) + return true; + // Return true if any use of 'memref' does not deference it in an affine + // way. + for (auto *user : memref.getUsers()) + if (!isa(*user)) + return true; + } + return false; +} + +// Returns true iff there is an edge from node 'srcId' to node 'dstId' which +// is for 'value' if non-null, or for any value otherwise. Returns false +// otherwise. +bool MemRefDependenceGraph::hasEdge(unsigned srcId, unsigned dstId, + Value value) { + if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { + return false; + } + bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { + return edge.id == dstId && (!value || edge.value == value); + }); + bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { + return edge.id == srcId && (!value || edge.value == value); + }); + return hasOutEdge && hasInEdge; +} + +// Adds an edge from node 'srcId' to node 'dstId' for 'value'. +void MemRefDependenceGraph::addEdge(unsigned srcId, unsigned dstId, + Value value) { + if (!hasEdge(srcId, dstId, value)) { + outEdges[srcId].push_back({dstId, value}); + inEdges[dstId].push_back({srcId, value}); + if (value.getType().isa()) + memrefEdgeCount[value]++; + } +} + +// Removes an edge from node 'srcId' to node 'dstId' for 'value'. +void MemRefDependenceGraph::removeEdge(unsigned srcId, unsigned dstId, + Value value) { + assert(inEdges.count(dstId) > 0); + assert(outEdges.count(srcId) > 0); + if (value.getType().isa()) { + assert(memrefEdgeCount.count(value) > 0); + memrefEdgeCount[value]--; + } + // Remove 'srcId' from 'inEdges[dstId]'. + for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { + if ((*it).id == srcId && (*it).value == value) { + inEdges[dstId].erase(it); + break; + } + } + // Remove 'dstId' from 'outEdges[srcId]'. + for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); ++it) { + if ((*it).id == dstId && (*it).value == value) { + outEdges[srcId].erase(it); + break; + } + } +} + +// Returns true if there is a path in the dependence graph from node 'srcId' +// to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the +// operations that the edges connected are expected to be from the same block. +bool MemRefDependenceGraph::hasDependencePath(unsigned srcId, unsigned dstId) { + // Worklist state is: + SmallVector, 4> worklist; + worklist.push_back({srcId, 0}); + Operation *dstOp = getNode(dstId)->op; + // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. + while (!worklist.empty()) { + auto &idAndIndex = worklist.back(); + // Return true if we have reached 'dstId'. + if (idAndIndex.first == dstId) + return true; + // Pop and continue if node has no out edges, or if all out edges have + // already been visited. + if (outEdges.count(idAndIndex.first) == 0 || + idAndIndex.second == outEdges[idAndIndex.first].size()) { + worklist.pop_back(); + continue; + } + // Get graph edge to traverse. + Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; + // Increment next output edge index for 'idAndIndex'. + ++idAndIndex.second; + // Add node at 'edge.id' to the worklist. We don't need to consider + // nodes that are "after" dstId in the containing block; one can't have a + // path to `dstId` from any of those nodes. + bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op); + if (!afterDst && edge.id != idAndIndex.first) + worklist.push_back({edge.id, 0}); + } + return false; +} + +// Returns the input edge count for node 'id' and 'memref' from src nodes +// which access 'memref' with a store operation. +unsigned MemRefDependenceGraph::getIncomingMemRefAccesses(unsigned id, + Value memref) { + unsigned inEdgeCount = 0; + if (inEdges.count(id) > 0) + for (auto &inEdge : inEdges[id]) + if (inEdge.value == memref) { + Node *srcNode = getNode(inEdge.id); + // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' + if (srcNode->getStoreOpCount(memref) > 0) + ++inEdgeCount; + } + return inEdgeCount; +} + +// Returns the output edge count for node 'id' and 'memref' (if non-null), +// otherwise returns the total output edge count from node 'id'. +unsigned MemRefDependenceGraph::getOutEdgeCount(unsigned id, Value memref) { + unsigned outEdgeCount = 0; + if (outEdges.count(id) > 0) + for (auto &outEdge : outEdges[id]) + if (!memref || outEdge.value == memref) + ++outEdgeCount; + return outEdgeCount; +} + +/// Return all nodes which define SSA values used in node 'id'. +void MemRefDependenceGraph::gatherDefiningNodes( + unsigned id, DenseSet &definingNodes) { + for (MemRefDependenceGraph::Edge edge : inEdges[id]) + // By definition of edge, if the edge value is a non-memref value, + // then the dependence is between a graph node which defines an SSA value + // and another graph node which uses the SSA value. + if (!edge.value.getType().isa()) + definingNodes.insert(edge.id); +} + +// Computes and returns an insertion point operation, before which the +// the fused loop nest can be inserted while preserving +// dependences. Returns nullptr if no such insertion point is found. +Operation * +MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId, + unsigned dstId) { + if (outEdges.count(srcId) == 0) + return getNode(dstId)->op; + + // Skip if there is any defining node of 'dstId' that depends on 'srcId'. + DenseSet definingNodes; + gatherDefiningNodes(dstId, definingNodes); + if (llvm::any_of(definingNodes, + [&](unsigned id) { return hasDependencePath(srcId, id); })) { + LLVM_DEBUG(llvm::dbgs() + << "Can't fuse: a defining op with a user in the dst " + "loop has dependence from the src loop\n"); + return nullptr; + } + + // Build set of insts in range (srcId, dstId) which depend on 'srcId'. + SmallPtrSet srcDepInsts; + for (auto &outEdge : outEdges[srcId]) + if (outEdge.id != dstId) + srcDepInsts.insert(getNode(outEdge.id)->op); + + // Build set of insts in range (srcId, dstId) on which 'dstId' depends. + SmallPtrSet dstDepInsts; + for (auto &inEdge : inEdges[dstId]) + if (inEdge.id != srcId) + dstDepInsts.insert(getNode(inEdge.id)->op); + + Operation *srcNodeInst = getNode(srcId)->op; + Operation *dstNodeInst = getNode(dstId)->op; + + // Computing insertion point: + // *) Walk all operation positions in Block operation list in the + // range (src, dst). For each operation 'op' visited in this search: + // *) Store in 'firstSrcDepPos' the first position where 'op' has a + // dependence edge from 'srcNode'. + // *) Store in 'lastDstDepPost' the last position where 'op' has a + // dependence edge to 'dstNode'. + // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the + // operation insertion point (or return null pointer if no such + // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). + SmallVector depInsts; + std::optional firstSrcDepPos; + std::optional lastDstDepPos; + unsigned pos = 0; + for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); + it != Block::iterator(dstNodeInst); ++it) { + Operation *op = &(*it); + if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt) + firstSrcDepPos = pos; + if (dstDepInsts.count(op) > 0) + lastDstDepPos = pos; + depInsts.push_back(op); + ++pos; + } + + if (firstSrcDepPos.has_value()) { + if (lastDstDepPos.has_value()) { + if (*firstSrcDepPos <= *lastDstDepPos) { + // No valid insertion point exists which preserves dependences. + return nullptr; + } + } + // Return the insertion point at 'firstSrcDepPos'. + return depInsts[*firstSrcDepPos]; + } + // No dependence targets in range (or only dst deps in range), return + // 'dstNodInst' insertion point. + return dstNodeInst; +} + +// Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, +// taking into account that: +// *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, +// *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a +// private memref. +void MemRefDependenceGraph::updateEdges(unsigned srcId, unsigned dstId, + const DenseSet &privateMemRefs, + bool removeSrcId) { + // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. + if (inEdges.count(srcId) > 0) { + SmallVector oldInEdges = inEdges[srcId]; + for (auto &inEdge : oldInEdges) { + // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. + if (privateMemRefs.count(inEdge.value) == 0) + addEdge(inEdge.id, dstId, inEdge.value); + } + } + // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. + // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. + if (outEdges.count(srcId) > 0) { + SmallVector oldOutEdges = outEdges[srcId]; + for (auto &outEdge : oldOutEdges) { + // Remove any out edges from 'srcId' to 'dstId' across memrefs. + if (outEdge.id == dstId) + removeEdge(srcId, outEdge.id, outEdge.value); + else if (removeSrcId) { + addEdge(dstId, outEdge.id, outEdge.value); + removeEdge(srcId, outEdge.id, outEdge.value); + } + } + } + // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being + // replaced by a private memref). These edges could come from nodes + // other than 'srcId' which were removed in the previous step. + if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { + SmallVector oldInEdges = inEdges[dstId]; + for (auto &inEdge : oldInEdges) + if (privateMemRefs.count(inEdge.value) > 0) + removeEdge(inEdge.id, dstId, inEdge.value); + } +} + +// Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion +// of sibling node 'sibId' into node 'dstId'. +void MemRefDependenceGraph::updateEdges(unsigned sibId, unsigned dstId) { + // For each edge in 'inEdges[sibId]': + // *) Add new edge from source node 'inEdge.id' to 'dstNode'. + // *) Remove edge from source node 'inEdge.id' to 'sibNode'. + if (inEdges.count(sibId) > 0) { + SmallVector oldInEdges = inEdges[sibId]; + for (auto &inEdge : oldInEdges) { + addEdge(inEdge.id, dstId, inEdge.value); + removeEdge(inEdge.id, sibId, inEdge.value); + } + } + + // For each edge in 'outEdges[sibId]' to node 'id' + // *) Add new edge from 'dstId' to 'outEdge.id'. + // *) Remove edge from 'sibId' to 'outEdge.id'. + if (outEdges.count(sibId) > 0) { + SmallVector oldOutEdges = outEdges[sibId]; + for (auto &outEdge : oldOutEdges) { + addEdge(dstId, outEdge.id, outEdge.value); + removeEdge(sibId, outEdge.id, outEdge.value); + } + } +} + +// Adds ops in 'loads' and 'stores' to node at 'id'. +void MemRefDependenceGraph::addToNode( + unsigned id, const SmallVectorImpl &loads, + const SmallVectorImpl &stores) { + Node *node = getNode(id); + llvm::append_range(node->loads, loads); + llvm::append_range(node->stores, stores); +} + +void MemRefDependenceGraph::clearNodeLoadAndStores(unsigned id) { + Node *node = getNode(id); + node->loads.clear(); + node->stores.clear(); +} + +// Calls 'callback' for each input edge incident to node 'id' which carries a +// memref dependence. +void MemRefDependenceGraph::forEachMemRefInputEdge( + unsigned id, const std::function &callback) { + if (inEdges.count(id) > 0) + forEachMemRefEdge(inEdges[id], callback); +} + +// Calls 'callback' for each output edge from node 'id' which carries a +// memref dependence. +void MemRefDependenceGraph::forEachMemRefOutputEdge( + unsigned id, const std::function &callback) { + if (outEdges.count(id) > 0) + forEachMemRefEdge(outEdges[id], callback); +} + +// Calls 'callback' for each edge in 'edges' which carries a memref +// dependence. +void MemRefDependenceGraph::forEachMemRefEdge( + ArrayRef edges, const std::function &callback) { + for (const auto &edge : edges) { + // Skip if 'edge' is not a memref dependence edge. + if (!edge.value.getType().isa()) + continue; + assert(nodes.count(edge.id) > 0); + // Skip if 'edge.id' is not a loop nest. + if (!isa(getNode(edge.id)->op)) + continue; + // Visit current input edge 'edge'. + callback(edge); + } +} + +void MemRefDependenceGraph::print(raw_ostream &os) const { + os << "\nMemRefDependenceGraph\n"; + os << "\nNodes:\n"; + for (const auto &idAndNode : nodes) { + os << "Node: " << idAndNode.first << "\n"; + auto it = inEdges.find(idAndNode.first); + if (it != inEdges.end()) { + for (const auto &e : it->second) + os << " InEdge: " << e.id << " " << e.value << "\n"; + } + it = outEdges.find(idAndNode.first); + if (it != outEdges.end()) { + for (const auto &e : it->second) + os << " OutEdge: " << e.id << " " << e.value << "\n"; + } + } +} + void mlir::getAffineForIVs(Operation &op, SmallVectorImpl *loops) { auto *currOp = op.getParentOp(); AffineForOp currAffineForOp; diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -80,536 +80,6 @@ namespace { -// LoopNestStateCollector walks loop nests and collects load and store -// operations, and whether or not a region holding op other than ForOp and IfOp -// was encountered in the loop nest. -struct LoopNestStateCollector { - SmallVector forOps; - SmallVector loadOpInsts; - SmallVector storeOpInsts; - bool hasNonAffineRegionOp = false; - - void collect(Operation *opToWalk) { - opToWalk->walk([&](Operation *op) { - if (isa(op)) - forOps.push_back(cast(op)); - else if (op->getNumRegions() != 0 && !isa(op)) - hasNonAffineRegionOp = true; - else if (isa(op)) - loadOpInsts.push_back(op); - else if (isa(op)) - storeOpInsts.push_back(op); - }); - } -}; - -// MemRefDependenceGraph is a graph data structure where graph nodes are -// top-level operations in a `Block` which contain load/store ops, and edges -// are memref dependences between the nodes. -// TODO: Add a more flexible dependence graph representation. -// TODO: Add a depth parameter to dependence graph construction. -struct MemRefDependenceGraph { -public: - // Node represents a node in the graph. A Node is either an entire loop nest - // rooted at the top level which contains loads/stores, or a top level - // load/store. - struct Node { - // The unique identifier of this node in the graph. - unsigned id; - // The top-level statement which is (or contains) a load/store. - Operation *op; - // List of load operations. - SmallVector loads; - // List of store op insts. - SmallVector stores; - Node(unsigned id, Operation *op) : id(id), op(op) {} - - // Returns the load op count for 'memref'. - unsigned getLoadOpCount(Value memref) const { - unsigned loadOpCount = 0; - for (Operation *loadOp : loads) { - if (memref == cast(loadOp).getMemRef()) - ++loadOpCount; - } - return loadOpCount; - } - - // Returns the store op count for 'memref'. - unsigned getStoreOpCount(Value memref) const { - unsigned storeOpCount = 0; - for (Operation *storeOp : stores) { - if (memref == cast(storeOp).getMemRef()) - ++storeOpCount; - } - return storeOpCount; - } - - // Returns all store ops in 'storeOps' which access 'memref'. - void getStoreOpsForMemref(Value memref, - SmallVectorImpl *storeOps) const { - for (Operation *storeOp : stores) { - if (memref == cast(storeOp).getMemRef()) - storeOps->push_back(storeOp); - } - } - - // Returns all load ops in 'loadOps' which access 'memref'. - void getLoadOpsForMemref(Value memref, - SmallVectorImpl *loadOps) const { - for (Operation *loadOp : loads) { - if (memref == cast(loadOp).getMemRef()) - loadOps->push_back(loadOp); - } - } - - // Returns all memrefs in 'loadAndStoreMemrefSet' for which this node - // has at least one load and store operation. - void - getLoadAndStoreMemrefSet(DenseSet *loadAndStoreMemrefSet) const { - llvm::SmallDenseSet loadMemrefs; - for (Operation *loadOp : loads) { - loadMemrefs.insert(cast(loadOp).getMemRef()); - } - for (Operation *storeOp : stores) { - auto memref = cast(storeOp).getMemRef(); - if (loadMemrefs.count(memref) > 0) - loadAndStoreMemrefSet->insert(memref); - } - } - }; - - // Edge represents a data dependence between nodes in the graph. - struct Edge { - // The id of the node at the other end of the edge. - // If this edge is stored in Edge = Node.inEdges[i], then - // 'Node.inEdges[i].id' is the identifier of the source node of the edge. - // If this edge is stored in Edge = Node.outEdges[i], then - // 'Node.outEdges[i].id' is the identifier of the dest node of the edge. - unsigned id; - // The SSA value on which this edge represents a dependence. - // If the value is a memref, then the dependence is between graph nodes - // which contain accesses to the same memref 'value'. If the value is a - // non-memref value, then the dependence is between a graph node which - // defines an SSA value and another graph node which uses the SSA value - // (e.g. a constant or load operation defining a value which is used inside - // a loop nest). - Value value; - }; - - // Map from node id to Node. - DenseMap nodes; - // Map from node id to list of input edges. - DenseMap> inEdges; - // Map from node id to list of output edges. - DenseMap> outEdges; - // Map from memref to a count on the dependence edges associated with that - // memref. - DenseMap memrefEdgeCount; - // The next unique identifier to use for newly created graph nodes. - unsigned nextNodeId = 0; - - MemRefDependenceGraph(Block &block) : block(block) {} - - // Initializes the dependence graph based on operations in `block'. - // Returns true on success, false otherwise. - bool init(); - - // Returns the graph node for 'id'. - Node *getNode(unsigned id) { - auto it = nodes.find(id); - assert(it != nodes.end()); - return &it->second; - } - - // Returns the graph node for 'forOp'. - Node *getForOpNode(AffineForOp forOp) { - for (auto &idAndNode : nodes) - if (idAndNode.second.op == forOp) - return &idAndNode.second; - return nullptr; - } - - // Adds a node with 'op' to the graph and returns its unique identifier. - unsigned addNode(Operation *op) { - Node node(nextNodeId++, op); - nodes.insert({node.id, node}); - return node.id; - } - - // Remove node 'id' (and its associated edges) from graph. - void removeNode(unsigned id) { - // Remove each edge in 'inEdges[id]'. - if (inEdges.count(id) > 0) { - SmallVector oldInEdges = inEdges[id]; - for (auto &inEdge : oldInEdges) { - removeEdge(inEdge.id, id, inEdge.value); - } - } - // Remove each edge in 'outEdges[id]'. - if (outEdges.count(id) > 0) { - SmallVector oldOutEdges = outEdges[id]; - for (auto &outEdge : oldOutEdges) { - removeEdge(id, outEdge.id, outEdge.value); - } - } - // Erase remaining node state. - inEdges.erase(id); - outEdges.erase(id); - nodes.erase(id); - } - - // Returns true if node 'id' writes to any memref which escapes (or is an - // argument to) the block. Returns false otherwise. - bool writesToLiveInOrEscapingMemrefs(unsigned id) { - Node *node = getNode(id); - for (auto *storeOpInst : node->stores) { - auto memref = cast(storeOpInst).getMemRef(); - auto *op = memref.getDefiningOp(); - // Return true if 'memref' is a block argument. - if (!op) - return true; - // Return true if any use of 'memref' does not deference it in an affine - // way. - for (auto *user : memref.getUsers()) - if (!isa(*user)) - return true; - } - return false; - } - - // Returns true iff there is an edge from node 'srcId' to node 'dstId' which - // is for 'value' if non-null, or for any value otherwise. Returns false - // otherwise. - bool hasEdge(unsigned srcId, unsigned dstId, Value value = nullptr) { - if (outEdges.count(srcId) == 0 || inEdges.count(dstId) == 0) { - return false; - } - bool hasOutEdge = llvm::any_of(outEdges[srcId], [=](Edge &edge) { - return edge.id == dstId && (!value || edge.value == value); - }); - bool hasInEdge = llvm::any_of(inEdges[dstId], [=](Edge &edge) { - return edge.id == srcId && (!value || edge.value == value); - }); - return hasOutEdge && hasInEdge; - } - - // Adds an edge from node 'srcId' to node 'dstId' for 'value'. - void addEdge(unsigned srcId, unsigned dstId, Value value) { - if (!hasEdge(srcId, dstId, value)) { - outEdges[srcId].push_back({dstId, value}); - inEdges[dstId].push_back({srcId, value}); - if (value.getType().isa()) - memrefEdgeCount[value]++; - } - } - - // Removes an edge from node 'srcId' to node 'dstId' for 'value'. - void removeEdge(unsigned srcId, unsigned dstId, Value value) { - assert(inEdges.count(dstId) > 0); - assert(outEdges.count(srcId) > 0); - if (value.getType().isa()) { - assert(memrefEdgeCount.count(value) > 0); - memrefEdgeCount[value]--; - } - // Remove 'srcId' from 'inEdges[dstId]'. - for (auto *it = inEdges[dstId].begin(); it != inEdges[dstId].end(); ++it) { - if ((*it).id == srcId && (*it).value == value) { - inEdges[dstId].erase(it); - break; - } - } - // Remove 'dstId' from 'outEdges[srcId]'. - for (auto *it = outEdges[srcId].begin(); it != outEdges[srcId].end(); - ++it) { - if ((*it).id == dstId && (*it).value == value) { - outEdges[srcId].erase(it); - break; - } - } - } - - // Returns true if there is a path in the dependence graph from node 'srcId' - // to node 'dstId'. Returns false otherwise. `srcId`, `dstId`, and the - // operations that the edges connected are expected to be from the same block. - bool hasDependencePath(unsigned srcId, unsigned dstId) { - // Worklist state is: - SmallVector, 4> worklist; - worklist.push_back({srcId, 0}); - Operation *dstOp = getNode(dstId)->op; - // Run DFS traversal to see if 'dstId' is reachable from 'srcId'. - while (!worklist.empty()) { - auto &idAndIndex = worklist.back(); - // Return true if we have reached 'dstId'. - if (idAndIndex.first == dstId) - return true; - // Pop and continue if node has no out edges, or if all out edges have - // already been visited. - if (outEdges.count(idAndIndex.first) == 0 || - idAndIndex.second == outEdges[idAndIndex.first].size()) { - worklist.pop_back(); - continue; - } - // Get graph edge to traverse. - Edge edge = outEdges[idAndIndex.first][idAndIndex.second]; - // Increment next output edge index for 'idAndIndex'. - ++idAndIndex.second; - // Add node at 'edge.id' to the worklist. We don't need to consider - // nodes that are "after" dstId in the containing block; one can't have a - // path to `dstId` from any of those nodes. - bool afterDst = dstOp->isBeforeInBlock(getNode(edge.id)->op); - if (!afterDst && edge.id != idAndIndex.first) - worklist.push_back({edge.id, 0}); - } - return false; - } - - // Returns the input edge count for node 'id' and 'memref' from src nodes - // which access 'memref' with a store operation. - unsigned getIncomingMemRefAccesses(unsigned id, Value memref) { - unsigned inEdgeCount = 0; - if (inEdges.count(id) > 0) - for (auto &inEdge : inEdges[id]) - if (inEdge.value == memref) { - Node *srcNode = getNode(inEdge.id); - // Only count in edges from 'srcNode' if 'srcNode' accesses 'memref' - if (srcNode->getStoreOpCount(memref) > 0) - ++inEdgeCount; - } - return inEdgeCount; - } - - // Returns the output edge count for node 'id' and 'memref' (if non-null), - // otherwise returns the total output edge count from node 'id'. - unsigned getOutEdgeCount(unsigned id, Value memref = nullptr) { - unsigned outEdgeCount = 0; - if (outEdges.count(id) > 0) - for (auto &outEdge : outEdges[id]) - if (!memref || outEdge.value == memref) - ++outEdgeCount; - return outEdgeCount; - } - - /// Return all nodes which define SSA values used in node 'id'. - void gatherDefiningNodes(unsigned id, DenseSet &definingNodes) { - for (MemRefDependenceGraph::Edge edge : inEdges[id]) - // By definition of edge, if the edge value is a non-memref value, - // then the dependence is between a graph node which defines an SSA value - // and another graph node which uses the SSA value. - if (!edge.value.getType().isa()) - definingNodes.insert(edge.id); - } - - // Computes and returns an insertion point operation, before which the - // the fused loop nest can be inserted while preserving - // dependences. Returns nullptr if no such insertion point is found. - Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) { - if (outEdges.count(srcId) == 0) - return getNode(dstId)->op; - - // Skip if there is any defining node of 'dstId' that depends on 'srcId'. - DenseSet definingNodes; - gatherDefiningNodes(dstId, definingNodes); - if (llvm::any_of(definingNodes, [&](unsigned id) { - return hasDependencePath(srcId, id); - })) { - LLVM_DEBUG(llvm::dbgs() - << "Can't fuse: a defining op with a user in the dst " - "loop has dependence from the src loop\n"); - return nullptr; - } - - // Build set of insts in range (srcId, dstId) which depend on 'srcId'. - SmallPtrSet srcDepInsts; - for (auto &outEdge : outEdges[srcId]) - if (outEdge.id != dstId) - srcDepInsts.insert(getNode(outEdge.id)->op); - - // Build set of insts in range (srcId, dstId) on which 'dstId' depends. - SmallPtrSet dstDepInsts; - for (auto &inEdge : inEdges[dstId]) - if (inEdge.id != srcId) - dstDepInsts.insert(getNode(inEdge.id)->op); - - Operation *srcNodeInst = getNode(srcId)->op; - Operation *dstNodeInst = getNode(dstId)->op; - - // Computing insertion point: - // *) Walk all operation positions in Block operation list in the - // range (src, dst). For each operation 'op' visited in this search: - // *) Store in 'firstSrcDepPos' the first position where 'op' has a - // dependence edge from 'srcNode'. - // *) Store in 'lastDstDepPost' the last position where 'op' has a - // dependence edge to 'dstNode'. - // *) Compare 'firstSrcDepPos' and 'lastDstDepPost' to determine the - // operation insertion point (or return null pointer if no such - // insertion point exists: 'firstSrcDepPos' <= 'lastDstDepPos'). - SmallVector depInsts; - std::optional firstSrcDepPos; - std::optional lastDstDepPos; - unsigned pos = 0; - for (Block::iterator it = std::next(Block::iterator(srcNodeInst)); - it != Block::iterator(dstNodeInst); ++it) { - Operation *op = &(*it); - if (srcDepInsts.count(op) > 0 && firstSrcDepPos == std::nullopt) - firstSrcDepPos = pos; - if (dstDepInsts.count(op) > 0) - lastDstDepPos = pos; - depInsts.push_back(op); - ++pos; - } - - if (firstSrcDepPos.has_value()) { - if (lastDstDepPos.has_value()) { - if (*firstSrcDepPos <= *lastDstDepPos) { - // No valid insertion point exists which preserves dependences. - return nullptr; - } - } - // Return the insertion point at 'firstSrcDepPos'. - return depInsts[*firstSrcDepPos]; - } - // No dependence targets in range (or only dst deps in range), return - // 'dstNodInst' insertion point. - return dstNodeInst; - } - - // Updates edge mappings from node 'srcId' to node 'dstId' after fusing them, - // taking into account that: - // *) if 'removeSrcId' is true, 'srcId' will be removed after fusion, - // *) memrefs in 'privateMemRefs' has been replaced in node at 'dstId' by a - // private memref. - void updateEdges(unsigned srcId, unsigned dstId, - const DenseSet &privateMemRefs, bool removeSrcId) { - // For each edge in 'inEdges[srcId]': add new edge remapping to 'dstId'. - if (inEdges.count(srcId) > 0) { - SmallVector oldInEdges = inEdges[srcId]; - for (auto &inEdge : oldInEdges) { - // Add edge from 'inEdge.id' to 'dstId' if it's not a private memref. - if (privateMemRefs.count(inEdge.value) == 0) - addEdge(inEdge.id, dstId, inEdge.value); - } - } - // For each edge in 'outEdges[srcId]': remove edge from 'srcId' to 'dstId'. - // If 'srcId' is going to be removed, remap all the out edges to 'dstId'. - if (outEdges.count(srcId) > 0) { - SmallVector oldOutEdges = outEdges[srcId]; - for (auto &outEdge : oldOutEdges) { - // Remove any out edges from 'srcId' to 'dstId' across memrefs. - if (outEdge.id == dstId) - removeEdge(srcId, outEdge.id, outEdge.value); - else if (removeSrcId) { - addEdge(dstId, outEdge.id, outEdge.value); - removeEdge(srcId, outEdge.id, outEdge.value); - } - } - } - // Remove any edges in 'inEdges[dstId]' on 'oldMemRef' (which is being - // replaced by a private memref). These edges could come from nodes - // other than 'srcId' which were removed in the previous step. - if (inEdges.count(dstId) > 0 && !privateMemRefs.empty()) { - SmallVector oldInEdges = inEdges[dstId]; - for (auto &inEdge : oldInEdges) - if (privateMemRefs.count(inEdge.value) > 0) - removeEdge(inEdge.id, dstId, inEdge.value); - } - } - - // Update edge mappings for nodes 'sibId' and 'dstId' to reflect fusion - // of sibling node 'sibId' into node 'dstId'. - void updateEdges(unsigned sibId, unsigned dstId) { - // For each edge in 'inEdges[sibId]': - // *) Add new edge from source node 'inEdge.id' to 'dstNode'. - // *) Remove edge from source node 'inEdge.id' to 'sibNode'. - if (inEdges.count(sibId) > 0) { - SmallVector oldInEdges = inEdges[sibId]; - for (auto &inEdge : oldInEdges) { - addEdge(inEdge.id, dstId, inEdge.value); - removeEdge(inEdge.id, sibId, inEdge.value); - } - } - - // For each edge in 'outEdges[sibId]' to node 'id' - // *) Add new edge from 'dstId' to 'outEdge.id'. - // *) Remove edge from 'sibId' to 'outEdge.id'. - if (outEdges.count(sibId) > 0) { - SmallVector oldOutEdges = outEdges[sibId]; - for (auto &outEdge : oldOutEdges) { - addEdge(dstId, outEdge.id, outEdge.value); - removeEdge(sibId, outEdge.id, outEdge.value); - } - } - } - - // Adds ops in 'loads' and 'stores' to node at 'id'. - void addToNode(unsigned id, const SmallVectorImpl &loads, - const SmallVectorImpl &stores) { - Node *node = getNode(id); - llvm::append_range(node->loads, loads); - llvm::append_range(node->stores, stores); - } - - void clearNodeLoadAndStores(unsigned id) { - Node *node = getNode(id); - node->loads.clear(); - node->stores.clear(); - } - - // Calls 'callback' for each input edge incident to node 'id' which carries a - // memref dependence. - void forEachMemRefInputEdge(unsigned id, - const std::function &callback) { - if (inEdges.count(id) > 0) - forEachMemRefEdge(inEdges[id], callback); - } - - // Calls 'callback' for each output edge from node 'id' which carries a - // memref dependence. - void forEachMemRefOutputEdge(unsigned id, - const std::function &callback) { - if (outEdges.count(id) > 0) - forEachMemRefEdge(outEdges[id], callback); - } - - // Calls 'callback' for each edge in 'edges' which carries a memref - // dependence. - void forEachMemRefEdge(ArrayRef edges, - const std::function &callback) { - for (const auto &edge : edges) { - // Skip if 'edge' is not a memref dependence edge. - if (!edge.value.getType().isa()) - continue; - assert(nodes.count(edge.id) > 0); - // Skip if 'edge.id' is not a loop nest. - if (!isa(getNode(edge.id)->op)) - continue; - // Visit current input edge 'edge'. - callback(edge); - } - } - - void print(raw_ostream &os) const { - os << "\nMemRefDependenceGraph\n"; - os << "\nNodes:\n"; - for (const auto &idAndNode : nodes) { - os << "Node: " << idAndNode.first << "\n"; - auto it = inEdges.find(idAndNode.first); - if (it != inEdges.end()) { - for (const auto &e : it->second) - os << " InEdge: " << e.id << " " << e.value << "\n"; - } - it = outEdges.find(idAndNode.first); - if (it != outEdges.end()) { - for (const auto &e : it->second) - os << " OutEdge: " << e.id << " " << e.value << "\n"; - } - } - } - void dump() const { print(llvm::errs()); } - - /// The block for which this graph is created to perform fusion. - Block █ -}; - /// Returns true if node 'srcId' can be removed after fusing it with node /// 'dstId'. The node can be removed if any of the following conditions are met: /// 1. 'srcId' has no output dependences after fusion and no escaping memrefs.