diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -40,6 +40,7 @@ #include #include #include +#include using namespace llvm; @@ -98,7 +99,6 @@ bool Suspend = false; bool End = false; bool KillLoop = false; - bool Changed = false; }; SmallVector Block; @@ -106,17 +106,19 @@ BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); return llvm::predecessors(BB); } + size_t pred_size(BlockData const &BD) const { + BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); + return llvm::pred_size(BB); + } + iterator_range successors(BlockData const &BD) const { + BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); + return llvm::successors(BB); + } BlockData &getBlockData(BasicBlock *BB) { return Block[Mapping.blockToIndex(BB)]; } - /// Compute the BlockData for the current function in one iteration. - /// Returns whether the BlockData changes in this iteration. - /// Initialize - Whether this is the first iteration, we can optimize - /// the initial case a little bit by manual loop switch. - template bool computeBlockData(); - public: #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void dump() const; @@ -223,76 +225,35 @@ } #endif -template bool SuspendCrossingInfo::computeBlockData() { +SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) + : Mapping(F) { const size_t N = Mapping.size(); - bool Changed = false; - - for (size_t I = 0; I < N; ++I) { - auto &B = Block[I]; - - // We don't need to count the predecessors when initialization. - if constexpr (!Initialize) - // If all the predecessors of the current Block don't change, - // the BlockData for the current block must not change too. - if (all_of(predecessors(B), [this](BasicBlock *BB) { - return !Block[Mapping.blockToIndex(BB)].Changed; - })) { - B.Changed = false; - continue; - } - - // Saved Consumes and Kills bitsets so that it is easy to see - // if anything changed after propagation. - auto SavedConsumes = B.Consumes; - auto SavedKills = B.Kills; - - for (BasicBlock *PI : predecessors(B)) { - auto PrevNo = Mapping.blockToIndex(PI); - auto &P = Block[PrevNo]; - - // Propagate Kills and Consumes from predecessors into B. - B.Consumes |= P.Consumes; - B.Kills |= P.Kills; + Block.resize(N); - // If block P is a suspend block, it should propagate kills into block - // B for every block P consumes. - if (P.Suspend) - B.Kills |= P.Consumes; + std::unordered_set Visiting; + std::unordered_set MaybeLoop; + SmallVector Indegree(N, 0); + // Visit I. + auto visited = [&](size_t I) { + switch (Indegree[I]) { + case 0: + break; + case 1: { + Visiting.insert(I); + MaybeLoop.erase(I); + Indegree[I] = 0; + break; } - - if (B.Suspend) { - // If block S is a suspend block, it should kill all of the blocks it - // consumes. - B.Kills |= B.Consumes; - } else if (B.End) { - // If block B is an end block, it should not propagate kills as the - // blocks following coro.end() are reached during initial invocation - // of the coroutine while all the data are still available on the - // stack or in the registers. - B.Kills.reset(); - } else { - // This is reached when B block it not Suspend nor coro.end and it - // need to make sure that it is not in the kill set. - B.KillLoop |= B.Kills[I]; - B.Kills.reset(I); + default: { + Indegree[I]--; + MaybeLoop.insert(I); + break; } - - if constexpr (!Initialize) { - B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes); - Changed |= B.Changed; } - } - - if constexpr (Initialize) - return true; + }; - return Changed; -} - -SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) - : Mapping(F) { - const size_t N = Mapping.size(); - Block.resize(N); + // Initialize Visiting by entry block. + Visiting.insert(Mapping.blockToIndex(&(F.getEntryBlock()))); // Initialize every block so that it consumes itself for (size_t I = 0; I < N; ++I) { @@ -300,7 +261,7 @@ B.Consumes.resize(N); B.Kills.resize(N); B.Consumes.set(I); - B.Changed = true; + Indegree[I] = pred_size(B); } // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as @@ -325,10 +286,125 @@ markSuspendBlock(Save); } - computeBlockData(); + // Stash status for second and third visit. + auto VisitingSave = Visiting; + auto IndegreeSave = Indegree; + std::unordered_set TopVisiting; + bool HasLoop = true; + // This algorithm based on topological sorting. + // As we all know that topological sorting need to be used on DAG. But CFG + // maybe not a DAG as CFG may have loop, so we need to break the loop. + // If current CFG is a DAG, we just need to visit once. If current CFG is not + // a DAG, it means there have some loops. So we need to visit three times. - while (computeBlockData()) - ; + // Why we need to visit three times ? Firstly, we should to figure out how + // Consumes info propagate in a loop. For example, + // + // A -> B -> C -> D -> H + // ^ | + // | v + // G <- F <- E + // + // The first traversal order is A, B, C, D, H, E, F, G + // or A, B, C, D, E,H, F, G + + // According to the Consumes info in this stage(after first traversal), we can + // know F.Consumes[C] is true, but the C.Consumes[F] is false(shall be true, + // but not propagate here yet). And loop node B knows B.Consumes[F] is true + // because F will propagate his Consumes to G and G will propagate to B. So we + // need the second traversal to propagate B's Consumes to C and we will get + // C.Consumes[F] is true after second traversal. So for a loop, we need twice + // traversal to get full Consumes. The third traversal is for propagating + // Kills info. Because Kills info is based on Consumes info when it is suspend + // node. If a suspend node in a loop, we need the third traversal to propagate + // Kills to each loop node and its successors. + + for (int J = 0; J < 3 && HasLoop; ++J) { + HasLoop = false; + while (!Visiting.empty() || !MaybeLoop.empty()) { + // Unloop + if (Visiting.empty()) { + HasLoop = true; + size_t LoopI = -1; + if (J) { + // We get some loop info after first traversal. Not all but enough for + // finding true loop node. For example, + // A -> B -> C -> D -> H + // ^ | + // | v + // G <- F <- E + // The first traversal order is A, B, C, D, H, E, F, G + // or A, B, C, D, E,H, F, G + // Node B add to MaybeLoop after visit A. If we need to make sure B + // is a loop node, we can iterate every predecessors of B to check + // whether one of them has been arrived by B. + for (auto I : MaybeLoop) { + for (BasicBlock *PI : llvm::predecessors(Mapping.indexToBlock(I))) { + auto PredNo = Mapping.blockToIndex(PI); + auto &P = Block[PredNo]; + // PredNo -> I exists, then check path I -> PredNo. + if (P.Consumes[I]) { + LoopI = I; + break; + } + } + // Prevent multiple node in one loop. + if (LoopI != size_t(-1)) + break; + } + // Must find a loop node, otherwise this is a bug. + assert(LoopI != size_t(-1) && "A bug reached"); + } else + // Just pick one node in MaybeLoop as we don't know loop info. + LoopI = *(MaybeLoop.begin()); + Visiting.insert(LoopI); + MaybeLoop.erase(LoopI); + Indegree[LoopI] = 0; + } + TopVisiting.clear(); + TopVisiting.swap(Visiting); + for (auto I : TopVisiting) { + auto &B = Block[I]; + for (BasicBlock *SI : successors(B)) { + auto SuccNo = Mapping.blockToIndex(SI); + auto &S = Block[SuccNo]; + + // Propagate Kills and Consumes from predecessors into S. + S.Consumes |= B.Consumes; + S.Kills |= B.Kills; + + if (B.Suspend) + S.Kills |= B.Consumes; + + if (S.Suspend) { + // If block S is a suspend block, it should kill all of the blocks + // it consumes. + S.Kills |= S.Consumes; + } else if (S.End) { + // If block S is an end block, it should not propagate kills as the + // blocks following coro.end() are reached during initial invocation + // of the coroutine while all the data are still available on the + // stack or in the registers. + S.Kills.reset(); + } else { + // This is reached when S block it not Suspend nor coro.end and it + // need to make sure that it is not in the kill set. + S.KillLoop |= S.Kills[SuccNo]; + S.Kills.reset(SuccNo); + } + // visit SuccNo. + visited(SuccNo); + } + } + } + if (J < 2) { + Visiting = VisitingSave; + Indegree = IndegreeSave; + } else { + Visiting.swap(VisitingSave); + Indegree.swap(IndegreeSave); + } + } LLVM_DEBUG(dump()); }