diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -569,7 +569,17 @@ clang::PrintingPolicy PrintingPolicy; std::unique_ptr InterpContext; + ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs; + public: + ast_type_traits::TraversalKind getTraversalKind() const { return Traversal; } + void setTraversalKind(ast_type_traits::TraversalKind TK) { Traversal = TK; } + + const Expr *traverseIgnored(const Expr *E) const; + Expr *traverseIgnored(Expr *E) const; + ast_type_traits::DynTypedNode + traverseIgnored(const ast_type_traits::DynTypedNode &N) const; + IdentifierTable &Idents; SelectorTable &Selectors; Builtin::Context &BuiltinInfo; @@ -2996,7 +3006,7 @@ std::vector TraversalScope; class ParentMap; - std::unique_ptr Parents; + std::map> Parents; std::unique_ptr VTContext; @@ -3040,6 +3050,22 @@ return Ctx.Selectors.getSelector(1, &II); } +class TraversalKindScope { + ASTContext &Ctx; + ast_type_traits::TraversalKind TK = ast_type_traits::TK_AsIs; + +public: + TraversalKindScope(ASTContext &Ctx, + llvm::Optional ScopeTK) + : Ctx(Ctx) { + TK = Ctx.getTraversalKind(); + if (ScopeTK) + Ctx.setTraversalKind(*ScopeTK); + } + + ~TraversalKindScope() { Ctx.setTraversalKind(TK); } +}; + } // namespace clang // operator new and delete aren't allowed inside namespaces. diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -65,6 +65,9 @@ /// not already been loaded. bool Deserialize = false; + ast_type_traits::TraversalKind Traversal = + ast_type_traits::TraversalKind::TK_AsIs; + NodeDelegateType &getNodeDelegate() { return getDerived().doGetNodeDelegate(); } @@ -74,6 +77,8 @@ void setDeserialize(bool D) { Deserialize = D; } bool getDeserialize() const { return Deserialize; } + void SetTraversalKind(ast_type_traits::TraversalKind TK) { Traversal = TK; } + void Visit(const Decl *D) { getNodeDelegate().AddChild([=] { getNodeDelegate().Visit(D); @@ -97,8 +102,20 @@ }); } - void Visit(const Stmt *S, StringRef Label = {}) { + void Visit(const Stmt *Node, StringRef Label = {}) { getNodeDelegate().AddChild(Label, [=] { + const Stmt *S = Node; + + if (auto *E = dyn_cast_or_null(S)) { + switch (Traversal) { + case ast_type_traits::TK_AsIs: + break; + case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses: + S = E->IgnoreParenImpCasts(); + break; + } + } + getNodeDelegate().Visit(S); if (!S) { diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h --- a/clang/include/clang/ASTMatchers/ASTMatchers.h +++ b/clang/include/clang/ASTMatchers/ASTMatchers.h @@ -689,6 +689,31 @@ Builder); } +/// Causes all nested matchers to be matched with the specified traversal kind. +/// +/// Given +/// \code +/// void foo() +/// { +/// int i = 3.0; +/// } +/// \endcode +/// The matcher +/// \code +/// traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, +/// varDecl(hasInitializer(floatLiteral().bind("init"))) +/// ) +/// \endcode +/// matches the variable declaration with "init" bound to the "3.0". +template +internal::Matcher traverse(ast_type_traits::TraversalKind TK, + const internal::Matcher &InnerMatcher) { + return internal::DynTypedMatcher::constructRestrictedWrapper( + new internal::TraversalMatcher(TK, InnerMatcher), + InnerMatcher.getID().first) + .template unconditionalConvertTo(); +} + /// Matches expressions that match InnerMatcher after any implicit AST /// nodes are stripped off. /// diff --git a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h --- a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h +++ b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h @@ -283,6 +283,10 @@ virtual bool dynMatches(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const = 0; + + virtual llvm::Optional TraversalKind() const { + return llvm::None; + } }; /// Generic interface for matchers on an AST node of type T. @@ -371,6 +375,10 @@ ast_type_traits::ASTNodeKind SupportedKind, std::vector InnerMatchers); + static DynTypedMatcher + constructRestrictedWrapper(const DynTypedMatcher &InnerMatcher, + ast_type_traits::ASTNodeKind RestrictKind); + /// Get a "true" matcher for \p NodeKind. /// /// It only checks that the node is of the right kind. @@ -1002,7 +1010,7 @@ std::is_base_of::value, "unsupported type for recursive matching"); return matchesChildOf(ast_type_traits::DynTypedNode::create(Node), - Matcher, Builder, Traverse, Bind); + getASTContext(), Matcher, Builder, Traverse, Bind); } template @@ -1018,7 +1026,7 @@ std::is_base_of::value, "unsupported type for recursive matching"); return matchesDescendantOf(ast_type_traits::DynTypedNode::create(Node), - Matcher, Builder, Bind); + getASTContext(), Matcher, Builder, Bind); } // FIXME: Implement support for BindKind. @@ -1033,24 +1041,26 @@ std::is_base_of::value, "type not allowed for recursive matching"); return matchesAncestorOf(ast_type_traits::DynTypedNode::create(Node), - Matcher, Builder, MatchMode); + getASTContext(), Matcher, Builder, MatchMode); } virtual ASTContext &getASTContext() const = 0; protected: virtual bool matchesChildOf(const ast_type_traits::DynTypedNode &Node, - const DynTypedMatcher &Matcher, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, ast_type_traits::TraversalKind Traverse, BindKind Bind) = 0; virtual bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, BindKind Bind) = 0; virtual bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) = 0; @@ -1162,6 +1172,28 @@ } }; +template +class TraversalMatcher : public WrapperMatcherInterface { + ast_type_traits::TraversalKind Traversal; + +public: + explicit TraversalMatcher(ast_type_traits::TraversalKind TK, + const Matcher &ChildMatcher) + : TraversalMatcher::WrapperMatcherInterface(ChildMatcher), Traversal(TK) { + } + + bool matches(const T &Node, ASTMatchFinder *Finder, + BoundNodesTreeBuilder *Builder) const override { + return this->InnerMatcher.matches( + ast_type_traits::DynTypedNode::create(Node), Finder, Builder); + } + + llvm::Optional + TraversalKind() const override { + return Traversal; + } +}; + /// A PolymorphicMatcherWithParamN object can be /// created from N parameters p1, ..., pN (of type P1, ..., PN) and /// used as a Matcher where a MatcherT(p1, ..., pN) diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -99,6 +99,30 @@ enum FloatingRank { Float16Rank, HalfRank, FloatRank, DoubleRank, LongDoubleRank, Float128Rank }; +const Expr *ASTContext::traverseIgnored(const Expr *E) const { + return traverseIgnored(const_cast(E)); +} + +Expr *ASTContext::traverseIgnored(Expr *E) const { + if (!E) + return nullptr; + + switch (Traversal) { + case ast_type_traits::TK_AsIs: + return E; + case ast_type_traits::TK_IgnoreImplicitCastsAndParentheses: + return E->IgnoreParenImpCasts(); + } + llvm_unreachable("Invalid Traversal type!"); +} + +ast_type_traits::DynTypedNode +ASTContext::traverseIgnored(const ast_type_traits::DynTypedNode &N) const { + if (const auto *E = N.get()) { + return ast_type_traits::DynTypedNode::create(*traverseIgnored(E)); + } + return N; +} /// \returns location that is relevant when searching for Doc comments related /// to \p D. @@ -959,7 +983,7 @@ void ASTContext::setTraversalScope(const std::vector &TopLevelDecls) { TraversalScope = TopLevelDecls; - Parents.reset(); + Parents.clear(); } void ASTContext::AddDeallocation(void (*Callback)(void *), void *Data) const { @@ -10397,7 +10421,8 @@ class ASTContext::ParentMap::ASTVisitor : public RecursiveASTVisitor { public: - ASTVisitor(ParentMap &Map) : Map(Map) {} + ASTVisitor(ParentMap &Map, ASTContext &Context) + : Map(Map), Context(Context) {} private: friend class RecursiveASTVisitor; @@ -10467,9 +10492,12 @@ } bool TraverseStmt(Stmt *StmtNode) { - return TraverseNode( - StmtNode, StmtNode, [&] { return VisitorBase::TraverseStmt(StmtNode); }, - &Map.PointerParents); + Stmt *FilteredNode = StmtNode; + if (auto *ExprNode = dyn_cast_or_null(FilteredNode)) + FilteredNode = Context.traverseIgnored(ExprNode); + return TraverseNode(FilteredNode, FilteredNode, + [&] { return VisitorBase::TraverseStmt(FilteredNode); }, + &Map.PointerParents); } bool TraverseTypeLoc(TypeLoc TypeLocNode) { @@ -10487,20 +10515,22 @@ } ParentMap ⤅ + ASTContext &Context; llvm::SmallVector ParentStack; }; ASTContext::ParentMap::ParentMap(ASTContext &Ctx) { - ASTVisitor(*this).TraverseAST(Ctx); + ASTVisitor(*this, Ctx).TraverseAST(Ctx); } ASTContext::DynTypedNodeList ASTContext::getParents(const ast_type_traits::DynTypedNode &Node) { - if (!Parents) + std::unique_ptr &P = Parents[Traversal]; + if (!P) // We build the parent map for the traversal scope (usually whole TU), as // hasAncestor can escape any subtree. - Parents = std::make_unique(*this); - return Parents->getParents(Node); + P = std::make_unique(*this); + return P->getParents(Node); } bool diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp --- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp +++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -59,10 +59,12 @@ DynTypedMatcher::MatcherIDType MatcherID; ast_type_traits::DynTypedNode Node; BoundNodesTreeBuilder BoundNodes; + ast_type_traits::TraversalKind Traversal = ast_type_traits::TK_AsIs; bool operator<(const MatchKey &Other) const { - return std::tie(MatcherID, Node, BoundNodes) < - std::tie(Other.MatcherID, Other.Node, Other.BoundNodes); + return std::tie(MatcherID, Node, BoundNodes, Traversal) < + std::tie(Other.MatcherID, Other.Node, Other.BoundNodes, + Other.Traversal); } }; @@ -143,6 +145,8 @@ ScopedIncrement ScopedDepth(&CurrentDepth); Stmt *StmtToTraverse = StmtNode; + if (auto *ExprNode = dyn_cast_or_null(StmtNode)) + StmtToTraverse = Finder->getASTContext().traverseIgnored(ExprNode); if (Traversal == ast_type_traits::TraversalKind::TK_IgnoreImplicitCastsAndParentheses) { if (Expr *ExprNode = dyn_cast_or_null(StmtNode)) @@ -390,6 +394,7 @@ // Matches children or descendants of 'Node' with 'BaseMatcher'. bool memoizedMatchesRecursively(const ast_type_traits::DynTypedNode &Node, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, int MaxDepth, ast_type_traits::TraversalKind Traversal, @@ -404,6 +409,7 @@ Key.Node = Node; // Note that we key on the bindings *before* the match. Key.BoundNodes = *Builder; + Key.Traversal = Ctx.getTraversalKind(); MemoizationMap::iterator I = ResultCache.find(Key); if (I != ResultCache.end()) { @@ -446,36 +452,36 @@ // Implements ASTMatchFinder::matchesChildOf. bool matchesChildOf(const ast_type_traits::DynTypedNode &Node, - const DynTypedMatcher &Matcher, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, ast_type_traits::TraversalKind Traversal, BindKind Bind) override { if (ResultCache.size() > MaxMemoizationEntries) ResultCache.clear(); - return memoizedMatchesRecursively(Node, Matcher, Builder, 1, Traversal, + return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Traversal, Bind); } // Implements ASTMatchFinder::matchesDescendantOf. bool matchesDescendantOf(const ast_type_traits::DynTypedNode &Node, - const DynTypedMatcher &Matcher, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, BindKind Bind) override { if (ResultCache.size() > MaxMemoizationEntries) ResultCache.clear(); - return memoizedMatchesRecursively(Node, Matcher, Builder, INT_MAX, + return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX, ast_type_traits::TraversalKind::TK_AsIs, Bind); } // Implements ASTMatchFinder::matchesAncestorOf. bool matchesAncestorOf(const ast_type_traits::DynTypedNode &Node, - const DynTypedMatcher &Matcher, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) override { // Reset the cache outside of the recursive call to make sure we // don't invalidate any iterators. if (ResultCache.size() > MaxMemoizationEntries) ResultCache.clear(); - return memoizedMatchesAncestorOfRecursively(Node, Matcher, Builder, + return memoizedMatchesAncestorOfRecursively(Node, Ctx, Matcher, Builder, MatchMode); } @@ -576,7 +582,7 @@ if (EnableCheckProfiling) Timer.setBucket(&TimeByBucket[MP.second->getID()]); BoundNodesTreeBuilder Builder; - if (MP.first.matchesNoKindCheck(DynNode, this, &Builder)) { + if (MP.first.matches(DynNode, this, &Builder)) { MatchVisitor Visitor(ActiveASTContext, MP.second); Builder.visitMatches(&Visitor); } @@ -640,16 +646,19 @@ // allow simple memoization on the ancestors. Thus, we only memoize as long // as there is a single parent. bool memoizedMatchesAncestorOfRecursively( - const ast_type_traits::DynTypedNode &Node, const DynTypedMatcher &Matcher, - BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) { + const ast_type_traits::DynTypedNode &Node, ASTContext &Ctx, + const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, + AncestorMatchMode MatchMode) { // For AST-nodes that don't have an identity, we can't memoize. if (!Builder->isComparable()) - return matchesAncestorOfRecursively(Node, Matcher, Builder, MatchMode); + return matchesAncestorOfRecursively(Node, Ctx, Matcher, Builder, + MatchMode); MatchKey Key; Key.MatcherID = Matcher.getID(); Key.Node = Node; Key.BoundNodes = *Builder; + Key.Traversal = Ctx.getTraversalKind(); // Note that we cannot use insert and reuse the iterator, as recursive // calls to match might invalidate the result cache iterators. @@ -661,8 +670,8 @@ MemoizedMatchResult Result; Result.Nodes = *Builder; - Result.ResultOfMatch = - matchesAncestorOfRecursively(Node, Matcher, &Result.Nodes, MatchMode); + Result.ResultOfMatch = matchesAncestorOfRecursively( + Node, Ctx, Matcher, &Result.Nodes, MatchMode); MemoizedMatchResult &CachedResult = ResultCache[Key]; CachedResult = std::move(Result); @@ -672,6 +681,7 @@ } bool matchesAncestorOfRecursively(const ast_type_traits::DynTypedNode &Node, + ASTContext &Ctx, const DynTypedMatcher &Matcher, BoundNodesTreeBuilder *Builder, AncestorMatchMode MatchMode) { @@ -705,8 +715,8 @@ return true; } if (MatchMode != ASTMatchFinder::AMM_ParentOnly) { - return memoizedMatchesAncestorOfRecursively(Parent, Matcher, Builder, - MatchMode); + return memoizedMatchesAncestorOfRecursively(Parent, Ctx, Matcher, + Builder, MatchMode); // Once we get back from the recursive call, the result will be the // same as the parent's result. } @@ -804,8 +814,6 @@ /// kind (and derived kinds) so it is a waste to try every matcher on every /// node. /// We precalculate a list of matchers that pass the toplevel restrict check. - /// This also allows us to skip the restrict check at matching time. See - /// use \c matchesNoKindCheck() above. llvm::DenseMap> MatcherFiltersMap; diff --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp --- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp +++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp @@ -189,6 +189,14 @@ llvm_unreachable("Invalid Op value."); } +DynTypedMatcher DynTypedMatcher::constructRestrictedWrapper( + const DynTypedMatcher &InnerMatcher, + ast_type_traits::ASTNodeKind RestrictKind) { + DynTypedMatcher Copy = InnerMatcher; + Copy.RestrictKind = RestrictKind; + return Copy; +} + DynTypedMatcher DynTypedMatcher::trueMatcher( ast_type_traits::ASTNodeKind NodeKind) { return DynTypedMatcher(NodeKind, NodeKind, &*TrueMatcherInstance); @@ -211,8 +219,13 @@ bool DynTypedMatcher::matches(const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const { - if (RestrictKind.isBaseOf(DynNode.getNodeKind()) && - Implementation->dynMatches(DynNode, Finder, Builder)) { + TraversalKindScope RAII(Finder->getASTContext(), + Implementation->TraversalKind()); + + auto N = Finder->getASTContext().traverseIgnored(DynNode); + + if (RestrictKind.isBaseOf(N.getNodeKind()) && + Implementation->dynMatches(N, Finder, Builder)) { return true; } // Delete all bindings when a matcher does not match. @@ -225,8 +238,13 @@ bool DynTypedMatcher::matchesNoKindCheck( const ast_type_traits::DynTypedNode &DynNode, ASTMatchFinder *Finder, BoundNodesTreeBuilder *Builder) const { - assert(RestrictKind.isBaseOf(DynNode.getNodeKind())); - if (Implementation->dynMatches(DynNode, Finder, Builder)) { + TraversalKindScope raii(Finder->getASTContext(), + Implementation->TraversalKind()); + + auto N = Finder->getASTContext().traverseIgnored(DynNode); + + assert(RestrictKind.isBaseOf(N.getNodeKind())); + if (Implementation->dynMatches(N, Finder, Builder)) { return true; } // Delete all bindings when a matcher does not match. diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -1595,6 +1595,91 @@ notMatches("class C {}; C a = C();", varDecl(has(cxxConstructExpr())))); } +TEST(Traversal, traverseMatcher) { + + StringRef VarDeclCode = R"cpp( +void foo() +{ + int i = 3.0; +} +)cpp"; + + auto Matcher = varDecl(hasInitializer(floatLiteral())); + + EXPECT_TRUE( + notMatches(VarDeclCode, traverse(ast_type_traits::TK_AsIs, Matcher))); + EXPECT_TRUE( + matches(VarDeclCode, + traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + Matcher))); +} + +TEST(Traversal, traverseMatcherNesting) { + + StringRef Code = R"cpp( +float bar(int i) +{ + return i; +} + +void foo() +{ + bar(bar(3.0)); +} +)cpp"; + + EXPECT_TRUE(matches( + Code, + traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + callExpr(has(callExpr(traverse( + ast_type_traits::TK_AsIs, + callExpr(has(implicitCastExpr(has(floatLiteral()))))))))))); +} + +TEST(Traversal, traverseMatcherThroughImplicit) { + StringRef Code = R"cpp( +struct S { + S(int x); +}; + +void constructImplicit() { + int a = 8; + S s(a); +} + )cpp"; + + auto Matcher = traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + implicitCastExpr()); + + // Verfiy that it does not segfault + EXPECT_FALSE(matches(Code, Matcher)); +} + +TEST(Traversal, traverseMatcherThroughMemoization) { + + StringRef Code = R"cpp( +void foo() +{ + int i = 3.0; +} + )cpp"; + + auto Matcher = varDecl(hasInitializer(floatLiteral())); + + // Matchers such as hasDescendant memoize their result regarding AST + // nodes. In the matcher below, the first use of hasDescendant(Matcher) + // fails, and the use of it inside the traverse() matcher should pass + // causing the overall matcher to be a true match. + // This test verifies that the first false result is not re-used, which + // would cause the overall matcher to be incorrectly false. + + EXPECT_TRUE(matches( + Code, functionDecl(anyOf( + hasDescendant(Matcher), + traverse(ast_type_traits::TK_IgnoreImplicitCastsAndParentheses, + functionDecl(hasDescendant(Matcher))))))); +} + TEST(IgnoringImpCasts, MatchesImpCasts) { // This test checks that ignoringImpCasts matches when implicit casts are // present and its inner matcher alone does not match.