diff --git a/clang/unittests/AST/ASTPrint.h b/clang/unittests/AST/ASTPrint.h --- a/clang/unittests/AST/ASTPrint.h +++ b/clang/unittests/AST/ASTPrint.h @@ -19,72 +19,95 @@ namespace clang { -using PolicyAdjusterType = - Optional>; - -static void PrintStmt(raw_ostream &Out, const ASTContext *Context, - const Stmt *S, PolicyAdjusterType PolicyAdjuster) { - assert(S != nullptr && "Expected non-null Stmt"); - PrintingPolicy Policy = Context->getPrintingPolicy(); - if (PolicyAdjuster) - (*PolicyAdjuster)(Policy); - S->printPretty(Out, /*Helper*/ nullptr, Policy); -} +using PrintingPolicyAdjuster = llvm::function_ref; + +template +using NodePrinter = + std::function; +template +using NodeFilter = std::function; + +template class PrintMatch : public ast_matchers::MatchFinder::MatchCallback { + using PrinterT = NodePrinter; + using FilterT = NodeFilter; + SmallString<1024> Printed; - unsigned NumFoundStmts; - PolicyAdjusterType PolicyAdjuster; + unsigned NumFoundNodes; + PrinterT Printer; + FilterT Filter; + PrintingPolicyAdjuster PolicyAdjuster; public: - PrintMatch(PolicyAdjusterType PolicyAdjuster) - : NumFoundStmts(0), PolicyAdjuster(PolicyAdjuster) {} + PrintMatch(PrinterT Printer, PrintingPolicyAdjuster PolicyAdjuster, + FilterT Filter) + : NumFoundNodes(0), Printer(std::move(Printer)), + Filter(std::move(Filter)), PolicyAdjuster(PolicyAdjuster) {} void run(const ast_matchers::MatchFinder::MatchResult &Result) override { - const Stmt *S = Result.Nodes.getNodeAs("id"); - if (!S) + const NodeType *N = Result.Nodes.getNodeAs("id"); + if (!N || !Filter(N)) return; - NumFoundStmts++; - if (NumFoundStmts > 1) + NumFoundNodes++; + if (NumFoundNodes > 1) return; llvm::raw_svector_ostream Out(Printed); - PrintStmt(Out, Result.Context, S, PolicyAdjuster); + Printer(Out, Result.Context, N, PolicyAdjuster); } StringRef getPrinted() const { return Printed; } - unsigned getNumFoundStmts() const { return NumFoundStmts; } + unsigned getNumFoundNodes() const { return NumFoundNodes; } }; -template -::testing::AssertionResult -PrintedStmtMatches(StringRef Code, const std::vector &Args, - const T &NodeMatch, StringRef ExpectedPrinted, - PolicyAdjusterType PolicyAdjuster = None) { +template bool NoNodeFilter(const NodeType *) { + return true; +} - PrintMatch Printer(PolicyAdjuster); +template +::testing::AssertionResult +PrintedNodeMatches(StringRef Code, const std::vector &Args, + const Matcher &NodeMatch, StringRef ExpectedPrinted, + StringRef FileName, NodePrinter Printer, + PrintingPolicyAdjuster PolicyAdjuster = nullptr, + bool AllowError = false, + // Would like to use a lambda for the default value, but that + // trips gcc 7 up. + NodeFilter Filter = &NoNodeFilter) { + + PrintMatch Callback(Printer, PolicyAdjuster, Filter); ast_matchers::MatchFinder Finder; - Finder.addMatcher(NodeMatch, &Printer); + Finder.addMatcher(NodeMatch, &Callback); std::unique_ptr Factory( tooling::newFrontendActionFactory(&Finder)); - if (!tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args)) + bool ToolResult; + if (FileName.empty()) { + ToolResult = tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args); + } else { + ToolResult = + tooling::runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName); + } + if (!ToolResult && !AllowError) return testing::AssertionFailure() << "Parsing error in \"" << Code.str() << "\""; - if (Printer.getNumFoundStmts() == 0) - return testing::AssertionFailure() << "Matcher didn't find any statements"; + if (Callback.getNumFoundNodes() == 0) + return testing::AssertionFailure() << "Matcher didn't find any nodes"; - if (Printer.getNumFoundStmts() > 1) + if (Callback.getNumFoundNodes() > 1) return testing::AssertionFailure() - << "Matcher should match only one statement (found " - << Printer.getNumFoundStmts() << ")"; + << "Matcher should match only one node (found " + << Callback.getNumFoundNodes() << ")"; - if (Printer.getPrinted() != ExpectedPrinted) + if (Callback.getPrinted() != ExpectedPrinted) return ::testing::AssertionFailure() << "Expected \"" << ExpectedPrinted.str() << "\", got \"" - << Printer.getPrinted().str() << "\""; + << Callback.getPrinted().str() << "\""; return ::testing::AssertionSuccess(); } diff --git a/clang/unittests/AST/DeclPrinterTest.cpp b/clang/unittests/AST/DeclPrinterTest.cpp --- a/clang/unittests/AST/DeclPrinterTest.cpp +++ b/clang/unittests/AST/DeclPrinterTest.cpp @@ -18,6 +18,7 @@ // //===----------------------------------------------------------------------===// +#include "ASTPrint.h" #include "clang/AST/ASTContext.h" #include "clang/ASTMatchers/ASTMatchFinder.h" #include "clang/ASTMatchers/ASTMatchers.h" @@ -32,10 +33,8 @@ namespace { -using PrintingPolicyModifier = void (*)(PrintingPolicy &policy); - void PrintDecl(raw_ostream &Out, const ASTContext *Context, const Decl *D, - PrintingPolicyModifier PolicyModifier) { + PrintingPolicyAdjuster PolicyModifier) { PrintingPolicy Policy = Context->getPrintingPolicy(); Policy.TerseOutput = true; Policy.Indentation = 0; @@ -44,74 +43,23 @@ D->print(Out, Policy, /*Indentation*/ 0, /*PrintInstantiation*/ false); } -class PrintMatch : public MatchFinder::MatchCallback { - SmallString<1024> Printed; - unsigned NumFoundDecls; - PrintingPolicyModifier PolicyModifier; - -public: - PrintMatch(PrintingPolicyModifier PolicyModifier) - : NumFoundDecls(0), PolicyModifier(PolicyModifier) {} - - void run(const MatchFinder::MatchResult &Result) override { - const Decl *D = Result.Nodes.getNodeAs("id"); - if (!D || D->isImplicit()) - return; - NumFoundDecls++; - if (NumFoundDecls > 1) - return; - - llvm::raw_svector_ostream Out(Printed); - PrintDecl(Out, Result.Context, D, PolicyModifier); - } - - StringRef getPrinted() const { - return Printed; - } - - unsigned getNumFoundDecls() const { - return NumFoundDecls; - } -}; - ::testing::AssertionResult PrintedDeclMatches(StringRef Code, const std::vector &Args, const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, StringRef FileName, - PrintingPolicyModifier PolicyModifier = nullptr, + PrintingPolicyAdjuster PolicyModifier = nullptr, bool AllowError = false) { - PrintMatch Printer(PolicyModifier); - MatchFinder Finder; - Finder.addMatcher(NodeMatch, &Printer); - std::unique_ptr Factory( - newFrontendActionFactory(&Finder)); - - if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName) && - !AllowError) - return testing::AssertionFailure() - << "Parsing error in \"" << Code.str() << "\""; - - if (Printer.getNumFoundDecls() == 0) - return testing::AssertionFailure() - << "Matcher didn't find any declarations"; - - if (Printer.getNumFoundDecls() > 1) - return testing::AssertionFailure() - << "Matcher should match only one declaration " - "(found " << Printer.getNumFoundDecls() << ")"; - - if (Printer.getPrinted() != ExpectedPrinted) - return ::testing::AssertionFailure() - << "Expected \"" << ExpectedPrinted.str() << "\", " - "got \"" << Printer.getPrinted().str() << "\""; - - return ::testing::AssertionSuccess(); + return PrintedNodeMatches( + Code, Args, NodeMatch, ExpectedPrinted, FileName, PrintDecl, + PolicyModifier, AllowError, + // Filter out implicit decls + [](const Decl *D) { return !D->isImplicit(); }); } ::testing::AssertionResult PrintedDeclCXX98Matches(StringRef Code, StringRef DeclName, StringRef ExpectedPrinted, - PrintingPolicyModifier PolicyModifier = nullptr) { + PrintingPolicyAdjuster PolicyModifier = nullptr) { std::vector Args(1, "-std=c++98"); return PrintedDeclMatches(Code, Args, namedDecl(hasName(DeclName)).bind("id"), ExpectedPrinted, "input.cc", PolicyModifier); @@ -120,7 +68,7 @@ ::testing::AssertionResult PrintedDeclCXX98Matches(StringRef Code, const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyModifier PolicyModifier = nullptr) { + PrintingPolicyAdjuster PolicyModifier = nullptr) { std::vector Args(1, "-std=c++98"); return PrintedDeclMatches(Code, Args, @@ -165,7 +113,7 @@ ::testing::AssertionResult PrintedDeclCXX17Matches(StringRef Code, const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyModifier PolicyModifier = nullptr) { + PrintingPolicyAdjuster PolicyModifier = nullptr) { std::vector Args{"-std=c++17", "-fno-delayed-template-parsing"}; return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.cc", PolicyModifier); @@ -174,7 +122,7 @@ ::testing::AssertionResult PrintedDeclC11Matches(StringRef Code, const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, - PrintingPolicyModifier PolicyModifier = nullptr) { + PrintingPolicyAdjuster PolicyModifier = nullptr) { std::vector Args(1, "-std=c11"); return PrintedDeclMatches(Code, Args, NodeMatch, ExpectedPrinted, "input.c", PolicyModifier); diff --git a/clang/unittests/AST/NamedDeclPrinterTest.cpp b/clang/unittests/AST/NamedDeclPrinterTest.cpp --- a/clang/unittests/AST/NamedDeclPrinterTest.cpp +++ b/clang/unittests/AST/NamedDeclPrinterTest.cpp @@ -15,6 +15,7 @@ // //===----------------------------------------------------------------------===// +#include "ASTPrint.h" #include "clang/AST/ASTContext.h" #include "clang/AST/Decl.h" #include "clang/AST/PrettyPrinter.h" @@ -66,31 +67,11 @@ const DeclarationMatcher &NodeMatch, StringRef ExpectedPrinted, StringRef FileName, std::function Print) { - PrintMatch Printer(std::move(Print)); - MatchFinder Finder; - Finder.addMatcher(NodeMatch, &Printer); - std::unique_ptr Factory = - newFrontendActionFactory(&Finder); - - if (!runToolOnCodeWithArgs(Factory->create(), Code, Args, FileName)) - return testing::AssertionFailure() - << "Parsing error in \"" << Code.str() << "\""; - - if (Printer.getNumFoundDecls() == 0) - return testing::AssertionFailure() - << "Matcher didn't find any named declarations"; - - if (Printer.getNumFoundDecls() > 1) - return testing::AssertionFailure() - << "Matcher should match only one named declaration " - "(found " << Printer.getNumFoundDecls() << ")"; - - if (Printer.getPrinted() != ExpectedPrinted) - return ::testing::AssertionFailure() - << "Expected \"" << ExpectedPrinted.str() << "\", " - "got \"" << Printer.getPrinted().str() << "\""; - - return ::testing::AssertionSuccess(); + return PrintedNodeMatches( + Code, Args, NodeMatch, ExpectedPrinted, FileName, + [Print](llvm::raw_ostream &Out, const ASTContext *Context, + const NamedDecl *ND, + PrintingPolicyAdjuster PolicyAdjuster) { Print(Out, ND); }); } ::testing::AssertionResult diff --git a/clang/unittests/AST/StmtPrinterTest.cpp b/clang/unittests/AST/StmtPrinterTest.cpp --- a/clang/unittests/AST/StmtPrinterTest.cpp +++ b/clang/unittests/AST/StmtPrinterTest.cpp @@ -38,11 +38,29 @@ has(compoundStmt(has(stmt().bind("id"))))); } +static void PrintStmt(raw_ostream &Out, const ASTContext *Context, + const Stmt *S, PrintingPolicyAdjuster PolicyAdjuster) { + assert(S != nullptr && "Expected non-null Stmt"); + PrintingPolicy Policy = Context->getPrintingPolicy(); + if (PolicyAdjuster) + PolicyAdjuster(Policy); + S->printPretty(Out, /*Helper*/ nullptr, Policy); +} + +template +::testing::AssertionResult +PrintedStmtMatches(StringRef Code, const std::vector &Args, + const Matcher &NodeMatch, StringRef ExpectedPrinted, + PrintingPolicyAdjuster PolicyAdjuster = nullptr) { + return PrintedNodeMatches(Code, Args, NodeMatch, ExpectedPrinted, "", + PrintStmt, PolicyAdjuster); +} + template ::testing::AssertionResult PrintedStmtCXXMatches(StdVer Standard, StringRef Code, const T &NodeMatch, StringRef ExpectedPrinted, - PolicyAdjusterType PolicyAdjuster = None) { + PrintingPolicyAdjuster PolicyAdjuster = nullptr) { const char *StdOpt; switch (Standard) { case StdVer::CXX98: StdOpt = "-std=c++98"; break; @@ -64,7 +82,7 @@ ::testing::AssertionResult PrintedStmtMSMatches(StringRef Code, const T &NodeMatch, StringRef ExpectedPrinted, - PolicyAdjusterType PolicyAdjuster = None) { + PrintingPolicyAdjuster PolicyAdjuster = nullptr) { std::vector Args = { "-std=c++98", "-target", "i686-pc-win32", @@ -79,7 +97,7 @@ ::testing::AssertionResult PrintedStmtObjCMatches(StringRef Code, const T &NodeMatch, StringRef ExpectedPrinted, - PolicyAdjusterType PolicyAdjuster = None) { + PrintingPolicyAdjuster PolicyAdjuster = nullptr) { std::vector Args = { "-ObjC", "-fobjc-runtime=macosx-10.12.0", @@ -202,10 +220,10 @@ }; )"; // No implicit 'this'. - ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11, - CPPSource, memberExpr(anything()).bind("id"), "field", - PolicyAdjusterType( - [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; }))); + ASSERT_TRUE(PrintedStmtCXXMatches( + StdVer::CXX11, CPPSource, memberExpr(anything()).bind("id"), "field", + + [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; })); // Print implicit 'this'. ASSERT_TRUE(PrintedStmtCXXMatches(StdVer::CXX11, CPPSource, memberExpr(anything()).bind("id"), "this->field")); @@ -222,11 +240,10 @@ @end )"; // No implicit 'self'. - ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"), - "return ivar;\n", - PolicyAdjusterType([](PrintingPolicy &PP) { - PP.SuppressImplicitBase = true; - }))); + ASSERT_TRUE(PrintedStmtObjCMatches( + ObjCSource, returnStmt().bind("id"), "return ivar;\n", + + [](PrintingPolicy &PP) { PP.SuppressImplicitBase = true; })); // Print implicit 'self'. ASSERT_TRUE(PrintedStmtObjCMatches(ObjCSource, returnStmt().bind("id"), "return self->ivar;\n")); @@ -243,5 +260,6 @@ // body not printed when TerseOutput is on. ASSERT_TRUE(PrintedStmtCXXMatches( StdVer::CXX11, CPPSource, lambdaExpr(anything()).bind("id"), "[] {}", - PolicyAdjusterType([](PrintingPolicy &PP) { PP.TerseOutput = true; }))); + + [](PrintingPolicy &PP) { PP.TerseOutput = true; })); }