diff --git a/clang/lib/AST/ParentMapContext.cpp b/clang/lib/AST/ParentMapContext.cpp --- a/clang/lib/AST/ParentMapContext.cpp +++ b/clang/lib/AST/ParentMapContext.cpp @@ -49,7 +49,17 @@ return N; } +template +std::tuple +matchParents(const DynTypedNodeList &NodeList, + ParentMapContext::ParentMap *ParentMap); + +template struct MatchParents; + class ParentMapContext::ParentMap { + + template friend struct ::MatchParents; + /// Contains parents of a node. using ParentVector = llvm::SmallVector; @@ -117,11 +127,75 @@ if (Node.getNodeKind().hasPointerIdentity()) { auto ParentList = getDynNodeFromMap(Node.getMemoizationData(), PointerParents); - if (ParentList.size() == 1 && TK == TK_IgnoreUnlessSpelledInSource) { - const auto *E = ParentList[0].get(); - const auto *Child = Node.get(); - if (E && Child) - return AscendIgnoreUnlessSpelledInSource(E, Child); + if (ParentList.size() > 0 && TK == TK_IgnoreUnlessSpelledInSource) { + + const auto *ChildExpr = Node.get(); + + { + // Don't match explicit node types because different stdlib + // implementations implement this in different ways and have + // different intermediate nodes. + // Look up 4 levels for a cxxRewrittenBinaryOperator + auto RewrittenBinOpParentsList = ParentList; + auto I = 0; + while (RewrittenBinOpParentsList.size() == 1 && I++ < 4) { + const auto *S = RewrittenBinOpParentsList[0].get(); + if (!S) { + break; + } + + const auto *RWBO = dyn_cast(S); + if (!RWBO) { + RewrittenBinOpParentsList = getDynNodeFromMap(S, PointerParents); + continue; + } + if (RWBO->getLHS()->IgnoreUnlessSpelledInSource() != ChildExpr && + RWBO->getRHS()->IgnoreUnlessSpelledInSource() != ChildExpr) + break; + return DynTypedNode::create(*RWBO); + } + } + + const auto *ParentExpr = ParentList[0].get(); + if (ParentExpr && ChildExpr) { + return AscendIgnoreUnlessSpelledInSource(ParentExpr, ChildExpr); + } + + { + auto AncestorNodes = + matchParents(ParentList, this); + if (std::get(AncestorNodes) && + std::get(AncestorNodes) + ->getLoopVarStmt() == + std::get(AncestorNodes)) + return std::get(AncestorNodes); + } + { + auto AncestorNodes = matchParents( + ParentList, this); + if (std::get(AncestorNodes) && + std::get(AncestorNodes) + ->getRangeStmt() == + std::get(AncestorNodes)) { + return std::get(AncestorNodes); + } + } + { + auto AncestorNodes = + matchParents(ParentList, + this); + if (std::get(AncestorNodes)) { + return std::get(AncestorNodes); + } + } + { + auto AncestorNodes = + matchParents( + ParentList, this); + if (std::get(AncestorNodes)) { + return std::get(AncestorNodes); + } + } } return ParentList; } @@ -194,6 +268,60 @@ } }; +template +auto tuple_pop_front_impl(const Tuple &tuple, std::index_sequence) { + return std::make_tuple(std::get<1 + Is>(tuple)...); +} + +template auto tuple_pop_front(const Tuple &tuple) { + return tuple_pop_front_impl( + tuple, std::make_index_sequence::value - 1>()); +} + +template struct MatchParents { + static std::tuple + match(const DynTypedNodeList &NodeList, + ParentMapContext::ParentMap *ParentMap) { + if (const auto *TypedNode = NodeList[0].get()) { + auto NextParentList = + ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents); + if (NextParentList.size() == 1) { + auto TailTuple = MatchParents::match(NextParentList, ParentMap); + if (std::get(TailTuple)) { + return std::tuple_cat( + std::make_tuple(true, std::get(TailTuple), + TypedNode), + tuple_pop_front(tuple_pop_front(TailTuple))); + } + } + } + return std::tuple_cat(std::make_tuple(false, NodeList), + std::tuple()); + } +}; + +template struct MatchParents { + static std::tuple + match(const DynTypedNodeList &NodeList, + ParentMapContext::ParentMap *ParentMap) { + if (const auto *TypedNode = NodeList[0].get()) { + auto NextParentList = + ParentMap->getDynNodeFromMap(TypedNode, ParentMap->PointerParents); + if (NextParentList.size() == 1) { + return std::make_tuple(true, NodeList, TypedNode); + } + } + return std::make_tuple(false, NodeList, nullptr); + } +}; + +template +std::tuple +matchParents(const DynTypedNodeList &NodeList, + ParentMapContext::ParentMap *ParentMap) { + return MatchParents::match(NodeList, ParentMap); +} + /// Template specializations to abstract away from pointers and TypeLocs. /// @{ template static DynTypedNode createDynTypedNode(const T &Node) { diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -2933,6 +2933,37 @@ EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); } + { + auto M = ifStmt(hasParent(compoundStmt(hasParent(cxxForRangeStmt())))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = cxxForRangeStmt( + has(varDecl(hasName("i"), hasParent(cxxForRangeStmt())))); + EXPECT_FALSE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = cxxForRangeStmt(hasDescendant(varDecl( + hasName("i"), hasParent(declStmt(hasParent(cxxForRangeStmt())))))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = cxxForRangeStmt(hasRangeInit(declRefExpr( + to(varDecl(hasName("arr"))), hasParent(cxxForRangeStmt())))); + EXPECT_FALSE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + + { + auto M = cxxForRangeStmt(hasRangeInit(declRefExpr( + to(varDecl(hasName("arr"))), hasParent(varDecl(hasParent(declStmt( + hasParent(cxxForRangeStmt())))))))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } Code = R"cpp( struct Range { @@ -3035,6 +3066,15 @@ matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), true, {"-std=c++20"})); } + { + auto M = cxxForRangeStmt(hasInitStatement(declStmt( + hasSingleDecl(varDecl(hasName("a"))), hasParent(cxxForRangeStmt())))); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"})); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), + true, {"-std=c++20"})); + } Code = R"cpp( struct Range { @@ -3511,6 +3551,20 @@ forFunction(functionDecl(hasName("func13"))))))), langCxx20OrLater())); + EXPECT_TRUE(matches(Code, + traverse(TK_IgnoreUnlessSpelledInSource, + compoundStmt(hasParent(lambdaExpr(forFunction( + functionDecl(hasName("func13"))))))), + langCxx20OrLater())); + + EXPECT_TRUE(matches( + Code, + traverse(TK_IgnoreUnlessSpelledInSource, + templateTypeParmDecl(hasName("TemplateType"), + hasParent(lambdaExpr(forFunction( + functionDecl(hasName("func14"))))))), + langCxx20OrLater())); + EXPECT_TRUE(matches( Code, traverse(TK_IgnoreUnlessSpelledInSource, @@ -3635,6 +3689,16 @@ matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), true, {"-std=c++20"})); } + { + auto M = cxxRewrittenBinaryOperator( + hasLHS(expr(hasParent(cxxRewrittenBinaryOperator()))), + hasRHS(expr(hasParent(cxxRewrittenBinaryOperator())))); + EXPECT_FALSE( + matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"})); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), + true, {"-std=c++20"})); + } { EXPECT_TRUE(matchesConditionally( Code,