diff --git a/clang-tools-extra/clangd/AST.cpp b/clang-tools-extra/clangd/AST.cpp --- a/clang-tools-extra/clangd/AST.cpp +++ b/clang-tools-extra/clangd/AST.cpp @@ -36,6 +36,10 @@ llvm::dyn_cast(&ND)) { if (auto *Args = Cls->getTemplateArgsAsWritten()) return Args->arguments(); + } else if (auto *Var = + llvm::dyn_cast(&ND)) { + if (auto *Args = Var->getTemplateArgsAsWritten()) + return Args->arguments(); } else if (auto *Var = llvm::dyn_cast(&ND)) return Var->getTemplateArgsInfo().arguments(); // We return None for ClassTemplateSpecializationDecls because it does not diff --git a/clang-tools-extra/clangd/unittests/ASTTests.cpp b/clang-tools-extra/clangd/unittests/ASTTests.cpp --- a/clang-tools-extra/clangd/unittests/ASTTests.cpp +++ b/clang-tools-extra/clangd/unittests/ASTTests.cpp @@ -8,6 +8,7 @@ #include "AST.h" #include "gtest/gtest.h" +#include "TestTU.h" namespace clang { namespace clangd { @@ -36,6 +37,44 @@ "testns1::TestClass", "testns1")); } +TEST(PrintTemplateSpecializationArgs, PrintsTemplateArgs) { + TestTU TU; + TU.Code = R"cpp( + template + void foo(T) {} + template<> + void foo(int) {} + + template + struct K {}; + template + struct K {}; + template<> + struct K {}; + + template + T S = T(10); + + template + int S = 0; + template <> + int S = 0; + )cpp"; + + // The expected template args string representation for every top level decl + // in the TU. + std::vector ExpectedTemplateDeclArgs{ + "", "", "", "", "", "", "", ""}; + + std::vector TopLevel = TU.build().getLocalTopLevelDecls(); + std::vector ActualTemplateArgs; + for (const Decl *D : TopLevel) { + if (const NamedDecl *ND = dyn_cast(D)) + ActualTemplateArgs.push_back(printTemplateSpecializationArgs(*ND)); + } + + EXPECT_EQ(ExpectedTemplateDeclArgs, ActualTemplateArgs); +} } // namespace } // namespace clangd diff --git a/clang-tools-extra/clangd/unittests/ClangdUnitTests.cpp b/clang-tools-extra/clangd/unittests/ClangdUnitTests.cpp --- a/clang-tools-extra/clangd/unittests/ClangdUnitTests.cpp +++ b/clang-tools-extra/clangd/unittests/ClangdUnitTests.cpp @@ -188,7 +188,7 @@ AllOf(DeclNamed("foo"), WithTemplateArgs("")), AllOf(DeclNamed("i"), WithTemplateArgs("")), AllOf(DeclNamed("d"), WithTemplateArgs("")), - AllOf(DeclNamed("foo"), WithTemplateArgs("<>")), + AllOf(DeclNamed("foo"), WithTemplateArgs("")), AllOf(DeclNamed("foo"), WithTemplateArgs(""))})); }