diff --git a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp --- a/llvm/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/llvm/lib/Transforms/Coroutines/CoroFrame.cpp @@ -112,10 +112,12 @@ } /// Compute the BlockData for the current function in one iteration. - /// Returns whether the BlockData changes in this iteration. /// Initialize - Whether this is the first iteration, we can optimize /// the initial case a little bit by manual loop switch. - template bool computeBlockData(); + /// The parameter "RPOT" is a reverse post order. + /// Returns whether the BlockData changes in this iteration. + template + bool computeBlockData(ReversePostOrderTraversal &RPOT); public: #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -223,12 +225,15 @@ } #endif -template bool SuspendCrossingInfo::computeBlockData() { - const size_t N = Mapping.size(); +template +bool SuspendCrossingInfo::computeBlockData( + ReversePostOrderTraversal &RPOT) { bool Changed = false; - for (size_t I = 0; I < N; ++I) { - auto &B = Block[I]; + /// Use reverse post order to guide the computation. + for (auto BB : RPOT) { + auto BBNo = Mapping.blockToIndex(BB); + auto &B = Block[BBNo]; // We don't need to count the predecessors when initialization. if constexpr (!Initialize) @@ -261,7 +266,7 @@ } if (B.Suspend) { - // If block S is a suspend block, it should kill all of the blocks it + // If block B is a suspend block, it should kill all of the blocks it // consumes. B.Kills |= B.Consumes; } else if (B.End) { @@ -273,8 +278,8 @@ } else { // This is reached when B block it not Suspend nor coro.end and it // need to make sure that it is not in the kill set. - B.KillLoop |= B.Kills[I]; - B.Kills.reset(I); + B.KillLoop |= B.Kills[BBNo]; + B.Kills.reset(BBNo); } if constexpr (!Initialize) { @@ -283,9 +288,6 @@ } } - if constexpr (Initialize) - return true; - return Changed; } @@ -325,9 +327,11 @@ markSuspendBlock(Save); } - computeBlockData(); - - while (computeBlockData()) + /// Use reverse post order to guide the computation. It will lead to reach + /// fixed point faster. + ReversePostOrderTraversal RPOT(&F); + computeBlockData(RPOT); + while (computeBlockData(RPOT)) ; LLVM_DEBUG(dump());