Index: include/clang/AST/ASTContext.h =================================================================== --- include/clang/AST/ASTContext.h +++ include/clang/AST/ASTContext.h @@ -556,7 +556,18 @@ const TargetInfo *AuxTarget = nullptr; clang::PrintingPolicy PrintingPolicy; + ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs; + public: + + ast_type_traits::TraversalKind GetTraversalKind() const { return Traversal; } + void SetTraversalKind(ast_type_traits::TraversalKind TK) { Traversal = TK; } + + const Expr *TraverseIgnored(const Expr *E); + Expr *TraverseIgnored(Expr *E); + ast_type_traits::DynTypedNode + TraverseIgnored(const ast_type_traits::DynTypedNode &N); + IdentifierTable &Idents; SelectorTable &Selectors; Builtin::Context &BuiltinInfo; @@ -2942,7 +2953,7 @@ std::vector TraversalScope; class ParentMap; - std::unique_ptr Parents; + std::map> Parents; std::unique_ptr VTContext; Index: include/clang/AST/ASTNodeTraverser.h =================================================================== --- include/clang/AST/ASTNodeTraverser.h +++ include/clang/AST/ASTNodeTraverser.h @@ -65,6 +65,8 @@ /// not already been loaded. bool Deserialize = false; + ast_type_traits::TraversalKind Traversal = ast_type_traits::TraversalKind::TK_AsIs; + NodeDelegateType &getNodeDelegate() { return getDerived().doGetNodeDelegate(); } @@ -74,6 +76,11 @@ void setDeserialize(bool D) { Deserialize = D; } bool getDeserialize() const { return Deserialize; } + void SetTraversalKind(ast_type_traits::TraversalKind TK) + { + Traversal = TK; + } + void Visit(const Decl *D) { getNodeDelegate().AddChild([=] { getNodeDelegate().Visit(D); @@ -97,8 +104,23 @@ }); } - void Visit(const Stmt *S, StringRef Label = {}) { + void Visit(const Stmt *S_, StringRef Label = {}) { getNodeDelegate().AddChild(Label, [=] { + + const Stmt *S = S_; + + if (auto* E = dyn_cast_or_null(S)) + { + switch (Traversal) + { + case ast_type_traits::TK_AsIs: + break; + case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses: + S = E->IgnoreParenImpCasts(); + break; + } + } + getNodeDelegate().Visit(S); if (!S) { Index: include/clang/ASTMatchers/ASTMatchers.h =================================================================== --- include/clang/ASTMatchers/ASTMatchers.h +++ include/clang/ASTMatchers/ASTMatchers.h @@ -698,6 +698,29 @@ Builder); } +/// Causes all nested matchers to be matched with the specified traversal kind +/// +/// Given +/// \code +/// void foo() +/// { +/// int i = 3.0; +/// } +/// \endcode +/// The matcher +/// \code +/// traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, +/// varDecl(hasInitializer(floatLiteral())) +/// ) +/// \endcode +/// matches the return statement with "ret" bound to "a". +template +internal::Matcher +traverse(ast_type_traits::TraversalKind TK, + const internal::Matcher& InnerMatcher) { + return internal::Matcher(new internal::TraversalMatcher(TK, InnerMatcher)); +} + /// Matches expressions that match InnerMatcher after any implicit AST /// nodes are stripped off. /// Index: include/clang/ASTMatchers/ASTMatchersInternal.h =================================================================== --- include/clang/ASTMatchers/ASTMatchersInternal.h +++ include/clang/ASTMatchers/ASTMatchersInternal.h @@ -993,6 +993,7 @@ std::is_base_of::value, "unsupported type for recursive matching"); return matchesChildOf(ast_type_traits::DynTypedNode::create(Node), + getASTContext(), Matcher, Builder, Traverse, Bind); } @@ -1009,6 +1010,7 @@ std::is_base_of::value, "unsupported type for recursive matching"); return matchesDescendantOf(ast_type_traits::DynTypedNode::create(Node), + getASTContext(), Matcher, Builder, Bind); } @@ -1024,6 +1026,7 @@ std::is_base_of::value, "type not allowed for recursive matching"); return matchesAncestorOf(ast_type_traits::DynTypedNode::create(Node), + getASTContext(), Matcher, Builder, MatchMode); } @@ -1031,17 +1034,20 @@ protected: virtual bool matchesChildOf(const ast_type_traits::DynTypedNode &Node, + ASTContext& Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, ast_type_traits::TraversalKind Traverse, BindKind Bind) = 0; virtual bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node, + ASTContext& Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, BindKind Bind) = 0; virtual bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node, + ASTContext& Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) = 0; @@ -1153,6 +1159,28 @@ } }; +template +class TraversalMatcher : public WrapperMatcherInterface { + ast_type_traits::TraversalKind Traversal; +public: + explicit TraversalMatcher( + ast_type_traits::TraversalKind TK, + const Matcher &ChildMatcher) + : TraversalMatcher::WrapperMatcherInterface(ChildMatcher), + Traversal(TK) {} + + bool matches(const T &Node, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const override { + auto PrevTK = Finder->getASTContext().GetTraversalKind(); + Finder->getASTContext().SetTraversalKind(Traversal); + auto Result = this->InnerMatcher.matches( + ast_type_traits::DynTypedNode::create(Node), + Finder, Builder); + Finder->getASTContext().SetTraversalKind(PrevTK); + return Result; + } +}; + /// A PolymorphicMatcherWithParamN object can be /// created from N parameters p1, ..., pN (of type P1, ..., PN) and /// used as a Matcher where a MatcherT(p1, ..., pN) Index: lib/AST/ASTContext.cpp =================================================================== --- lib/AST/ASTContext.cpp +++ lib/AST/ASTContext.cpp @@ -98,6 +98,31 @@ Float16Rank, HalfRank, FloatRank, DoubleRank, LongDoubleRank, Float128Rank }; +const Expr *ASTContext::TraverseIgnored(const Expr *E) { + return TraverseIgnored(const_cast(E)); +} + +Expr *ASTContext::TraverseIgnored(Expr *E) { + if (!E) + return nullptr; + + switch(Traversal) { + case ast_type_traits::TK_AsIs: + return E; + case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses: + return E->IgnoreParenImpCasts(); + } + llvm_unreachable("Invalid Traversal type!"); +} + +ast_type_traits::DynTypedNode +ASTContext::TraverseIgnored(const ast_type_traits::DynTypedNode &N) { + if (auto E = N.get()) { + return ast_type_traits::DynTypedNode::create(*TraverseIgnored(E)); + } + return N; +} + RawComment *ASTContext::getRawCommentForDeclNoCache(const Decl *D) const { assert(D); @@ -900,7 +925,7 @@ void ASTContext::setTraversalScope(const std::vector &TopLevelDecls) { TraversalScope = TopLevelDecls; - Parents.reset(); + Parents.clear(); } void ASTContext::AddDeallocation(void (*Callback)(void*), void *Data) { @@ -10259,7 +10284,8 @@ class ASTContext::ParentMap::ASTVisitor : public RecursiveASTVisitor { public: - ASTVisitor(ParentMap &Map) : Map(Map) {} + ASTVisitor(ParentMap &Map, ASTContext& Context) + : Map(Map), Context(Context) {} private: friend class RecursiveASTVisitor; @@ -10329,9 +10355,12 @@ } bool TraverseStmt(Stmt *StmtNode) { - return TraverseNode( - StmtNode, StmtNode, [&] { return VisitorBase::TraverseStmt(StmtNode); }, - &Map.PointerParents); + auto FilteredNode = StmtNode; + if (auto* ExprNode = dyn_cast_or_null(FilteredNode)) + FilteredNode = Context.TraverseIgnored(ExprNode); + return TraverseNode(FilteredNode, FilteredNode, + [&] { return VisitorBase::TraverseStmt(FilteredNode); }, + &Map.PointerParents); } bool TraverseTypeLoc(TypeLoc TypeLocNode) { @@ -10349,20 +10378,22 @@ } ParentMap ⤅ + ASTContext& Context; llvm::SmallVector ParentStack; }; ASTContext::ParentMap::ParentMap(ASTContext &Ctx) { - ASTVisitor(*this).TraverseAST(Ctx); + ASTVisitor(*this, Ctx).TraverseAST(Ctx); } ASTContext::DynTypedNodeList ASTContext::getParents(const ast_type_traits::DynTypedNode &Node) { - if (!Parents) + std::unique_ptr& P = Parents[Traversal]; + if (!P) // We build the parent map for the traversal scope (usually whole TU), as // hasAncestor can escape any subtree. - Parents = llvm::make_unique(*this); - return Parents->getParents(Node); + P = llvm::make_unique(*this); + return P->getParents(Node); } bool Index: lib/ASTMatchers/ASTMatchFinder.cpp =================================================================== --- lib/ASTMatchers/ASTMatchFinder.cpp +++ lib/ASTMatchers/ASTMatchFinder.cpp @@ -59,10 +59,11 @@ DynTypedMatcher::MatcherIDType MatcherID; ast_type_traits::DynTypedNode Node; BoundNodesTreeBuilder BoundNodes; + ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs; bool operator<(const MatchKey &Other) const { - return std::tie(MatcherID, Node, BoundNodes) < - std::tie(Other.MatcherID, Other.Node, Other.BoundNodes); + return std::tie(MatcherID, Node, BoundNodes, Traversal) < + std::tie(Other.MatcherID, Other.Node, Other.BoundNodes, Other.Traversal); } }; @@ -151,6 +152,8 @@ ScopedIncrement ScopedDepth(&CurrentDepth); Stmt *StmtToTraverse = StmtNode; + if (Expr *ExprNode = dyn_cast_or_null(StmtNode)) + StmtToTraverse = Finder->getASTContext().TraverseIgnored(ExprNode); if (Traversal == ast_type_traits::TraversalKind::TK_IgnoreImplicitCastsAndParentheses) { if (Expr *ExprNode = dyn_cast_or_null(StmtNode)) StmtToTraverse = ExprNode->IgnoreParenImpCasts(); @@ -391,6 +394,7 @@ // Matches children or descendants of 'Node' with 'BaseMatcher'. bool memoizedMatchesRecursively(const ast_type_traits::DynTypedNode &Node, + ASTContext& Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, int MaxDepth, ast_type_traits::TraversalKind Traversal, BindKind Bind) { @@ -404,6 +408,7 @@ Key.Node = Node; // Note that we key on the bindings *before* the match. Key.BoundNodes = *Builder; + Key.Traversal = Ctx.GetTraversalKind(); MemoizationMap::iterator I = ResultCache.find(Key); if (I != ResultCache.end()) { @@ -439,27 +444,30 @@ // Implements ASTMatchFinder::matchesChildOf. bool matchesChildOf(const ast_type_traits::DynTypedNode &Node, + ASTContext& Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, ast_type_traits::TraversalKind Traversal, BindKind Bind) override { if (ResultCache.size() > MaxMemoizationEntries) ResultCache.clear(); - return memoizedMatchesRecursively(Node, Matcher, Builder, 1, Traversal, + return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Traversal, Bind); } // Implements ASTMatchFinder::matchesDescendantOf. bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node, + ASTContext& Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, BindKind Bind) override { if (ResultCache.size() > MaxMemoizationEntries) ResultCache.clear(); - return memoizedMatchesRecursively(Node, Matcher, Builder, INT_MAX, + return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX, ast_type_traits::TraversalKind::TK_AsIs, Bind); } // Implements ASTMatchFinder::matchesAncestorOf. bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node, + ASTContext& Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) override { @@ -467,7 +475,7 @@ // don't invalidate any iterators. if (ResultCache.size() > MaxMemoizationEntries) ResultCache.clear(); - return memoizedMatchesAncestorOfRecursively(Node, Matcher, Builder, + return memoizedMatchesAncestorOfRecursively(Node, Ctx, Matcher, Builder, MatchMode); } @@ -632,16 +640,19 @@ // allow simple memoization on the ancestors. Thus, we only memoize as long // as there is a single parent. bool memoizedMatchesAncestorOfRecursively( - const ast_type_traits::DynTypedNode &Node, const DynTypedMatcher &Matcher, + const ast_type_traits::DynTypedNode &Node, + ASTContext& Ctx, + const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) { // For AST-nodes that don't have an identity, we can't memoize. if (!Builder->isComparable()) - return matchesAncestorOfRecursively(Node, Matcher, Builder, MatchMode); + return matchesAncestorOfRecursively(Node, Ctx, Matcher, Builder, MatchMode); MatchKey Key; Key.MatcherID = Matcher.getID(); Key.Node = Node; Key.BoundNodes = *Builder; + Key.Traversal = Ctx.GetTraversalKind(); // Note that we cannot use insert and reuse the iterator, as recursive // calls to match might invalidate the result cache iterators. @@ -654,7 +665,7 @@ MemoizedMatchResult Result; Result.Nodes = *Builder; Result.ResultOfMatch = - matchesAncestorOfRecursively(Node, Matcher, &Result.Nodes, MatchMode); + matchesAncestorOfRecursively(Node, Ctx, Matcher, &Result.Nodes, MatchMode); MemoizedMatchResult &CachedResult = ResultCache[Key]; CachedResult = std::move(Result); @@ -664,6 +675,7 @@ } bool matchesAncestorOfRecursively(const ast_type_traits::DynTypedNode &Node, + ASTContext& Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) { @@ -697,7 +709,7 @@ return true; } if (MatchMode != ASTMatchFinder::AMM_ParentOnly) { - return memoizedMatchesAncestorOfRecursively(Parent, Matcher, Builder, + return memoizedMatchesAncestorOfRecursively(Parent, Ctx, Matcher, Builder, MatchMode); // Once we get back from the recursive call, the result will be the // same as the parent's result. Index: lib/ASTMatchers/ASTMatchersInternal.cpp =================================================================== --- lib/ASTMatchers/ASTMatchersInternal.cpp +++ lib/ASTMatchers/ASTMatchersInternal.cpp @@ -211,8 +211,11 @@ bool DynTypedMatcher::matches(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const { - if (RestrictKind.isBaseOf(DynNode.getNodeKind()) && - Implementation->dynMatches(DynNode, Finder, Builder)) { + auto N = Finder->getASTContext().TraverseIgnored(DynNode); + auto NodeKind = N.getNodeKind(); + + if (RestrictKind.isBaseOf(NodeKind) && + Implementation->dynMatches(N, Finder, Builder)) { return true; } // Delete all bindings when a matcher does not match. @@ -225,8 +228,11 @@ bool DynTypedMatcher::matchesNoKindCheck( const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const { - assert(RestrictKind.isBaseOf(DynNode.getNodeKind())); - if (Implementation->dynMatches(DynNode, Finder, Builder)) { + auto N = Finder->getASTContext().TraverseIgnored(DynNode); + auto NodeKind = N.getNodeKind(); + + assert(RestrictKind.isBaseOf(NodeKind)); + if (Implementation->dynMatches(N, Finder, Builder)) { return true; } // Delete all bindings when a matcher does not match. Index: unittests/ASTMatchers/ASTMatchersTraversalTest.cpp =================================================================== --- unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -1510,6 +1510,96 @@ notMatches("class C {}; C a = C();", varDecl(has(cxxConstructExpr())))); } +TEST(Traversal, traverseMatcher) { + + StringRef VarDeclCode = R"cpp( +void foo() +{ + int i = 3.0; +} +)cpp"; + + auto Matcher = varDecl(hasInitializer(floatLiteral())); + + EXPECT_TRUE(notMatches( + VarDeclCode, + traverse(ast_type_traits::TK_AsIs, + Matcher + ) + )); + EXPECT_TRUE(matches( + VarDeclCode, + traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + Matcher + ) + )); +} + +TEST(Traversal, traverseMatcherNesting) { + + StringRef Code = R"cpp( +float bar(int i) +{ + return i; +} + +void foo() +{ + bar(bar(3.0)); +} +)cpp"; + + EXPECT_TRUE(matches( + Code, + traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + callExpr(has( + callExpr( + traverse(ast_type_traits::TK_AsIs, + callExpr(has(implicitCastExpr(has(floatLiteral())))) + ) + ) + )) + ) + )); +} + +TEST(Traversal, traverseMatcherThroughMemoization) { + + StringRef Code = R"cpp( +void foo() +{ + int i = 3.0; +} + )cpp"; + + auto Matcher = varDecl(hasInitializer(floatLiteral())); + + // Matchers such as hasDescendant memoize their result regarding AST + // nodes. In the matcher below, the first use of hasDescendant(Matcher) + // fails, and the use of it inside the traverse() matcher should pass + // causing the overall matcher to be a true match. + // This test verifies that the first false result is not re-used, which + // would cause the overall matcher to be incorrectly false. + + EXPECT_TRUE( + matches( + Code, + functionDecl(anyOf( + hasDescendant( + Matcher + ), + traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + functionDecl( + hasDescendant( + Matcher + ) + ) + ) + )) + ) + ); +} + TEST(IgnoringImpCasts, MatchesImpCasts) { // This test checks that ignoringImpCasts matches when implicit casts are // present and its inner matcher alone does not match.