Index: include/clang/AST/ASTContext.h =================================================================== --- include/clang/AST/ASTContext.h +++ include/clang/AST/ASTContext.h @@ -556,7 +556,17 @@ const TargetInfo *AuxTarget = nullptr; clang::PrintingPolicy PrintingPolicy; + ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs; + public: + ast_type_traits::TraversalKind GetTraversalKind() const { return Traversal; } + void SetTraversalKind(ast_type_traits::TraversalKind TK) { Traversal = TK; } + + const Expr *TraverseIgnored(const Expr *E); + Expr *TraverseIgnored(Expr *E); + ast_type_traits::DynTypedNode + TraverseIgnored(const ast_type_traits::DynTypedNode &N); + IdentifierTable &Idents; SelectorTable &Selectors; Builtin::Context &BuiltinInfo; @@ -2942,7 +2952,7 @@ std::vector TraversalScope; class ParentMap; - std::unique_ptr Parents; + std::map> Parents; std::unique_ptr VTContext; Index: include/clang/AST/ASTNodeTraverser.h =================================================================== --- include/clang/AST/ASTNodeTraverser.h +++ include/clang/AST/ASTNodeTraverser.h @@ -65,6 +65,9 @@ /// not already been loaded. bool Deserialize = false; + ast_type_traits::TraversalKind Traversal = + ast_type_traits::TraversalKind::TK_AsIs; + NodeDelegateType &getNodeDelegate() { return getDerived().doGetNodeDelegate(); } @@ -74,6 +77,8 @@ void setDeserialize(bool D) { Deserialize = D; } bool getDeserialize() const { return Deserialize; } + void SetTraversalKind(ast_type_traits::TraversalKind TK) { Traversal = TK; } + void Visit(const Decl *D) { getNodeDelegate().AddChild([=] { getNodeDelegate().Visit(D); @@ -97,8 +102,20 @@ }); } - void Visit(const Stmt *S, StringRef Label = {}) { + void Visit(const Stmt *S_, StringRef Label = {}) { getNodeDelegate().AddChild(Label, [=] { + const Stmt *S = S_; + + if (auto *E = dyn_cast_or_null(S)) { + switch (Traversal) { + case ast_type_traits::TK_AsIs: + break; + case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses: + S = E->IgnoreParenImpCasts(); + break; + } + } + getNodeDelegate().Visit(S); if (!S) { Index: include/clang/ASTMatchers/ASTMatchers.h =================================================================== --- include/clang/ASTMatchers/ASTMatchers.h +++ include/clang/ASTMatchers/ASTMatchers.h @@ -698,6 +698,29 @@ Builder); } +/// Causes all nested matchers to be matched with the specified traversal kind +/// +/// Given +/// \code +/// void foo() +/// { +/// int i = 3.0; +/// } +/// \endcode +/// The matcher +/// \code +/// traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, +/// varDecl(hasInitializer(floatLiteral())) +/// ) +/// \endcode +/// matches the return statement with "ret" bound to "a". +template +internal::Matcher traverse(ast_type_traits::TraversalKind TK, + const internal::Matcher &InnerMatcher) { + return internal::Matcher( + new internal::TraversalMatcher(TK, InnerMatcher)); +} + /// Matches expressions that match InnerMatcher after any implicit AST /// nodes are stripped off. /// Index: include/clang/ASTMatchers/ASTMatchersInternal.h =================================================================== --- include/clang/ASTMatchers/ASTMatchersInternal.h +++ include/clang/ASTMatchers/ASTMatchersInternal.h @@ -282,6 +282,10 @@ virtual bool dynMatches(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const = 0; + + virtual llvm::Optional TraversalKind() const { + return {}; + } }; /// Generic interface for matchers on an AST node of type T. @@ -991,7 +995,7 @@ std::is_base_of::value, "unsupported type for recursive matching"); return matchesChildOf(ast_type_traits::DynTypedNode::create(Node), - Matcher, Builder, Traverse, Bind); + getASTContext(), Matcher, Builder, Traverse, Bind); } template @@ -1007,7 +1011,7 @@ std::is_base_of::value, "unsupported type for recursive matching"); return matchesDescendantOf(ast_type_traits::DynTypedNode::create(Node), - Matcher, Builder, Bind); + getASTContext(), Matcher, Builder, Bind); } // FIXME: Implement support for BindKind. @@ -1022,24 +1026,26 @@ std::is_base_of::value, "type not allowed for recursive matching"); return matchesAncestorOf(ast_type_traits::DynTypedNode::create(Node), - Matcher, Builder, MatchMode); + getASTContext(), Matcher, Builder, MatchMode); } virtual ASTContext &getASTContext() const = 0; protected: virtual bool matchesChildOf(const ast_type_traits::DynTypedNode &Node, - const DynTypedMatcher &Matcher, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, ast_type_traits::TraversalKind Traverse, BindKind Bind) = 0; virtual bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, BindKind Bind) = 0; virtual bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) = 0; @@ -1151,6 +1157,28 @@ } }; +template +class TraversalMatcher : public WrapperMatcherInterface { + ast_type_traits::TraversalKind Traversal; + +public: + explicit TraversalMatcher(ast_type_traits::TraversalKind TK, + const Matcher &ChildMatcher) + : TraversalMatcher::WrapperMatcherInterface(ChildMatcher), Traversal(TK) { + } + + bool matches(const T &Node, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const override { + return this->InnerMatcher.matches( + ast_type_traits::DynTypedNode::create(Node), Finder, Builder); + } + + llvm::Optional + TraversalKind() const override { + return Traversal; + } +}; + /// A PolymorphicMatcherWithParamN object can be /// created from N parameters p1, ..., pN (of type P1, ..., PN) and /// used as a Matcher where a MatcherT(p1, ..., pN) Index: lib/AST/ASTContext.cpp =================================================================== --- lib/AST/ASTContext.cpp +++ lib/AST/ASTContext.cpp @@ -98,6 +98,31 @@ Float16Rank, HalfRank, FloatRank, DoubleRank, LongDoubleRank, Float128Rank }; +const Expr *ASTContext::TraverseIgnored(const Expr *E) { + return TraverseIgnored(const_cast(E)); +} + +Expr *ASTContext::TraverseIgnored(Expr *E) { + if (!E) + return nullptr; + + switch (Traversal) { + case ast_type_traits::TK_AsIs: + return E; + case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses: + return E->IgnoreParenImpCasts(); + } + llvm_unreachable("Invalid Traversal type!"); +} + +ast_type_traits::DynTypedNode +ASTContext::TraverseIgnored(const ast_type_traits::DynTypedNode &N) { + if (auto E = N.get()) { + return ast_type_traits::DynTypedNode::create(*TraverseIgnored(E)); + } + return N; +} + RawComment *ASTContext::getRawCommentForDeclNoCache(const Decl *D) const { assert(D); @@ -900,7 +925,7 @@ void ASTContext::setTraversalScope(const std::vector &TopLevelDecls) { TraversalScope = TopLevelDecls; - Parents.reset(); + Parents.clear(); } void ASTContext::AddDeallocation(void (*Callback)(void*), void *Data) { @@ -10259,7 +10284,8 @@ class ASTContext::ParentMap::ASTVisitor : public RecursiveASTVisitor { public: - ASTVisitor(ParentMap &Map) : Map(Map) {} + ASTVisitor(ParentMap &Map, ASTContext &Context) + : Map(Map), Context(Context) {} private: friend class RecursiveASTVisitor; @@ -10329,8 +10355,12 @@ } bool TraverseStmt(Stmt *StmtNode) { + auto FilteredNode = StmtNode; + if (auto *ExprNode = dyn_cast_or_null(FilteredNode)) + FilteredNode = Context.TraverseIgnored(ExprNode); return TraverseNode( - StmtNode, StmtNode, [&] { return VisitorBase::TraverseStmt(StmtNode); }, + FilteredNode, FilteredNode, + [&] { return VisitorBase::TraverseStmt(FilteredNode); }, &Map.PointerParents); } @@ -10349,20 +10379,22 @@ } ParentMap ⤅ + ASTContext &Context; llvm::SmallVector ParentStack; }; ASTContext::ParentMap::ParentMap(ASTContext &Ctx) { - ASTVisitor(*this).TraverseAST(Ctx); + ASTVisitor(*this, Ctx).TraverseAST(Ctx); } ASTContext::DynTypedNodeList ASTContext::getParents(const ast_type_traits::DynTypedNode &Node) { - if (!Parents) + std::unique_ptr &P = Parents[Traversal]; + if (!P) // We build the parent map for the traversal scope (usually whole TU), as // hasAncestor can escape any subtree. - Parents = llvm::make_unique(*this); - return Parents->getParents(Node); + P = llvm::make_unique(*this); + return P->getParents(Node); } bool Index: lib/ASTMatchers/ASTMatchFinder.cpp =================================================================== --- lib/ASTMatchers/ASTMatchFinder.cpp +++ lib/ASTMatchers/ASTMatchFinder.cpp @@ -59,10 +59,12 @@ DynTypedMatcher::MatcherIDType MatcherID; ast_type_traits::DynTypedNode Node; BoundNodesTreeBuilder BoundNodes; + ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs; bool operator<(const MatchKey &Other) const { - return std::tie(MatcherID, Node, BoundNodes) < - std::tie(Other.MatcherID, Other.Node, Other.BoundNodes); + return std::tie(MatcherID, Node, BoundNodes, Traversal) < + std::tie(Other.MatcherID, Other.Node, Other.BoundNodes, + Other.Traversal); } }; @@ -143,6 +145,8 @@ ScopedIncrement ScopedDepth(&CurrentDepth); Stmt *StmtToTraverse = StmtNode; + if (Expr *ExprNode = dyn_cast_or_null(StmtNode)) + StmtToTraverse = Finder->getASTContext().TraverseIgnored(ExprNode); if (Traversal == ast_type_traits::TraversalKind::TK_IgnoreImplicitCastsAndParentheses) { if (Expr *ExprNode = dyn_cast_or_null(StmtNode)) @@ -384,6 +388,7 @@ // Matches children or descendants of 'Node' with 'BaseMatcher'. bool memoizedMatchesRecursively(const ast_type_traits::DynTypedNode &Node, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, int MaxDepth, ast_type_traits::TraversalKind Traversal, @@ -398,6 +403,7 @@ Key.Node = Node; // Note that we key on the bindings *before* the match. Key.BoundNodes = *Builder; + Key.Traversal = Ctx.GetTraversalKind(); MemoizationMap::iterator I = ResultCache.find(Key); if (I != ResultCache.end()) { @@ -434,36 +440,36 @@ // Implements ASTMatchFinder::matchesChildOf. bool matchesChildOf(const ast_type_traits::DynTypedNode &Node, - const DynTypedMatcher &Matcher, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, ast_type_traits::TraversalKind Traversal, BindKind Bind) override { if (ResultCache.size() > MaxMemoizationEntries) ResultCache.clear(); - return memoizedMatchesRecursively(Node, Matcher, Builder, 1, Traversal, + return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Traversal, Bind); } // Implements ASTMatchFinder::matchesDescendantOf. bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node, - const DynTypedMatcher &Matcher, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, BindKind Bind) override { if (ResultCache.size() > MaxMemoizationEntries) ResultCache.clear(); - return memoizedMatchesRecursively(Node, Matcher, Builder, INT_MAX, + return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX, ast_type_traits::TraversalKind::TK_AsIs, Bind); } // Implements ASTMatchFinder::matchesAncestorOf. bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node, - const DynTypedMatcher &Matcher, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) override { // Reset the cache outside of the recursive call to make sure we // don't invalidate any iterators. if (ResultCache.size() > MaxMemoizationEntries) ResultCache.clear(); - return memoizedMatchesAncestorOfRecursively(Node, Matcher, Builder, + return memoizedMatchesAncestorOfRecursively(Node, Ctx, Matcher, Builder, MatchMode); } @@ -628,16 +634,19 @@ // allow simple memoization on the ancestors. Thus, we only memoize as long // as there is a single parent. bool memoizedMatchesAncestorOfRecursively( - const ast_type_traits::DynTypedNode &Node, const DynTypedMatcher &Matcher, - BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) { + const ast_type_traits::DynTypedNode &Node, ASTContext &Ctx, + const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, + AncestorMatchMode MatchMode) { // For AST-nodes that don't have an identity, we can't memoize. if (!Builder->isComparable()) - return matchesAncestorOfRecursively(Node, Matcher, Builder, MatchMode); + return matchesAncestorOfRecursively(Node, Ctx, Matcher, Builder, + MatchMode); MatchKey Key; Key.MatcherID = Matcher.getID(); Key.Node = Node; Key.BoundNodes = *Builder; + Key.Traversal = Ctx.GetTraversalKind(); // Note that we cannot use insert and reuse the iterator, as recursive // calls to match might invalidate the result cache iterators. @@ -649,8 +658,8 @@ MemoizedMatchResult Result; Result.Nodes = *Builder; - Result.ResultOfMatch = - matchesAncestorOfRecursively(Node, Matcher, &Result.Nodes, MatchMode); + Result.ResultOfMatch = matchesAncestorOfRecursively( + Node, Ctx, Matcher, &Result.Nodes, MatchMode); MemoizedMatchResult &CachedResult = ResultCache[Key]; CachedResult = std::move(Result); @@ -660,6 +669,7 @@ } bool matchesAncestorOfRecursively(const ast_type_traits::DynTypedNode &Node, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) { @@ -693,8 +703,8 @@ return true; } if (MatchMode != ASTMatchFinder::AMM_ParentOnly) { - return memoizedMatchesAncestorOfRecursively(Parent, Matcher, Builder, - MatchMode); + return memoizedMatchesAncestorOfRecursively(Parent, Ctx, Matcher, + Builder, MatchMode); // Once we get back from the recursive call, the result will be the // same as the parent's result. } Index: lib/ASTMatchers/ASTMatchersInternal.cpp =================================================================== --- lib/ASTMatchers/ASTMatchersInternal.cpp +++ lib/ASTMatchers/ASTMatchersInternal.cpp @@ -211,10 +211,19 @@ bool DynTypedMatcher::matches(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const { - if (RestrictKind.isBaseOf(DynNode.getNodeKind()) && - Implementation->dynMatches(DynNode, Finder, Builder)) { + auto PreviousTraversalKind = Finder->getASTContext().GetTraversalKind(); + auto OptTK = Implementation->TraversalKind(); + if (OptTK) + Finder->getASTContext().SetTraversalKind(*OptTK); + auto N = Finder->getASTContext().TraverseIgnored(DynNode); + auto NodeKind = N.getNodeKind(); + + if (RestrictKind.isBaseOf(NodeKind) && + Implementation->dynMatches(N, Finder, Builder)) { + Finder->getASTContext().SetTraversalKind(PreviousTraversalKind); return true; } + Finder->getASTContext().SetTraversalKind(PreviousTraversalKind); // Delete all bindings when a matcher does not match. // This prevents unexpected exposure of bound nodes in unmatches // branches of the match tree. @@ -225,8 +234,11 @@ bool DynTypedMatcher::matchesNoKindCheck( const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const { - assert(RestrictKind.isBaseOf(DynNode.getNodeKind())); - if (Implementation->dynMatches(DynNode, Finder, Builder)) { + auto N = Finder->getASTContext().TraverseIgnored(DynNode); + auto NodeKind = N.getNodeKind(); + + assert(RestrictKind.isBaseOf(NodeKind)); + if (Implementation->dynMatches(N, Finder, Builder)) { return true; } // Delete all bindings when a matcher does not match. Index: unittests/ASTMatchers/ASTMatchersTraversalTest.cpp =================================================================== --- unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -1510,6 +1510,72 @@ notMatches("class C {}; C a = C();", varDecl(has(cxxConstructExpr())))); } +TEST(Traversal, traverseMatcher) { + + StringRef VarDeclCode = R"cpp( +void foo() +{ + int i = 3.0; +} +)cpp"; + + auto Matcher = varDecl(hasInitializer(floatLiteral())); + + EXPECT_TRUE( + notMatches(VarDeclCode, traverse(ast_type_traits::TK_AsIs, Matcher))); + EXPECT_TRUE( + matches(VarDeclCode, + traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + Matcher))); +} + +TEST(Traversal, traverseMatcherNesting) { + + StringRef Code = R"cpp( +float bar(int i) +{ + return i; +} + +void foo() +{ + bar(bar(3.0)); +} +)cpp"; + + EXPECT_TRUE(matches( + Code, + traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + callExpr(has(callExpr(traverse( + ast_type_traits::TK_AsIs, + callExpr(has(implicitCastExpr(has(floatLiteral()))))))))))); +} + +TEST(Traversal, traverseMatcherThroughMemoization) { + + StringRef Code = R"cpp( +void foo() +{ + int i = 3.0; +} + )cpp"; + + auto Matcher = varDecl(hasInitializer(floatLiteral())); + + // Matchers such as hasDescendant memoize their result regarding AST + // nodes. In the matcher below, the first use of hasDescendant(Matcher) + // fails, and the use of it inside the traverse() matcher should pass + // causing the overall matcher to be a true match. + // This test verifies that the first false result is not re-used, which + // would cause the overall matcher to be incorrectly false. + + EXPECT_TRUE(matches( + Code, functionDecl(anyOf( + hasDescendant(Matcher), + traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + functionDecl(hasDescendant(Matcher))))))); +} + TEST(IgnoringImpCasts, MatchesImpCasts) { // This test checks that ignoringImpCasts matches when implicit casts are // present and its inner matcher alone does not match.