diff --git a/clang/include/clang/AST/ParentMapContext.h b/clang/include/clang/AST/ParentMapContext.h --- a/clang/include/clang/AST/ParentMapContext.h +++ b/clang/include/clang/AST/ParentMapContext.h @@ -69,7 +69,7 @@ ASTContext &ASTCtx; class ParentMap; ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs; - std::map> Parents; + std::unique_ptr Parents; }; class TraversalKindScope { diff --git a/clang/lib/AST/ParentMapContext.cpp b/clang/lib/AST/ParentMapContext.cpp --- a/clang/lib/AST/ParentMapContext.cpp +++ b/clang/lib/AST/ParentMapContext.cpp @@ -23,7 +23,7 @@ ParentMapContext::~ParentMapContext() = default; -void ParentMapContext::clear() { Parents.clear(); } +void ParentMapContext::clear() { Parents.reset(); } const Expr *ParentMapContext::traverseIgnored(const Expr *E) const { return traverseIgnored(const_cast(E)); @@ -116,11 +116,79 @@ } } - DynTypedNodeList getParents(const ast_type_traits::DynTypedNode &Node) { - if (Node.getNodeKind().hasPointerIdentity()) - return getDynNodeFromMap(Node.getMemoizationData(), PointerParents); + DynTypedNodeList getParents(ast_type_traits::TraversalKind TK, + const ast_type_traits::DynTypedNode &Node) { + if (Node.getNodeKind().hasPointerIdentity()) { + auto ParentList = + getDynNodeFromMap(Node.getMemoizationData(), PointerParents); + if (ParentList.size() == 1 && + TK == ast_type_traits::TK_IgnoreUnlessSpelledInSource) { + const Expr *E = ParentList[0].get(); + const Expr *Child = Node.get(); + if (E && Child) + return AscendIgnoreUnlessSpelledInSource(E, Child); + } + return ParentList; + } return getDynNodeFromMap(Node, OtherParents); } + + ast_type_traits::DynTypedNode + AscendIgnoreUnlessSpelledInSource(const Expr *E, const Expr *Child) { + + auto ShouldSkip = [](const Expr *E, const Expr *Child) { + if (isa(E)) + return true; + + if (isa(E)) + return true; + + if (isa(E)) + return true; + + if (isa(E)) + return true; + + if (isa(E)) + return true; + + if (isa(E)) + return true; + + auto SR = Child->getSourceRange(); + + if (auto *C = dyn_cast(E)) { + if (C->getSourceRange() == SR || !isa(C)) + return true; + } + + if (auto *C = dyn_cast(E)) { + if (C->getSourceRange() == SR) + return true; + } + + if (auto *C = dyn_cast(E)) { + if (C->getSourceRange() == SR) + return true; + } + return false; + }; + + while (ShouldSkip(E, Child)) { + auto It = PointerParents.find(E); + if (It == PointerParents.end()) + break; + auto *S = It->second.dyn_cast(); + if (!S) + return getSingleDynTypedNodeFromParentMap(It->second); + auto *P = dyn_cast(S); + if (!P) + return ast_type_traits::DynTypedNode::create(*S); + Child = E; + E = P; + } + return ast_type_traits::DynTypedNode::create(*E); + } }; /// Template specializations to abstract away from pointers and TypeLocs. @@ -151,8 +219,7 @@ class ParentMapContext::ParentMap::ASTVisitor : public RecursiveASTVisitor { public: - ASTVisitor(ParentMap &Map, ParentMapContext &MapCtx) - : Map(Map), MapCtx(MapCtx) {} + ASTVisitor(ParentMap &Map) : Map(Map) {} private: friend class RecursiveASTVisitor; @@ -222,11 +289,8 @@ } bool TraverseStmt(Stmt *StmtNode) { - Stmt *FilteredNode = StmtNode; - if (auto *ExprNode = dyn_cast_or_null(FilteredNode)) - FilteredNode = MapCtx.traverseIgnored(ExprNode); - return TraverseNode(FilteredNode, FilteredNode, - [&] { return VisitorBase::TraverseStmt(FilteredNode); }, + return TraverseNode(StmtNode, StmtNode, + [&] { return VisitorBase::TraverseStmt(StmtNode); }, &Map.PointerParents); } @@ -245,21 +309,18 @@ } ParentMap ⤅ - ParentMapContext &MapCtx; llvm::SmallVector ParentStack; }; ParentMapContext::ParentMap::ParentMap(ASTContext &Ctx) { - ASTVisitor(*this, Ctx.getParentMapContext()).TraverseAST(Ctx); + ASTVisitor(*this).TraverseAST(Ctx); } DynTypedNodeList ParentMapContext::getParents(const ast_type_traits::DynTypedNode &Node) { - std::unique_ptr &P = Parents[Traversal]; - if (!P) + if (!Parents) // We build the parent map for the traversal scope (usually whole TU), as // hasAncestor can escape any subtree. - P = std::make_unique(ASTCtx); - return P->getParents(Node); + Parents = std::make_unique(ASTCtx); + return Parents->getParents(getTraversalKind(), Node); } -