diff --git a/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h b/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h --- a/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h +++ b/clang-tools-extra/pseudo/include/clang-pseudo/grammar/LRTable.h @@ -40,6 +40,8 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/BitVector.h" #include "llvm/Support/Capacity.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" #include #include @@ -123,12 +125,18 @@ // Returns the state after we reduce a nonterminal. // Expected to be called by LR parsers. - // REQUIRES: Nonterminal is valid here. - StateID getGoToState(StateID State, SymbolID Nonterminal) const; + // If the nonterminal is invalid here, returns None. + llvm::Optional getGoToState(StateID State, + SymbolID Nonterminal) const { + return Gotos.get(gotoIndex(State, Nonterminal, numStates())); + } // Returns the state after we shift a terminal. // Expected to be called by LR parsers. // If the terminal is invalid here, returns None. - llvm::Optional getShiftState(StateID State, SymbolID Terminal) const; + llvm::Optional getShiftState(StateID State, + SymbolID Terminal) const { + return Shifts.get(shiftIndex(State, Terminal, numStates())); + } // Returns the possible reductions from a state. // @@ -164,9 +172,7 @@ StateID getStartState(SymbolID StartSymbol) const; size_t bytes() const { - return sizeof(*this) + llvm::capacity_in_bytes(Actions) + - llvm::capacity_in_bytes(Symbols) + - llvm::capacity_in_bytes(StateOffset) + + return sizeof(*this) + Gotos.bytes() + Shifts.bytes() + llvm::capacity_in_bytes(Reduces) + llvm::capacity_in_bytes(ReduceOffset) + llvm::capacity_in_bytes(FollowSets); @@ -194,22 +200,92 @@ llvm::ArrayRef); private: - // Looks up actions stored in the generic table. - llvm::ArrayRef find(StateID State, SymbolID Symbol) const; - - // Conceptually the LR table is a multimap from (State, SymbolID) => Action. - // Our physical representation is quite different for compactness. - - // Index is StateID, value is the offset into Symbols/Actions - // where the entries for this state begin. - // Give a state id, the corresponding half-open range of Symbols/Actions is - // [StateOffset[id], StateOffset[id+1]). - std::vector StateOffset; - // Parallel to Actions, the value is SymbolID (columns of the matrix). - // Grouped by the StateID, and only subranges are sorted. - std::vector Symbols; - // A flat list of available actions, sorted by (State, SymbolID). - std::vector Actions; + unsigned numStates() const { return ReduceOffset.size() - 1; } + + // A map from unsigned key => StateID, used to store actions. + // The keys should be sequential but the values are somewhat sparse. + // + // In practice, the keys encode (origin state, symbol) pairs, and the values + // are the state we should move to after seeing that symbol. + // + // We store one bit for presence/absence of the value for each key. + // At every 64th key, we store the offset into the table of values. + // e.g. key 0x500 is checkpoint 0x500/64 = 20 + // Checkpoints[20] = 34 + // get(0x500) = Values[34] (assuming it has a value) + // To look up values in between, we count the set bits: + // get(0x509) has a value if HasValue[20] & (1<<9) + // #values between 0x500 and 0x509: popcnt(HasValue[20] & (1<<9 - 1)) + // get(0x509) = Values[34 + popcnt(...)] + // + // Overall size is 1.25 bits/key + 16 bits/value. + // Lookup is constant time with a low factor (no hashing). + class TransitionTable { + using Word = uint64_t; + constexpr static unsigned WordBits = CHAR_BIT * sizeof(Word); + + std::vector Values; + std::vector HasValue; + std::vector Checkpoints; + + public: + TransitionTable() = default; + TransitionTable(const llvm::DenseMap &Entries, + unsigned NumKeys) { + assert( + Entries.size() < + std::numeric_limits::max() && + "16 bits too small for value offsets!"); + unsigned NumWords = (NumKeys + WordBits - 1) / WordBits; + HasValue.resize(NumWords, 0); + Checkpoints.reserve(NumWords); + Values.reserve(Entries.size()); + for (unsigned I = 0; I < NumKeys; ++I) { + if ((I % WordBits) == 0) + Checkpoints.push_back(Values.size()); + auto It = Entries.find(I); + if (It != Entries.end()) { + HasValue[I / WordBits] |= (Word(1) << (I % WordBits)); + Values.push_back(It->second); + } + } + } + + llvm::Optional get(unsigned Key) const { + // Do we have a value for this key? + Word KeyMask = Word(1) << (Key % WordBits); + unsigned KeyWord = Key / WordBits; + if ((HasValue[KeyWord] & KeyMask) == 0) + return llvm::None; + // Count the number of values since the checkpoint. + Word BelowKeyMask = KeyMask - 1; + unsigned CountSinceCheckpoint = + llvm::countPopulation(HasValue[KeyWord] & BelowKeyMask); + // Find the value relative to the last checkpoint. + return Values[Checkpoints[KeyWord] + CountSinceCheckpoint]; + } + + unsigned size() const { return Values.size(); } + + size_t bytes() const { + return llvm::capacity_in_bytes(HasValue) + + llvm::capacity_in_bytes(Values) + + llvm::capacity_in_bytes(Checkpoints); + } + }; + // Shift and Goto tables are keyed by encoded (State, Symbol). + static unsigned shiftIndex(StateID State, SymbolID Terminal, + unsigned NumStates) { + return NumStates * symbolToToken(Terminal) + State; + } + static unsigned gotoIndex(StateID State, SymbolID Nonterminal, + unsigned NumStates) { + assert(isNonterminal(Nonterminal)); + return NumStates * Nonterminal + State; + } + TransitionTable Shifts; + TransitionTable Gotos; + // A sorted table, storing the start state for each target parsing symbol. std::vector> StartStates; diff --git a/clang-tools-extra/pseudo/lib/GLR.cpp b/clang-tools-extra/pseudo/lib/GLR.cpp --- a/clang-tools-extra/pseudo/lib/GLR.cpp +++ b/clang-tools-extra/pseudo/lib/GLR.cpp @@ -318,9 +318,11 @@ do { const PushSpec &Push = Sequences.top().second; FamilySequences.emplace_back(Sequences.top().first.Rule, *Push.Seq); - for (const GSS::Node *Base : Push.LastPop->parents()) - FamilyBases.emplace_back( - Params.Table.getGoToState(Base->State, F.Symbol), Base); + for (const GSS::Node *Base : Push.LastPop->parents()) { + auto NextState = Params.Table.getGoToState(Base->State, F.Symbol); + assert(NextState.hasValue() && "goto must succeed after reduce!"); + FamilyBases.emplace_back(*NextState, Base); + } Sequences.pop(); } while (!Sequences.empty() && Sequences.top().first == F); @@ -393,8 +395,9 @@ } const ForestNode *Parsed = &Params.Forest.createSequence(Rule.Target, *RID, TempSequence); - StateID NextState = Params.Table.getGoToState(Base->State, Rule.Target); - Heads->push_back(Params.GSStack.addNode(NextState, Parsed, {Base})); + auto NextState = Params.Table.getGoToState(Base->State, Rule.Target); + assert(NextState.hasValue() && "goto must succeed after reduce!"); + Heads->push_back(Params.GSStack.addNode(*NextState, Parsed, {Base})); return true; } }; @@ -444,7 +447,8 @@ } LLVM_DEBUG(llvm::dbgs() << llvm::formatv("Reached eof\n")); - StateID AcceptState = Params.Table.getGoToState(StartState, StartSymbol); + auto AcceptState = Params.Table.getGoToState(StartState, StartSymbol); + assert(AcceptState.hasValue() && "goto must succeed after start symbol!"); const ForestNode *Result = nullptr; for (const auto *Head : Heads) { if (Head->State == AcceptState) { diff --git a/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp b/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp --- a/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp +++ b/clang-tools-extra/pseudo/lib/grammar/LRTable.cpp @@ -34,11 +34,10 @@ return llvm::formatv(R"( Statistics of the LR parsing table: number of states: {0} - number of actions: {1} - number of reduces: {2} - size of the table (bytes): {3} + number of actions: shift={1} goto={2} reduce={3} + size of the table (bytes): {4} )", - StateOffset.size() - 1, Actions.size(), Reduces.size(), + numStates(), Shifts.size(), Gotos.size(), Reduces.size(), bytes()) .str(); } @@ -47,15 +46,13 @@ std::string Result; llvm::raw_string_ostream OS(Result); OS << "LRTable:\n"; - for (StateID S = 0; S < StateOffset.size() - 1; ++S) { + for (StateID S = 0; S < numStates(); ++S) { OS << llvm::formatv("State {0}\n", S); for (uint16_t Terminal = 0; Terminal < NumTerminals; ++Terminal) { SymbolID TokID = tokenSymbol(static_cast(Terminal)); - for (auto A : find(S, TokID)) { - if (A.kind() == LRTable::Action::Shift) - OS.indent(4) << llvm::formatv("{0}: shift state {1}\n", - G.symbolName(TokID), A.getShiftState()); - } + if (auto SS = getShiftState(S, TokID)) + OS.indent(4) << llvm::formatv("{0}: shift state {1}\n", + G.symbolName(TokID), SS); } for (RuleID R : getReduceRules(S)) { SymbolID Target = G.lookupRule(R).Target; @@ -71,55 +68,15 @@ } for (SymbolID NontermID = 0; NontermID < G.table().Nonterminals.size(); ++NontermID) { - if (find(S, NontermID).empty()) - continue; - OS.indent(4) << llvm::formatv("{0}: go to state {1}\n", - G.symbolName(NontermID), - getGoToState(S, NontermID)); + if (auto GS = getGoToState(S, NontermID)) { + OS.indent(4) << llvm::formatv("{0}: go to state {1}\n", + G.symbolName(NontermID), *GS); + } } } return OS.str(); } -llvm::Optional -LRTable::getShiftState(StateID State, SymbolID Terminal) const { - // FIXME: we spend a significant amount of time on misses here. - // We could consider storing a std::bitset for a cheaper test? - assert(pseudo::isToken(Terminal) && "expected terminal symbol!"); - for (const auto &Result : find(State, Terminal)) - if (Result.kind() == Action::Shift) - return Result.getShiftState(); // unique: no shift/shift conflicts. - return llvm::None; -} - -LRTable::StateID LRTable::getGoToState(StateID State, - SymbolID Nonterminal) const { - assert(pseudo::isNonterminal(Nonterminal) && "expected nonterminal symbol!"); - auto Result = find(State, Nonterminal); - assert(Result.size() == 1 && Result.front().kind() == Action::GoTo); - return Result.front().getGoToState(); -} - -llvm::ArrayRef LRTable::find(StateID Src, SymbolID ID) const { - assert(Src + 1u < StateOffset.size()); - std::pair Range = - std::make_pair(StateOffset[Src], StateOffset[Src + 1]); - auto SymbolRange = llvm::makeArrayRef(Symbols.data() + Range.first, - Symbols.data() + Range.second); - - assert(llvm::is_sorted(SymbolRange) && - "subrange of the Symbols should be sorted!"); - const LRTable::StateID *Start = - llvm::partition_point(SymbolRange, [&ID](SymbolID S) { return S < ID; }); - if (Start == SymbolRange.end()) - return {}; - const LRTable::StateID *End = Start; - while (End != SymbolRange.end() && *End == ID) - ++End; - return llvm::makeArrayRef(&Actions[Start - Symbols.data()], - /*length=*/End - Start); -} - LRTable::StateID LRTable::getStartState(SymbolID Target) const { assert(llvm::is_sorted(StartStates) && "StartStates must be sorted!"); auto It = llvm::partition_point( diff --git a/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp b/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp --- a/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp +++ b/clang-tools-extra/pseudo/lib/grammar/LRTableBuild.cpp @@ -45,49 +45,24 @@ llvm::DenseMap> Reduces; std::vector> FollowSets; - LRTable build(unsigned NumStates) && { - // E.g. given the following parsing table with 3 states and 3 terminals: - // - // a b c - // +-------+----+-------+-+ - // |state0 | | s0,r0 | | - // |state1 | acc| | | - // |state2 | | r1 | | - // +-------+----+-------+-+ - // - // The final LRTable: - // - StateOffset: [s0] = 0, [s1] = 2, [s2] = 3, [sentinel] = 4 - // - Symbols: [ b, b, a, b] - // Actions: [ s0, r0, acc, r1] - // ~~~~~~ range for state 0 - // ~~~~ range for state 1 - // ~~ range for state 2 - // First step, we sort all entries by (State, Symbol, Action). - std::vector Sorted(Entries.begin(), Entries.end()); - llvm::sort(Sorted, [](const Entry &L, const Entry &R) { - return std::forward_as_tuple(L.State, L.Symbol, L.Act.opaque()) < - std::forward_as_tuple(R.State, R.Symbol, R.Act.opaque()); - }); - + LRTable build(unsigned NumStates, unsigned NumNonterminals) && { LRTable Table; - Table.Actions.reserve(Sorted.size()); - Table.Symbols.reserve(Sorted.size()); - // We are good to finalize the States and Actions. - for (const auto &E : Sorted) { - Table.Actions.push_back(E.Act); - Table.Symbols.push_back(E.Symbol); - } - // Initialize the terminal and nonterminal offset, all ranges are empty by - // default. - Table.StateOffset = std::vector(NumStates + 1, 0); - size_t SortedIndex = 0; - for (StateID State = 0; State < Table.StateOffset.size(); ++State) { - Table.StateOffset[State] = SortedIndex; - while (SortedIndex < Sorted.size() && Sorted[SortedIndex].State == State) - ++SortedIndex; - } Table.StartStates = std::move(StartStates); + // Compile the goto and shift actions into transition tables. + llvm::DenseMap Gotos; + llvm::DenseMap Shifts; + for (const auto &E : Entries) { + if (E.Act.kind() == Action::Shift) + Shifts.try_emplace(shiftIndex(E.State, E.Symbol, NumStates), + E.Act.getShiftState()); + else if (E.Act.kind() == Action::GoTo) + Gotos.try_emplace(gotoIndex(E.State, E.Symbol, NumStates), + E.Act.getGoToState()); + } + Table.Shifts = TransitionTable(Shifts, NumStates * NumTerminals); + Table.Gotos = TransitionTable(Gotos, NumStates * NumNonterminals); + // Compile the follow sets into a bitmap. Table.FollowSets.resize(tok::NUM_TOKENS * FollowSets.size()); for (SymbolID NT = 0; NT < FollowSets.size(); ++NT) @@ -128,7 +103,8 @@ for (const ReduceEntry &E : Reduces) Build.Reduces[E.State].insert(E.Rule); Build.FollowSets = followSets(G); - return std::move(Build).build(/*NumStates=*/MaxState + 1); + return std::move(Build).build(/*NumStates=*/MaxState + 1, + G.table().Nonterminals.size()); } LRTable LRTable::buildSLR(const Grammar &G) { @@ -156,7 +132,8 @@ Build.Reduces[SID].insert(I.rule()); } } - return std::move(Build).build(Graph.states().size()); + return std::move(Build).build(Graph.states().size(), + G.table().Nonterminals.size()); } } // namespace pseudo diff --git a/clang-tools-extra/pseudo/unittests/LRTableTest.cpp b/clang-tools-extra/pseudo/unittests/LRTableTest.cpp --- a/clang-tools-extra/pseudo/unittests/LRTableTest.cpp +++ b/clang-tools-extra/pseudo/unittests/LRTableTest.cpp @@ -60,7 +60,7 @@ EXPECT_EQ(T.getShiftState(1, Eof), llvm::None); EXPECT_EQ(T.getShiftState(1, Identifier), llvm::None); - EXPECT_EQ(T.getGoToState(1, Term), 3); + EXPECT_THAT(T.getGoToState(1, Term), ValueIs(3)); EXPECT_THAT(T.getReduceRules(1), ElementsAre(2)); // Verify the behaivor for other non-available-actions terminals.