diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -868,6 +868,8 @@ ------------ - Add ``isInAnoymousNamespace`` matcher to match declarations in an anonymous namespace. +- Add ``coroutineBodyStmt`` matcher. + clang-format ------------ - Add ``RemoveSemicolon`` option for removing ``;`` after a non-empty function definition. diff --git a/clang/include/clang/ASTMatchers/ASTMatchers.h b/clang/include/clang/ASTMatchers/ASTMatchers.h --- a/clang/include/clang/ASTMatchers/ASTMatchers.h +++ b/clang/include/clang/ASTMatchers/ASTMatchers.h @@ -2449,6 +2449,17 @@ extern const internal::VariadicDynCastAllOfMatcher coyieldExpr; +/// Matches coroutine body statements. +/// +/// coroutineBodyStmt() matches the coroutine below +/// \code +/// generator gen() { +/// co_return; +/// } +/// \endcode +extern const internal::VariadicDynCastAllOfMatcher + coroutineBodyStmt; + /// Matches nullptr literal. extern const internal::VariadicDynCastAllOfMatcher cxxNullPtrLiteralExpr; @@ -5460,9 +5471,9 @@ } /// Matches a 'for', 'while', 'do while' statement or a function -/// definition that has a given body. Note that in case of functions -/// this matcher only matches the definition itself and not the other -/// declarations of the same function. +/// or coroutine definition that has a given body. Note that in case of +/// functions this matcher only matches the definition itself and not +/// the other declarations of the same function. /// /// Given /// \code @@ -5484,12 +5495,11 @@ /// matching '{}' /// but does not match 'void f();' -AST_POLYMORPHIC_MATCHER_P(hasBody, - AST_POLYMORPHIC_SUPPORTED_TYPES(DoStmt, ForStmt, - WhileStmt, - CXXForRangeStmt, - FunctionDecl), - internal::Matcher, InnerMatcher) { +AST_POLYMORPHIC_MATCHER_P( + hasBody, + AST_POLYMORPHIC_SUPPORTED_TYPES(DoStmt, ForStmt, WhileStmt, CXXForRangeStmt, + FunctionDecl, CoroutineBodyStmt), + internal::Matcher, InnerMatcher) { if (Finder->isTraversalIgnoringImplicitNodes() && isDefaultedHelper(&Node)) return false; const Stmt *const Statement = internal::GetBodyMatcher::get(Node); diff --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp --- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp +++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp @@ -909,6 +909,8 @@ const internal::VariadicDynCastAllOfMatcher caseStmt; const internal::VariadicDynCastAllOfMatcher defaultStmt; const internal::VariadicDynCastAllOfMatcher compoundStmt; +const internal::VariadicDynCastAllOfMatcher + coroutineBodyStmt; const internal::VariadicDynCastAllOfMatcher cxxCatchStmt; const internal::VariadicDynCastAllOfMatcher cxxTryStmt; const internal::VariadicDynCastAllOfMatcher cxxThrowExpr; diff --git a/clang/lib/ASTMatchers/Dynamic/Registry.cpp b/clang/lib/ASTMatchers/Dynamic/Registry.cpp --- a/clang/lib/ASTMatchers/Dynamic/Registry.cpp +++ b/clang/lib/ASTMatchers/Dynamic/Registry.cpp @@ -175,6 +175,7 @@ REGISTER_MATCHER(containsDeclaration); REGISTER_MATCHER(continueStmt); REGISTER_MATCHER(coreturnStmt); + REGISTER_MATCHER(coroutineBodyStmt); REGISTER_MATCHER(coyieldExpr); REGISTER_MATCHER(cudaKernelCallExpr); REGISTER_MATCHER(cxxBaseSpecifier); diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -678,6 +678,48 @@ EXPECT_TRUE(matchesConditionally(CoYieldCode, coyieldExpr(isExpansionInMainFile()), true, {"-std=c++20", "-I/"}, M)); + + StringRef NonCoroCode = R"cpp( +#include +void non_coro_function() { +} +)cpp"; + + EXPECT_TRUE(matchesConditionally(CoReturnCode, coroutineBodyStmt(), true, + {"-std=c++20", "-I/"}, M)); + EXPECT_TRUE(matchesConditionally(CoAwaitCode, coroutineBodyStmt(), true, + {"-std=c++20", "-I/"}, M)); + EXPECT_TRUE(matchesConditionally(CoYieldCode, coroutineBodyStmt(), true, + {"-std=c++20", "-I/"}, M)); + + EXPECT_FALSE(matchesConditionally(NonCoroCode, coroutineBodyStmt(), true, + {"-std=c++20", "-I/"}, M)); + + StringRef CoroWithDeclCode = R"cpp( +#include +void coro() { + int thevar; + co_return 1; +} +)cpp"; + EXPECT_TRUE(matchesConditionally( + CoroWithDeclCode, + coroutineBodyStmt(hasBody(compoundStmt( + has(declStmt(containsDeclaration(0, varDecl(hasName("thevar")))))))), + true, {"-std=c++20", "-I/"}, M)); + + StringRef CoroWithTryCatchDeclCode = R"cpp( +#include +void coro() try { + int thevar; + co_return 1; +} catch (...) {} +)cpp"; + EXPECT_TRUE(matchesConditionally( + CoroWithTryCatchDeclCode, + coroutineBodyStmt(hasBody(cxxTryStmt(has(compoundStmt(has( + declStmt(containsDeclaration(0, varDecl(hasName("thevar")))))))))), + true, {"-std=c++20", "-I/"}, M)); } TEST(Matcher, isClassMessage) {