diff --git a/llvm/include/llvm/ADT/SCCIterator.h b/llvm/include/llvm/ADT/SCCIterator.h --- a/llvm/include/llvm/ADT/SCCIterator.h +++ b/llvm/include/llvm/ADT/SCCIterator.h @@ -23,6 +23,7 @@ #define LLVM_ADT_SCCITERATOR_H #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/iterator.h" #include @@ -258,7 +259,8 @@ struct NodeInfo { NodeInfo *Group = this; uint32_t Rank = 0; - bool Visited = true; + bool Visited = false; + DenseSet IncomingMSTEdges; }; // Find the root group of the node and compress the path from node to the @@ -340,20 +342,22 @@ MSTEdges.insert(Edge); } - // Do BFS on MST, starting from nodes that have no incoming edge. These nodes - // are "roots" of the MST forest. This ensures that nodes are visited before - // their decsendents are, thus ensures hot edges are processed before cold - // edges, based on how MST is computed. + // Run Kahn's algorithm on MST to compute a topological traversal order. + // The algorithm starts from nodes that have no incoming edge. These nodes are + // "roots" of the MST forest. This ensures that nodes are visited before their + // descendants are, thus ensures hot edges are processed before cold edges, + // based on how MST is computed. + std::queue Queue; for (const auto *Edge : MSTEdges) - NodeInfoMap[Edge->Target].Visited = false; + NodeInfoMap[Edge->Target].IncomingMSTEdges.insert(Edge); - std::queue Queue; - // Initialze the queue with MST roots. Note that walking through SortedEdges - // instead of NodeInfoMap ensures an ordered deterministic push. + // Walk through SortedEdges to initialize the queue, instead of using NodeInfoMap + // to ensure an ordered deterministic push. for (auto *Edge : SortedEdges) { - if (NodeInfoMap[Edge->Source].Visited) { + if (!NodeInfoMap[Edge->Source].Visited && + NodeInfoMap[Edge->Source].IncomingMSTEdges.empty()) { Queue.push(Edge->Source); - NodeInfoMap[Edge->Source].Visited = false; + NodeInfoMap[Edge->Source].Visited = true; } } @@ -362,8 +366,9 @@ Queue.pop(); Nodes.push_back(Node); for (auto &Edge : Node->Edges) { - if (MSTEdges.count(&Edge) && !NodeInfoMap[Edge.Target].Visited) { - NodeInfoMap[Edge.Target].Visited = true; + NodeInfoMap[Edge.Target].IncomingMSTEdges.erase(&Edge); + if (MSTEdges.count(&Edge) && + NodeInfoMap[Edge.Target].IncomingMSTEdges.empty()) { Queue.push(Edge.Target); } }