diff --git a/clang/include/clang/Tooling/Syntax/Tree.h b/clang/include/clang/Tooling/Syntax/Tree.h --- a/clang/include/clang/Tooling/Syntax/Tree.h +++ b/clang/include/clang/Tooling/Syntax/Tree.h @@ -28,8 +28,10 @@ #include "clang/Tooling/Syntax/Tokens.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/iterator.h" #include "llvm/Support/Allocator.h" #include +#include namespace clang { namespace syntax { @@ -152,6 +154,34 @@ /// A node that has children and represents a syntactic language construct. class Tree : public Node { + /// Iterator over children (common base for const/non-const). + /// Not invalidated by tree mutations (holds a stable node pointer). + template + class ChildIteratorBase + : public llvm::iterator_facade_base { + protected: + NodeT *N = nullptr; + using Base = ChildIteratorBase; + + public: + ChildIteratorBase() = default; + explicit ChildIteratorBase(NodeT *N) : N(N) {} + + bool operator==(const DerivedT &O) const { return O.N == N; } + NodeT &operator*() const { return *N; } + DerivedT &operator++() { + N = N->getNextSibling(); + return *static_cast(this); + } + + /// Truthy if valid (not past-the-end). + /// This allows: if (auto It = find_first(N.children(), ...) ) + explicit operator bool() const { return N != nullptr; } + /// The element, or nullptr if past-the-end. + NodeT *asPointer() const { return N; } + }; + public: using Node::Node; static bool classof(const Node *N); @@ -169,6 +199,23 @@ return const_cast(this)->findLastLeaf(); } + /// child_iterator is not invalidated by mutations. + struct ChildIterator : ChildIteratorBase { + using Base::ChildIteratorBase; + }; + struct ConstChildIterator + : ChildIteratorBase { + using Base::ChildIteratorBase; + ConstChildIterator(const ChildIterator &I) : Base(I.asPointer()) {} + }; + + llvm::iterator_range getChildren() { + return {ChildIterator(getFirstChild()), ChildIterator()}; + } + llvm::iterator_range getChildren() const { + return {ConstChildIterator(getFirstChild()), ConstChildIterator()}; + } + protected: /// Find the first node with a corresponding role. Node *findChild(NodeRole R); @@ -195,6 +242,14 @@ Node *FirstChild = nullptr; }; +// Provide missing non_const == const overload. +// iterator_facade_base requires == to be a member, but implicit conversions +// don't work on the LHS of a member operator. +inline bool operator==(const Tree::ConstChildIterator &A, + const Tree::ConstChildIterator &B) { + return A.operator==(B); +} + /// A list of Elements separated or terminated by a fixed token. /// /// This type models the following grammar construct: diff --git a/clang/lib/Tooling/Syntax/Tree.cpp b/clang/lib/Tooling/Syntax/Tree.cpp --- a/clang/lib/Tooling/Syntax/Tree.cpp +++ b/clang/lib/Tooling/Syntax/Tree.cpp @@ -19,8 +19,8 @@ static void traverse(const syntax::Node *N, llvm::function_ref Visit) { if (auto *T = dyn_cast(N)) { - for (const auto *C = T->getFirstChild(); C; C = C->getNextSibling()) - traverse(C, Visit); + for (const syntax::Node &C : T->getChildren()) + traverse(&C, Visit); } Visit(N); } @@ -194,21 +194,21 @@ DumpExtraInfo(N); OS << "\n"; - for (const auto *It = T->getFirstChild(); It; It = It->getNextSibling()) { + for (const syntax::Node &It : T->getChildren()) { for (bool Filled : IndentMask) { if (Filled) OS << "| "; else OS << " "; } - if (!It->getNextSibling()) { + if (!It.getNextSibling()) { OS << "`-"; IndentMask.push_back(false); } else { OS << "|-"; IndentMask.push_back(true); } - dumpNode(OS, It, SM, IndentMask); + dumpNode(OS, &It, SM, IndentMask); IndentMask.pop_back(); } } @@ -243,22 +243,22 @@ const auto *T = dyn_cast(this); if (!T) return; - for (const auto *C = T->getFirstChild(); C; C = C->getNextSibling()) { + for (const Node &C : T->getChildren()) { if (T->isOriginal()) - assert(C->isOriginal()); - assert(!C->isDetached()); - assert(C->getParent() == T); + assert(C.isOriginal()); + assert(!C.isDetached()); + assert(C.getParent() == T); } const auto *L = dyn_cast(T); if (!L) return; - for (const auto *C = T->getFirstChild(); C; C = C->getNextSibling()) { - assert(C->getRole() == NodeRole::ListElement || - C->getRole() == NodeRole::ListDelimiter); - if (C->getRole() == NodeRole::ListDelimiter) { + for (const Node &C : T->getChildren()) { + assert(C.getRole() == NodeRole::ListElement || + C.getRole() == NodeRole::ListDelimiter); + if (C.getRole() == NodeRole::ListDelimiter) { assert(isa(C)); - assert(cast(C)->getToken()->kind() == L->getDelimiterTokenKind()); + assert(cast(C).getToken()->kind() == L->getDelimiterTokenKind()); } } @@ -272,10 +272,10 @@ } syntax::Leaf *syntax::Tree::findFirstLeaf() { - for (auto *C = getFirstChild(); C; C = C->getNextSibling()) { - if (auto *L = dyn_cast(C)) + for (Node &C : getChildren()) { + if (auto *L = dyn_cast(&C)) return L; - if (auto *L = cast(C)->findFirstLeaf()) + if (auto *L = cast(C).findFirstLeaf()) return L; } return nullptr; @@ -283,19 +283,19 @@ syntax::Leaf *syntax::Tree::findLastLeaf() { syntax::Leaf *Last = nullptr; - for (auto *C = getFirstChild(); C; C = C->getNextSibling()) { - if (auto *L = dyn_cast(C)) + for (Node &C : getChildren()) { + if (auto *L = dyn_cast(&C)) Last = L; - else if (auto *L = cast(C)->findLastLeaf()) + else if (auto *L = cast(C).findLastLeaf()) Last = L; } return Last; } syntax::Node *syntax::Tree::findChild(NodeRole R) { - for (auto *C = FirstChild; C; C = C->getNextSibling()) { - if (C->getRole() == R) - return C; + for (Node &C : getChildren()) { + if (C.getRole() == R) + return &C; } return nullptr; } @@ -318,17 +318,17 @@ std::vector> Children; syntax::Node *ElementWithoutDelimiter = nullptr; - for (auto *C = getFirstChild(); C; C = C->getNextSibling()) { - switch (C->getRole()) { + for (Node &C : getChildren()) { + switch (C.getRole()) { case syntax::NodeRole::ListElement: { if (ElementWithoutDelimiter) { Children.push_back({ElementWithoutDelimiter, nullptr}); } - ElementWithoutDelimiter = C; + ElementWithoutDelimiter = &C; break; } case syntax::NodeRole::ListDelimiter: { - Children.push_back({ElementWithoutDelimiter, cast(C)}); + Children.push_back({ElementWithoutDelimiter, cast(&C)}); ElementWithoutDelimiter = nullptr; break; } @@ -363,13 +363,13 @@ std::vector Children; syntax::Node *ElementWithoutDelimiter = nullptr; - for (auto *C = getFirstChild(); C; C = C->getNextSibling()) { - switch (C->getRole()) { + for (Node &C : getChildren()) { + switch (C.getRole()) { case syntax::NodeRole::ListElement: { if (ElementWithoutDelimiter) { Children.push_back(ElementWithoutDelimiter); } - ElementWithoutDelimiter = C; + ElementWithoutDelimiter = &C; break; } case syntax::NodeRole::ListDelimiter: { diff --git a/clang/unittests/Tooling/Syntax/TreeTest.cpp b/clang/unittests/Tooling/Syntax/TreeTest.cpp --- a/clang/unittests/Tooling/Syntax/TreeTest.cpp +++ b/clang/unittests/Tooling/Syntax/TreeTest.cpp @@ -8,6 +8,7 @@ #include "clang/Tooling/Syntax/Tree.h" #include "TreeTestBase.h" +#include "clang/Basic/SourceManager.h" #include "clang/Tooling/Syntax/BuildTree.h" #include "clang/Tooling/Syntax/Nodes.h" #include "llvm/ADT/STLExtras.h" @@ -17,6 +18,7 @@ using namespace clang::syntax; namespace { +using testing::ElementsAre; class TreeTest : public SyntaxTreeTest { private: @@ -124,6 +126,56 @@ } } +TEST_F(TreeTest, Iterators) { + buildTree("", allTestClangConfigs().front()); + std::vector Children = {createLeaf(*Arena, tok::identifier, "a"), + createLeaf(*Arena, tok::identifier, "b"), + createLeaf(*Arena, tok::identifier, "c")}; + auto *Tree = syntax::createTree(*Arena, + {{Children[0], NodeRole::LeftHandSide}, + {Children[1], NodeRole::OperatorToken}, + {Children[2], NodeRole::RightHandSide}}, + NodeKind::TranslationUnit); + const auto *ConstTree = Tree; + + auto Range = Tree->getChildren(); + EXPECT_THAT(Range, ElementsAre(role(NodeRole::LeftHandSide), + role(NodeRole::OperatorToken), + role(NodeRole::RightHandSide))); + + auto ConstRange = ConstTree->getChildren(); + EXPECT_THAT(ConstRange, ElementsAre(role(NodeRole::LeftHandSide), + role(NodeRole::OperatorToken), + role(NodeRole::RightHandSide))); + + // FIXME: mutate and observe no invalidation. Mutations are private for now... + auto It = Range.begin(); + auto CIt = ConstRange.begin(); + static_assert(std::is_same::value, + "mutable range"); + static_assert(std::is_same::value, + "const range"); + + for (unsigned I = 0; I < 3; ++I) { + EXPECT_EQ(It, CIt); + EXPECT_TRUE(It); + EXPECT_TRUE(CIt); + EXPECT_EQ(It.asPointer(), Children[I]); + EXPECT_EQ(CIt.asPointer(), Children[I]); + EXPECT_EQ(&*It, Children[I]); + EXPECT_EQ(&*CIt, Children[I]); + ++It; + ++CIt; + } + EXPECT_EQ(It, CIt); + EXPECT_EQ(It, Tree::ChildIterator()); + EXPECT_EQ(CIt, Tree::ConstChildIterator()); + EXPECT_FALSE(It); + EXPECT_FALSE(CIt); + EXPECT_EQ(nullptr, It.asPointer()); + EXPECT_EQ(nullptr, CIt.asPointer()); +} + class ListTest : public SyntaxTreeTest { private: std::string dumpQuotedTokensOrNull(const Node *N) { diff --git a/clang/unittests/Tooling/Syntax/TreeTestBase.h b/clang/unittests/Tooling/Syntax/TreeTestBase.h --- a/clang/unittests/Tooling/Syntax/TreeTestBase.h +++ b/clang/unittests/Tooling/Syntax/TreeTestBase.h @@ -20,7 +20,9 @@ #include "clang/Tooling/Syntax/Tokens.h" #include "clang/Tooling/Syntax/Tree.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/ScopedPrinter.h" #include "llvm/Testing/Support/Annotations.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" namespace clang { @@ -53,6 +55,14 @@ }; std::vector allTestClangConfigs(); + +MATCHER_P(role, R, "") { + if (arg.getRole() == R) + return true; + *result_listener << "role is " << llvm::to_string(arg.getRole()); + return false; +} + } // namespace syntax } // namespace clang #endif // LLVM_CLANG_UNITTESTS_TOOLING_SYNTAX_TREETESTBASE_H