diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp --- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp +++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -243,10 +243,14 @@ return true; ScopedIncrement ScopedDepth(&CurrentDepth); if (auto *Init = Node->getInit()) - if (!match(*Init)) + if (!traverse(*Init)) return false; - if (!match(*Node->getLoopVariable()) || !match(*Node->getRangeInit()) || - !match(*Node->getBody())) + if (!match(*Node->getLoopVariable())) + return false; + if (match(*Node->getRangeInit())) + if (!VisitorBase::TraverseStmt(Node->getRangeInit())) + return false; + if (!match(*Node->getBody())) return false; return VisitorBase::TraverseStmt(Node->getBody()); } @@ -488,15 +492,21 @@ bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) { if (auto *RF = dyn_cast(S)) { - for (auto *SubStmt : RF->children()) { - if (SubStmt == RF->getInit() || SubStmt == RF->getLoopVarStmt() || - SubStmt == RF->getRangeInit() || SubStmt == RF->getBody()) { - TraverseStmt(SubStmt, Queue); - } else { - ASTNodeNotSpelledInSourceScope RAII(this, true); - TraverseStmt(SubStmt, Queue); + { + ASTNodeNotAsIsSourceScope RAII(this, true); + TraverseStmt(RF->getInit()); + // Don't traverse under the loop variable + match(*RF->getLoopVariable()); + TraverseStmt(RF->getRangeInit()); + } + { + ASTNodeNotSpelledInSourceScope RAII(this, true); + for (auto *SubStmt : RF->children()) { + if (SubStmt != RF->getBody()) + TraverseStmt(SubStmt); } } + TraverseStmt(RF->getBody()); return true; } else if (auto *RBO = dyn_cast(S)) { { diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -2820,6 +2820,36 @@ EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); } + Code = R"cpp( + struct Range { + int* begin() const; + int* end() const; + }; + Range getRange(int); + + void rangeFor() + { + for (auto i : getRange(42)) + { + } + } + )cpp"; + { + auto M = integerLiteral(equals(42)); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = callExpr(hasDescendant(integerLiteral(equals(42)))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = compoundStmt(hasDescendant(integerLiteral(equals(42)))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + Code = R"cpp( void rangeFor() { @@ -2891,6 +2921,40 @@ matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), true, {"-std=c++20"})); } + + Code = R"cpp( + struct Range { + int* begin() const; + int* end() const; + }; + Range getRange(int); + + int getNum(int); + + void rangeFor() + { + for (auto j = getNum(42); auto i : getRange(j)) + { + } + } + )cpp"; + { + auto M = integerLiteral(equals(42)); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"})); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), + true, {"-std=c++20"})); + } + { + auto M = compoundStmt(hasDescendant(integerLiteral(equals(42)))); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"})); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), + true, {"-std=c++20"})); + } + Code = R"cpp( void hasDefaultArg(int i, int j = 0) {