diff --git a/clang/include/clang/AST/PrettyPrinter.h b/clang/include/clang/AST/PrettyPrinter.h --- a/clang/include/clang/AST/PrettyPrinter.h +++ b/clang/include/clang/AST/PrettyPrinter.h @@ -75,7 +75,7 @@ PrintCanonicalTypes(false), PrintInjectedClassNameWithArguments(true), UsePreferredNames(true), AlwaysIncludeTypeForTemplateArgument(false), CleanUglifiedParameters(false), EntireContentsOfLargeArray(true), - UseEnumerators(true) {} + UseEnumerators(true), AlwaysIncludeTypeForNonTypeTemplateArgument(false) {} /// Adjust this printing policy for cases where it's known that we're /// printing C++ code (for instance, if AST dumping reaches a C++-only @@ -295,6 +295,18 @@ /// enumerator name or via cast of an integer. unsigned UseEnumerators : 1; + /// Whether to print full type names of non-type template arguments. + /// + /// \code + /// struct Point { int x, y; }; + /// template< Point p > struct S {}; + /// S< Point{ 1, 2 } > s; + /// \endcode + /// + /// decltype(s) will be printed as "S" if enabled and as "S<{1,2}>" if disabled, + /// regardless if PrintCanonicalTypes is enabled. + unsigned AlwaysIncludeTypeForNonTypeTemplateArgument : 1; + /// Callbacks to use to allow the behavior of printing to be customized. const PrintingCallbacks *Callbacks = nullptr; }; diff --git a/clang/lib/AST/APValue.cpp b/clang/lib/AST/APValue.cpp --- a/clang/lib/AST/APValue.cpp +++ b/clang/lib/AST/APValue.cpp @@ -892,6 +892,8 @@ assert(BI != CD->bases_end()); if (!First) Out << ", "; + if (Policy.AlwaysIncludeTypeForNonTypeTemplateArgument) + BI->getType().getUnqualifiedType().print(Out, Policy); getStructBase(I).printPretty(Out, Policy, BI->getType(), Ctx); First = false; } diff --git a/clang/lib/AST/TemplateBase.cpp b/clang/lib/AST/TemplateBase.cpp --- a/clang/lib/AST/TemplateBase.cpp +++ b/clang/lib/AST/TemplateBase.cpp @@ -432,10 +432,11 @@ } case Declaration: { - // FIXME: Include the type if it's not obvious from the context. NamedDecl *ND = getAsDecl(); if (getParamTypeForDecl()->isRecordType()) { if (auto *TPO = dyn_cast(ND)) { + if (Policy.AlwaysIncludeTypeForNonTypeTemplateArgument) + TPO->getType().getUnqualifiedType().print(Out, Policy); TPO->printAsInit(Out, Policy); break; } diff --git a/clang/unittests/AST/TypePrinterTest.cpp b/clang/unittests/AST/TypePrinterTest.cpp --- a/clang/unittests/AST/TypePrinterTest.cpp +++ b/clang/unittests/AST/TypePrinterTest.cpp @@ -48,7 +48,7 @@ std::string Code = R"cpp( namespace N { template struct Type {}; - + template void Foo(const Type &Param); } @@ -127,3 +127,86 @@ Policy.EntireContentsOfLargeArray = true; })); } + +TEST(TypePrinter, TemplateIdWithFullTypeNTTP) { + constexpr char Code[] = R"cpp( + enum struct Encoding { UTF8, ASCII }; + template + struct Str { + constexpr Str(char const (&s)[N]) { __builtin_memcpy(value, s, N); } + char value[N]; + }; + template class ASCII {}; + + ASCII<"some string"> x; + )cpp"; + auto Matcher = classTemplateSpecializationDecl( + hasName("ASCII"), has(cxxConstructorDecl( + isMoveConstructor(), + has(parmVarDecl(hasType(qualType().bind("id"))))))); + + ASSERT_TRUE(PrintedTypeMatches( + Code, {"-std=c++20"}, Matcher, + R"(ASCII{"some string"}> &&)", + [](PrintingPolicy &Policy) { + Policy.AlwaysIncludeTypeForNonTypeTemplateArgument = true; + })); + + ASSERT_TRUE(PrintedTypeMatches( + Code, {"-std=c++20"}, Matcher, R"(ASCII<{"some string"}> &&)", + [](PrintingPolicy &Policy) { + Policy.AlwaysIncludeTypeForNonTypeTemplateArgument = false; + })); +} + +TEST(TypePrinter, TemplateIdWithComplexFullTypeNTTP) { + constexpr char Code[] = R"cpp( + template< typename T, auto ... dims > + struct NDArray {}; + + struct Dimension + { + using value_type = unsigned short; + + value_type size{ value_type( 0 ) }; + }; + + template < typename ConcreteDim > + struct DimensionImpl : Dimension {}; + + struct Width : DimensionImpl< Width > {}; + struct Height : DimensionImpl< Height > {}; + struct Channels : DimensionImpl< Channels > {}; + + inline constexpr Width W; + inline constexpr Height H; + inline constexpr Channels C; + + template< auto ... Dims > + consteval auto makeArray() noexcept + { + return NDArray< float, Dims ... >{}; + } + + [[ maybe_unused ]] auto x { makeArray< H, W, C >() }; + + )cpp"; + auto Matcher = varDecl( + allOf(hasAttr(attr::Kind::Unused), hasType(qualType().bind("id")))); + + ASSERT_TRUE(PrintedTypeMatches( + Code, {"-std=c++20"}, Matcher, + R"(NDArray)", + [](PrintingPolicy &Policy) { + Policy.PrintCanonicalTypes = true; + Policy.AlwaysIncludeTypeForNonTypeTemplateArgument = false; + })); + + ASSERT_TRUE(PrintedTypeMatches( + Code, {"-std=c++20"}, Matcher, + R"(NDArray{Dimension{0}}}, Width{DimensionImpl{Dimension{0}}}, Channels{DimensionImpl{Dimension{0}}}>)", + [](PrintingPolicy &Policy) { + Policy.PrintCanonicalTypes = true; + Policy.AlwaysIncludeTypeForNonTypeTemplateArgument = true; + })); +}