diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -82,6 +82,7 @@ bool getDeserialize() const { return Deserialize; } void SetTraversalKind(TraversalKind TK) { Traversal = TK; } + TraversalKind GetTraversalKind() const { return Traversal; } void Visit(const Decl *D) { getNodeDelegate().AddChild([=] { @@ -481,8 +482,9 @@ Visit(D->getTemplatedDecl()); - for (const auto *Child : D->specializations()) - dumpTemplateDeclSpecialization(Child); + if (Traversal == TK_AsIs) + for (const auto *Child : D->specializations()) + dumpTemplateDeclSpecialization(Child); } void VisitTypeAliasDecl(const TypeAliasDecl *D) { diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -461,6 +461,13 @@ bool canIgnoreChildDeclWhileTraversingDeclContext(const Decl *Child); +#define DEF_TRAVERSE_TMPL_INST(TMPLDECLKIND) \ + bool TraverseTemplateInstantiations(TMPLDECLKIND##TemplateDecl *D); + DEF_TRAVERSE_TMPL_INST(Class) + DEF_TRAVERSE_TMPL_INST(Var) + DEF_TRAVERSE_TMPL_INST(Function) +#undef DEF_TRAVERSE_TMPL_INST + private: // These are helper methods used by more than one Traverse* method. bool TraverseTemplateParameterListHelper(TemplateParameterList *TPL); @@ -469,12 +476,6 @@ template bool TraverseDeclTemplateParameterLists(T *D); -#define DEF_TRAVERSE_TMPL_INST(TMPLDECLKIND) \ - bool TraverseTemplateInstantiations(TMPLDECLKIND##TemplateDecl *D); - DEF_TRAVERSE_TMPL_INST(Class) - DEF_TRAVERSE_TMPL_INST(Var) - DEF_TRAVERSE_TMPL_INST(Function) -#undef DEF_TRAVERSE_TMPL_INST bool TraverseTemplateArgumentLocsHelper(const TemplateArgumentLoc *TAL, unsigned Count); bool TraverseArrayTypeLocHelper(ArrayTypeLoc TL); diff --git a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h --- a/clang/include/clang/ASTMatchers/ASTMatchersInternal.h +++ b/clang/include/clang/ASTMatchers/ASTMatchersInternal.h @@ -586,6 +586,10 @@ return this->InnerMatcher.matches(DynTypedNode::create(*Node), Finder, Builder); } + + llvm::Optional TraversalKind() const override { + return this->InnerMatcher.getTraversalKind(); + } }; private: @@ -1056,6 +1060,8 @@ virtual ASTContext &getASTContext() const = 0; + virtual bool isMatchingInImplicitTemplateInstantiation() const = 0; + protected: virtual bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx, const DynTypedMatcher &Matcher, diff --git a/clang/lib/AST/ASTDumper.cpp b/clang/lib/AST/ASTDumper.cpp --- a/clang/lib/AST/ASTDumper.cpp +++ b/clang/lib/AST/ASTDumper.cpp @@ -129,9 +129,10 @@ Visit(D->getTemplatedDecl()); - for (const auto *Child : D->specializations()) - dumpTemplateDeclSpecialization(Child, DumpExplicitInst, - !D->isCanonicalDecl()); + if (GetTraversalKind() == TK_AsIs) + for (const auto *Child : D->specializations()) + dumpTemplateDeclSpecialization(Child, DumpExplicitInst, + !D->isCanonicalDecl()); } void ASTDumper::VisitFunctionTemplateDecl(const FunctionTemplateDecl *D) { diff --git a/clang/lib/ASTMatchers/ASTMatchFinder.cpp b/clang/lib/ASTMatchers/ASTMatchFinder.cpp --- a/clang/lib/ASTMatchers/ASTMatchFinder.cpp +++ b/clang/lib/ASTMatchers/ASTMatchFinder.cpp @@ -584,7 +584,45 @@ bool shouldVisitTemplateInstantiations() const { return true; } bool shouldVisitImplicitCode() const { return true; } + bool isMatchingInImplicitTemplateInstantiation() const override { + return TraversingImplicitTemplateInstantiation; + } + + bool TraverseTemplateInstantiations(ClassTemplateDecl *D) { + ImplicitTemplateInstantiationScope RAII(this, true); + return RecursiveASTVisitor::TraverseTemplateInstantiations( + D); + } + + bool TraverseTemplateInstantiations(VarTemplateDecl *D) { + ImplicitTemplateInstantiationScope RAII(this, true); + return RecursiveASTVisitor::TraverseTemplateInstantiations( + D); + } + + bool TraverseTemplateInstantiations(FunctionTemplateDecl *D) { + ImplicitTemplateInstantiationScope RAII(this, true); + return RecursiveASTVisitor::TraverseTemplateInstantiations( + D); + } + private: + bool TraversingImplicitTemplateInstantiation = false; + + struct ImplicitTemplateInstantiationScope { + ImplicitTemplateInstantiationScope(MatchASTVisitor *V, bool B) + : MV(V), MB(V->TraversingImplicitTemplateInstantiation) { + V->TraversingImplicitTemplateInstantiation = B; + } + ~ImplicitTemplateInstantiationScope() { + MV->TraversingImplicitTemplateInstantiation = MB; + } + + private: + MatchASTVisitor *MV; + bool MB; + }; + class TimeBucketRegion { public: TimeBucketRegion() : Bucket(nullptr) {} diff --git a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp --- a/clang/lib/ASTMatchers/ASTMatchersInternal.cpp +++ b/clang/lib/ASTMatchers/ASTMatchersInternal.cpp @@ -284,6 +284,13 @@ TraversalKindScope RAII(Finder->getASTContext(), Implementation->TraversalKind()); + if (Finder->getASTContext().getParentMapContext().getTraversalKind() == + TK_IgnoreUnlessSpelledInSource) { + if (Finder->isMatchingInImplicitTemplateInstantiation()) { + return false; + } + } + auto N = Finder->getASTContext().getParentMapContext().traverseIgnored(DynNode); @@ -304,6 +311,13 @@ TraversalKindScope raii(Finder->getASTContext(), Implementation->TraversalKind()); + if (Finder->getASTContext().getParentMapContext().getTraversalKind() == + TK_IgnoreUnlessSpelledInSource) { + if (Finder->isMatchingInImplicitTemplateInstantiation()) { + return false; + } + } + auto N = Finder->getASTContext().getParentMapContext().traverseIgnored(DynNode); diff --git a/clang/unittests/AST/ASTTraverserTest.cpp b/clang/unittests/AST/ASTTraverserTest.cpp --- a/clang/unittests/AST/ASTTraverserTest.cpp +++ b/clang/unittests/AST/ASTTraverserTest.cpp @@ -68,6 +68,14 @@ void Visit(const TemplateArgument &A, SourceRange R = {}, const Decl *From = nullptr, const char *Label = nullptr) { OS << "TemplateArgument"; + switch (A.getKind()) { + case TemplateArgument::Type: { + OS << " type " << A.getAsType().getAsString(); + break; + } + default: + break; + } } template void Visit(T...) {} @@ -243,7 +251,7 @@ verifyWithDynNode(TA, R"cpp( -TemplateArgument +TemplateArgument type int `-BuiltinType )cpp"); @@ -1042,4 +1050,145 @@ } } +TEST(Traverse, IgnoreUnlessSpelledInSourceTemplateInstantiations) { + + auto AST = buildASTFromCode(R"cpp( +template +struct TemplStruct { + TemplStruct() {} + ~TemplStruct() {} + +private: + T m_t; +}; + +template +T timesTwo(T input) +{ + return input * 2; +} + +void instantiate() +{ + TemplStruct ti; + TemplStruct td; + (void)timesTwo(2); + (void)timesTwo(2); +} +)cpp"); + { + auto BN = ast_matchers::match( + classTemplateDecl(hasName("TemplStruct")).bind("rec"), + AST->getASTContext()); + EXPECT_EQ(BN.size(), 1u); + + EXPECT_EQ(dumpASTString(TK_IgnoreUnlessSpelledInSource, + BN[0].getNodeAs("rec")), + R"cpp( +ClassTemplateDecl 'TemplStruct' +|-TemplateTypeParmDecl 'T' +`-CXXRecordDecl 'TemplStruct' + |-CXXRecordDecl 'TemplStruct' + |-CXXConstructorDecl 'TemplStruct' + | `-CompoundStmt + |-CXXDestructorDecl '~TemplStruct' + | `-CompoundStmt + |-AccessSpecDecl + `-FieldDecl 'm_t' +)cpp"); + + EXPECT_EQ(dumpASTString(TK_AsIs, BN[0].getNodeAs("rec")), + R"cpp( +ClassTemplateDecl 'TemplStruct' +|-TemplateTypeParmDecl 'T' +|-CXXRecordDecl 'TemplStruct' +| |-CXXRecordDecl 'TemplStruct' +| |-CXXConstructorDecl 'TemplStruct' +| | `-CompoundStmt +| |-CXXDestructorDecl '~TemplStruct' +| | `-CompoundStmt +| |-AccessSpecDecl +| `-FieldDecl 'm_t' +|-ClassTemplateSpecializationDecl 'TemplStruct' +| |-TemplateArgument type int +| | `-BuiltinType +| |-CXXRecordDecl 'TemplStruct' +| |-CXXConstructorDecl 'TemplStruct' +| | `-CompoundStmt +| |-CXXDestructorDecl '~TemplStruct' +| | `-CompoundStmt +| |-AccessSpecDecl +| |-FieldDecl 'm_t' +| `-CXXConstructorDecl 'TemplStruct' +| `-ParmVarDecl '' +`-ClassTemplateSpecializationDecl 'TemplStruct' + |-TemplateArgument type double + | `-BuiltinType + |-CXXRecordDecl 'TemplStruct' + |-CXXConstructorDecl 'TemplStruct' + | `-CompoundStmt + |-CXXDestructorDecl '~TemplStruct' + | `-CompoundStmt + |-AccessSpecDecl + |-FieldDecl 'm_t' + `-CXXConstructorDecl 'TemplStruct' + `-ParmVarDecl '' +)cpp"); + } + { + auto BN = ast_matchers::match( + functionTemplateDecl(hasName("timesTwo")).bind("fn"), + AST->getASTContext()); + EXPECT_EQ(BN.size(), 1u); + + EXPECT_EQ(dumpASTString(TK_IgnoreUnlessSpelledInSource, + BN[0].getNodeAs("fn")), + R"cpp( +FunctionTemplateDecl 'timesTwo' +|-TemplateTypeParmDecl 'T' +`-FunctionDecl 'timesTwo' + |-ParmVarDecl 'input' + `-CompoundStmt + `-ReturnStmt + `-BinaryOperator + |-DeclRefExpr 'input' + `-IntegerLiteral +)cpp"); + + EXPECT_EQ(dumpASTString(TK_AsIs, BN[0].getNodeAs("fn")), + R"cpp( +FunctionTemplateDecl 'timesTwo' +|-TemplateTypeParmDecl 'T' +|-FunctionDecl 'timesTwo' +| |-ParmVarDecl 'input' +| `-CompoundStmt +| `-ReturnStmt +| `-BinaryOperator +| |-DeclRefExpr 'input' +| `-IntegerLiteral +|-FunctionDecl 'timesTwo' +| |-TemplateArgument type int +| | `-BuiltinType +| |-ParmVarDecl 'input' +| `-CompoundStmt +| `-ReturnStmt +| `-BinaryOperator +| |-ImplicitCastExpr +| | `-DeclRefExpr 'input' +| `-IntegerLiteral +`-FunctionDecl 'timesTwo' + |-TemplateArgument type double + | `-BuiltinType + |-ParmVarDecl 'input' + `-CompoundStmt + `-ReturnStmt + `-BinaryOperator + |-ImplicitCastExpr + | `-DeclRefExpr 'input' + `-ImplicitCastExpr + `-IntegerLiteral +)cpp"); + } +} + } // namespace clang diff --git a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp --- a/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp +++ b/clang/unittests/ASTMatchers/ASTMatchersTraversalTest.cpp @@ -2085,9 +2085,17 @@ traverse(TK_AsIs, staticAssertDecl(has(implicitCastExpr(has( substNonTypeTemplateParmExpr(has(integerLiteral()))))))))); + EXPECT_TRUE(matches( + Code, traverse(TK_IgnoreUnlessSpelledInSource, + staticAssertDecl(has(declRefExpr( + to(nonTypeTemplateParmDecl(hasName("alignment"))), + hasType(asString("unsigned int")))))))); - EXPECT_TRUE(matches(Code, traverse(TK_IgnoreUnlessSpelledInSource, - staticAssertDecl(has(integerLiteral()))))); + EXPECT_TRUE(matches(Code, traverse(TK_AsIs, staticAssertDecl(hasDescendant( + integerLiteral()))))); + EXPECT_FALSE(matches( + Code, traverse(TK_IgnoreUnlessSpelledInSource, + staticAssertDecl(hasDescendant(integerLiteral()))))); Code = R"cpp( diff --git a/clang/unittests/Tooling/TransformerTest.cpp b/clang/unittests/Tooling/TransformerTest.cpp --- a/clang/unittests/Tooling/TransformerTest.cpp +++ b/clang/unittests/Tooling/TransformerTest.cpp @@ -1064,6 +1064,70 @@ EXPECT_EQ(ErrorCount, 0); } +TEST_F(TransformerTest, TemplateInstantiation) { + + std::string NonTemplatesInput = R"cpp( +struct S { + int m_i; +}; +)cpp"; + std::string NonTemplatesExpected = R"cpp( +struct S { + safe_int m_i; +}; +)cpp"; + + std::string TemplatesInput = R"cpp( +template +struct TemplStruct { + TemplStruct() {} + ~TemplStruct() {} + +private: + T m_t; +}; + +void instantiate() +{ + TemplStruct ti; +} +)cpp"; + + auto MatchedField = fieldDecl(hasType(asString("int"))).bind("theField"); + + // Changes the 'int' in 'S', but not the 'T' in 'TemplStruct': + testRule(makeRule(traverse(TK_IgnoreUnlessSpelledInSource, MatchedField), + changeTo(cat("safe_int ", name("theField")))), + NonTemplatesInput + TemplatesInput, + NonTemplatesExpected + TemplatesInput); + + // In AsIs mode, template instantiations are modified, which is + // often not desired: + + std::string IncorrectTemplatesExpected = R"cpp( +template +struct TemplStruct { + TemplStruct() {} + ~TemplStruct() {} + +private: + safe_int m_t; +}; + +void instantiate() +{ + TemplStruct ti; +} +)cpp"; + + // Changes the 'int' in 'S', and (incorrectly) the 'T' in 'TemplStruct': + testRule(makeRule(traverse(TK_AsIs, MatchedField), + changeTo(cat("safe_int ", name("theField")))), + + NonTemplatesInput + TemplatesInput, + NonTemplatesExpected + IncorrectTemplatesExpected); +} + // Transformation of macro source text when the change encompasses the entirety // of the expanded text. TEST_F(TransformerTest, SimpleMacro) {