diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp --- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp +++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -488,13 +488,18 @@ bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) { if (auto *RF = dyn_cast(S)) { - for (auto *SubStmt : RF->children()) { - if (SubStmt == RF->getInit() || SubStmt == RF->getLoopVarStmt() || - SubStmt == RF->getRangeInit() || SubStmt == RF->getBody()) { - TraverseStmt(SubStmt, Queue); - } else { - ASTNodeNotSpelledInSourceScope RAII(this, true); - TraverseStmt(SubStmt, Queue); + { + ASTNodeNotAsIsSourceScope RAII(this, true); + TraverseStmt(RF->getInit()); + // Don't traverse under the loop variable + match(*RF->getLoopVariable()); + TraverseStmt(RF->getRangeInit()); + TraverseStmt(RF->getBody()); + } + { + ASTNodeNotSpelledInSourceScope RAII(this, true); + for (auto *SubStmt : RF->children()) { + TraverseStmt(SubStmt); } } return true; diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -2784,6 +2784,31 @@ EXPECT_FALSE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); } + Code = R"cpp( + struct Range { + int* begin() const; + int* end() const; + }; + Range getRange(int); + + void rangeFor() + { + for (auto i : getRange(42)) + { + } + } + )cpp"; + { + auto M = integerLiteral(equals(42)); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + { + auto M = callExpr(hasDescendant(integerLiteral(equals(42)))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, M))); + EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, M))); + } + Code = R"cpp( void rangeFor() { @@ -2855,6 +2880,32 @@ matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), true, {"-std=c++20"})); } + + Code = R"cpp( + struct Range { + int* begin() const; + int* end() const; + }; + Range getRange(int); + + int getNum(int); + + void rangeFor() + { + for (auto j = getNum(42); auto i : getRange(j)) + { + } + } + )cpp"; + { + auto M = integerLiteral(equals(42)); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_AsIs, M), true, {"-std=c++20"})); + EXPECT_TRUE( + matchesConditionally(Code, traverse(TK_IgnoreUnlessSpelledInSource, M), + true, {"-std=c++20"})); + } + Code = R"cpp( void hasDefaultArg(int i, int j = 0) {