diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -40,6 +40,7 @@ #include <algorithm> #include <deque> #include <optional> +#include <unordered_set> using namespace llvm; @@ -98,7 +99,6 @@ bool Suspend = false; bool End = false; bool KillLoop = false; - bool Changed = false; }; SmallVector<BlockData, SmallVectorThreshold> Block; @@ -106,16 +106,52 @@ BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); return llvm::predecessors(BB); } + size_t pred_size(BlockData const &BD) const { + BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); + return llvm::pred_size(BB); + } + iterator_range<succ_iterator> successors(BlockData const &BD) const { + BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); + return llvm::successors(BB); + } BlockData &getBlockData(BasicBlock *BB) { return Block[Mapping.blockToIndex(BB)]; } - /// Compute the BlockData for the current function in one iteration. - /// Returns whether the BlockData changes in this iteration. - /// Initialize - Whether this is the first iteration, we can optimize - /// the initial case a little bit by manual loop switch. - template <bool Initialize = false> bool computeBlockData(); + /// This algorithm based on topological sorting. + /// As we all know that topological sorting need to be used on DAG. But CFG + /// maybe not a DAG as CFG may have loop, so we need to break the loop. + /// If current CFG is a DAG, we just need to visit once. If current CFG is not + /// a DAG, it means there have some loops. So we need to visit three times. + + /// Why we need to visit three times ? Firstly, we should to figure out how + /// Consumes info propagate in a loop. For example, + /// + /// A -> B -> C -> D -> H + /// ^ | + /// | v + /// G <- F <- E + /// + /// The first traversal(Iteration = 1) order is A, B, C, D, H, E, F, G + /// or A, B, C, D, E,H, F, G + + /// According to the Consumes info in this stage(after first traversal where + /// Iteration = 1), we can know F.Consumes[C] is true, but the C.Consumes[F] + /// is false(shall be true, but not propagate here yet). And loop node B knows + /// B.Consumes[F] is true because F will propagate his Consumes to G and G + /// will propagate to B. So we need the second traversal (Iteration = 2) to + /// propagate B's Consumes to C and we will get C.Consumes[F] is true after + /// second traversal(Iteration = 2). So for a loop, we need twice traversal to + /// get full Consumes. The third traversal(Iteration = 3) is for propagating + /// Kills info. Because Kills info is based on Consumes info when it is + /// a suspend node. If a suspend node in a loop, we need the third + /// traversal(Iteration = 3) to propagate Kills to each loop node and its + /// successors. + + /// Returns true if there exists loop in CFG. + template <int Iteration> + bool computeBlockData(size_t EntryNo, SmallVector<int> const &BlocksIndegree); public: #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -223,70 +259,118 @@ } #endif -template <bool Initialize> bool SuspendCrossingInfo::computeBlockData() { - const size_t N = Mapping.size(); - bool Changed = false; - - for (size_t I = 0; I < N; ++I) { - auto &B = Block[I]; - - // We don't need to count the predecessors when initialization. - if constexpr (!Initialize) - // If all the predecessors of the current Block don't change, - // the BlockData for the current block must not change too. - if (all_of(predecessors(B), [this](BasicBlock *BB) { - return !Block[Mapping.blockToIndex(BB)].Changed; - })) { - B.Changed = false; - continue; - } - - // Saved Consumes and Kills bitsets so that it is easy to see - // if anything changed after propagation. - auto SavedConsumes = B.Consumes; - auto SavedKills = B.Kills; - - for (BasicBlock *PI : predecessors(B)) { - auto PrevNo = Mapping.blockToIndex(PI); - auto &P = Block[PrevNo]; - - // Propagate Kills and Consumes from predecessors into B. - B.Consumes |= P.Consumes; - B.Kills |= P.Kills; - - // If block P is a suspend block, it should propagate kills into block - // B for every block P consumes. - if (P.Suspend) - B.Kills |= P.Consumes; +template <int Iteration> +bool SuspendCrossingInfo::computeBlockData( + size_t EntryNo, SmallVector<int> const &BlocksIndegree) { + static_assert(Iteration >= 1 && Iteration <= 3, "Out of range [1, 3]"); + bool Loop = false; + // Copy BlocksIndegree to Indegrees. + auto Indegrees = BlocksIndegree; + // Block CandidateQueue with indegree zero. + std::queue<size_t> CandidateQueue; + // For blocks that maybe in a loop. + std::unordered_set<size_t> MaybeLoopSet; + // Visit I. + auto visited = [&](size_t I) { + switch (Indegrees[I]) { + case 0: + break; + case 1: { + CandidateQueue.push(I); + MaybeLoopSet.erase(I); + Indegrees[I] = 0; + break; + } + default: { + Indegrees[I]--; + MaybeLoopSet.insert(I); + break; } - - if (B.Suspend) { - // If block S is a suspend block, it should kill all of the blocks it - // consumes. - B.Kills |= B.Consumes; - } else if (B.End) { - // If block B is an end block, it should not propagate kills as the - // blocks following coro.end() are reached during initial invocation - // of the coroutine while all the data are still available on the - // stack or in the registers. - B.Kills.reset(); - } else { - // This is reached when B block it not Suspend nor coro.end and it - // need to make sure that it is not in the kill set. - B.KillLoop |= B.Kills[I]; - B.Kills.reset(I); } + }; - if constexpr (!Initialize) { - B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes); - Changed |= B.Changed; + // Push EntryNo into CandidateQueue. + CandidateQueue.push(EntryNo); + + // Topological sorting. + while (!CandidateQueue.empty()) { + auto &B = Block[CandidateQueue.front()]; + CandidateQueue.pop(); + for (BasicBlock *SI : successors(B)) { + auto SuccNo = Mapping.blockToIndex(SI); + auto &S = Block[SuccNo]; + + // Propagate Kills and Consumes from predecessors into S. + // Need to progatate Consumes in Iteration < 3. + if constexpr (Iteration < 3) + S.Consumes |= B.Consumes; + S.Kills |= B.Kills; + + if (B.Suspend) + S.Kills |= B.Consumes; + + if (S.Suspend) { + // If block S is a suspend block, it should kill all of the blocks + // it consumes. + S.Kills |= S.Consumes; + } else if (S.End) { + // If block S is an end block, it should not propagate kills as the + // blocks following coro.end() are reached during initial invocation + // of the coroutine while all the data are still available on the + // stack or in the registers. + S.Kills.reset(); + } else { + // This is reached when S block it not Suspend nor coro.end and it + // need to make sure that it is not in the kill set. + S.KillLoop |= S.Kills[SuccNo]; + S.Kills.reset(SuccNo); + } + // visit SuccNo. + visited(SuccNo); + } + // If CandidateQueue is empty and MaybeLoopSet is nonempty, it means + // there exists loop and need to break it. + if (CandidateQueue.empty() && MaybeLoopSet.size()) { + Loop = true; + size_t LoopNo = -1; + if constexpr (Iteration > 1) { + // We get some loop info after first traversal(Iteration = 1). Not all + // but enough for finding true loop node. For example, + // A -> B -> C -> D -> H + // ^ | + // | v + // G <- F <- E + // The first traversal(Iteration = 1) order is A, B, C, D, H, E, F, G + // or A, B, C, D, E,H, F, G + // Node B add to MaybeLoop after visit A. If we need to make sure B + // is a loop node, we can iterate every predecessors of B to check + // whether one of them has been arrived by B. + for (auto I : MaybeLoopSet) { + for (BasicBlock *PI : llvm::predecessors(Mapping.indexToBlock(I))) { + auto PredNo = Mapping.blockToIndex(PI); + auto &P = Block[PredNo]; + // PredNo -> I exists, then check path I -> PredNo. + if (P.Consumes[I]) { + LoopNo = I; + break; + } + } + // Prevent multiple node in one loop. + if (LoopNo != size_t(-1)) + break; + } + // Must find a loop node, otherwise this is a bug. + assert(LoopNo != size_t(-1) && "A bug reached"); + } else + // Just pick one node in MaybeLoopSet as we don't know loop info when + // Iterator = 1. + LoopNo = *(MaybeLoopSet.begin()); + CandidateQueue.push(LoopNo); + MaybeLoopSet.erase(LoopNo); + Indegrees[LoopNo] = 0; } } - - if constexpr (Initialize) - return true; - - return Changed; + return Loop; } SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) @@ -294,13 +378,17 @@ const size_t N = Mapping.size(); Block.resize(N); + // Set EntryNo. + size_t EntryNo = Mapping.blockToIndex(&(F.getEntryBlock())); + SmallVector<int> BlocksIndegree(N, 0); + // Initialize every block so that it consumes itself for (size_t I = 0; I < N; ++I) { auto &B = Block[I]; B.Consumes.resize(N); B.Kills.resize(N); B.Consumes.set(I); - B.Changed = true; + BlocksIndegree[I] = pred_size(B); } // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as @@ -325,10 +413,9 @@ markSuspendBlock(Save); } - computeBlockData</*Initialize=*/true>(); - - while (computeBlockData()) - ; + computeBlockData</*Iteration=*/1>(EntryNo, BlocksIndegree) && + computeBlockData</*Iteration*/ 2>(EntryNo, BlocksIndegree) && + computeBlockData</*Iteration*/ 3>(EntryNo, BlocksIndegree); LLVM_DEBUG(dump()); }