diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -40,6 +40,7 @@ #include #include #include +#include using namespace llvm; @@ -98,7 +99,6 @@ bool Suspend = false; bool End = false; bool KillLoop = false; - bool Changed = false; }; SmallVector Block; @@ -106,16 +106,51 @@ BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); return llvm::predecessors(BB); } + size_t pred_size(BlockData const &BD) const { + BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); + return llvm::pred_size(BB); + } + iterator_range successors(BlockData const &BD) const { + BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); + return llvm::successors(BB); + } BlockData &getBlockData(BasicBlock *BB) { return Block[Mapping.blockToIndex(BB)]; } - /// Compute the BlockData for the current function in one iteration. - /// Returns whether the BlockData changes in this iteration. - /// Initialize - Whether this is the first iteration, we can optimize - /// the initial case a little bit by manual loop switch. - template bool computeBlockData(); + /// This algorithm based on topological sorting. + /// As we all know that topological sorting need to be used on DAG. But CFG + /// maybe not a DAG as CFG may have back edge(loop). So we need to break the + /// back edge when we meet it. + /// If current CFG is a DAG, we just need to visit once(from entry block to + /// end block). If current CFG is not a DAG, it means there have back edge. + /// In this case, We need an extra traversal to propagate Consumes and Kills + /// info along back edge. + + /// Why we need an extra traversal when CFG exists back edge? + /// Firstly, we should to figure out how Consumes info propagate in back + /// edge. For example, + /// + /// A -> B -> C -> D -> H + /// ^ | + /// | v + /// G <- F <- E + /// + /// Following the direction of arrow, we can get the traveral sequences: + /// A, B, C, D, H, E, F, G or A, B, C, D, E,H, F, G. + /// We know that there have a path from C to G after first traversal. But we + /// don't know there exists a path from G to C or not as the Consumes info of + /// G does not yet propagate to C(via B). So we need the second traversal to + /// propagate G's Consumes info to C(via B) and its successors. We can get + /// full Consumes info after the second traversal. Since the computation about + /// Kills info depends on Consumes info, so we can compute full Kills info by + /// full Consumes info of each block in second traversal. + + /// Returns true if there exists back edges in CFG. + template + bool collectConsumeKillInfo(size_t EntryNo, + SmallVector const &BlocksIndegree); public: #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -223,70 +258,104 @@ } #endif -template bool SuspendCrossingInfo::computeBlockData() { - const size_t N = Mapping.size(); - bool Changed = false; - - for (size_t I = 0; I < N; ++I) { - auto &B = Block[I]; - - // We don't need to count the predecessors when initialization. - if constexpr (!Initialize) - // If all the predecessors of the current Block don't change, - // the BlockData for the current block must not change too. - if (all_of(predecessors(B), [this](BasicBlock *BB) { - return !Block[Mapping.blockToIndex(BB)].Changed; - })) { - B.Changed = false; - continue; - } - - // Saved Consumes and Kills bitsets so that it is easy to see - // if anything changed after propagation. - auto SavedConsumes = B.Consumes; - auto SavedKills = B.Kills; - - for (BasicBlock *PI : predecessors(B)) { - auto PrevNo = Mapping.blockToIndex(PI); - auto &P = Block[PrevNo]; - - // Propagate Kills and Consumes from predecessors into B. - B.Consumes |= P.Consumes; - B.Kills |= P.Kills; - - // If block P is a suspend block, it should propagate kills into block - // B for every block P consumes. - if (P.Suspend) - B.Kills |= P.Consumes; +template +bool SuspendCrossingInfo::collectConsumeKillInfo( + size_t EntryNo, SmallVector const &BlocksIndegree) { + bool ExistBackEdge = false; + // Copy BlocksIndegree to IndegreeOfBlocks. + auto IndegreeOfBlocks = BlocksIndegree; + // Block CandidateQueue with indegree zero. + std::queue CandidateQueue; + // For blocks that maybe has a back edge. + std::unordered_set MaybeBackEdgeSet; + // Visit I. + auto visited = [&](size_t I) { + switch (IndegreeOfBlocks[I]) { + case 0: + break; + case 1: { + CandidateQueue.push(I); + MaybeBackEdgeSet.erase(I); + IndegreeOfBlocks[I] = 0; + break; + } + default: { + IndegreeOfBlocks[I]--; + MaybeBackEdgeSet.insert(I); + break; } - - if (B.Suspend) { - // If block S is a suspend block, it should kill all of the blocks it - // consumes. - B.Kills |= B.Consumes; - } else if (B.End) { - // If block B is an end block, it should not propagate kills as the - // blocks following coro.end() are reached during initial invocation - // of the coroutine while all the data are still available on the - // stack or in the registers. - B.Kills.reset(); - } else { - // This is reached when B block it not Suspend nor coro.end and it - // need to make sure that it is not in the kill set. - B.KillLoop |= B.Kills[I]; - B.Kills.reset(I); } + }; - if constexpr (!Initialize) { - B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes); - Changed |= B.Changed; + // Push EntryNo into CandidateQueue. + CandidateQueue.push(EntryNo); + + // Topological sorting. + while (!CandidateQueue.empty()) { + auto &B = Block[CandidateQueue.front()]; + CandidateQueue.pop(); + for (BasicBlock *SI : successors(B)) { + auto SuccNo = Mapping.blockToIndex(SI); + auto &S = Block[SuccNo]; + + // Propagate Kills and Consumes from predecessors into S. + S.Consumes |= B.Consumes; + S.Kills |= B.Kills; + + if (B.Suspend) + S.Kills |= B.Consumes; + + if (S.Suspend) { + // If block S is a suspend block, it should kill all of the blocks + // it consumes. + S.Kills |= S.Consumes; + } else if (S.End) { + // If block S is an end block, it should not propagate kills as the + // blocks following coro.end() are reached during initial invocation + // of the coroutine while all the data are still available on the + // stack or in the registers. + S.Kills.reset(); + } else { + // This is reached when S block it not Suspend nor coro.end and it + // need to make sure that it is not in the kill set. + S.KillLoop |= S.Kills[SuccNo]; + S.Kills.reset(SuccNo); + } + // visit SuccNo. + visited(SuccNo); + } + // If CandidateQueue is empty and MaybeBackEdgeSet is nonempty, it means + // there exists back edge and need to break it. + if (CandidateQueue.empty() && MaybeBackEdgeSet.size()) { + ExistBackEdge = true; + size_t CandidateNo = -1; + if constexpr (HasBackEdge) { + for (auto I : MaybeBackEdgeSet) { + for (BasicBlock *PI : llvm::predecessors(Mapping.indexToBlock(I))) { + auto PredNo = Mapping.blockToIndex(PI); + auto &P = Block[PredNo]; + // PredNo -> I exists, then check path I -> PredNo. + if (P.Consumes[I]) { + CandidateNo = I; + break; + } + } + // Pick one CandidateNo. + if (CandidateNo != size_t(-1)) + break; + } + // Must find CandidateNo, otherwise this is a bug. + assert(CandidateNo != size_t(-1) && "A bug reached"); + } else + // Just pick one Block from MaybeBackEdgeSet as we don't know any back + // edge info when HasBackEdge == false. + CandidateNo = *(MaybeBackEdgeSet.begin()); + CandidateQueue.push(CandidateNo); + MaybeBackEdgeSet.erase(CandidateNo); + IndegreeOfBlocks[CandidateNo] = 0; } } - - if constexpr (Initialize) - return true; - - return Changed; + return ExistBackEdge; } SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) @@ -294,13 +363,17 @@ const size_t N = Mapping.size(); Block.resize(N); + // Set EntryNo. + size_t EntryNo = Mapping.blockToIndex(&(F.getEntryBlock())); + SmallVector BlocksIndegree(N, 0); + // Initialize every block so that it consumes itself for (size_t I = 0; I < N; ++I) { auto &B = Block[I]; B.Consumes.resize(N); B.Kills.resize(N); B.Consumes.set(I); - B.Changed = true; + BlocksIndegree[I] = pred_size(B); } // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as @@ -325,10 +398,10 @@ markSuspendBlock(Save); } - computeBlockData(); - - while (computeBlockData()) - ; + // Collect Consumes and Kills info. If there exists back edge, collect it + // again. + if (collectConsumeKillInfo(EntryNo, BlocksIndegree)) + collectConsumeKillInfo(EntryNo, BlocksIndegree); LLVM_DEBUG(dump()); }