diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -1614,7 +1614,8 @@ QualType Wrapped); QualType getSubstTemplateTypeParmType(const TemplateTypeParmType *Replaced, - QualType Replacement) const; + QualType Replacement, + Optional PackIndex) const; QualType getSubstTemplateTypeParmPackType( const TemplateTypeParmType *Replaced, const TemplateArgument &ArgPack); diff --git a/clang/include/clang/AST/JSONNodeDumper.h b/clang/include/clang/AST/JSONNodeDumper.h --- a/clang/include/clang/AST/JSONNodeDumper.h +++ b/clang/include/clang/AST/JSONNodeDumper.h @@ -220,6 +220,7 @@ void VisitUnaryTransformType(const UnaryTransformType *UTT); void VisitTagType(const TagType *TT); void VisitTemplateTypeParmType(const TemplateTypeParmType *TTPT); + void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *STTPT); void VisitAutoType(const AutoType *AT); void VisitTemplateSpecializationType(const TemplateSpecializationType *TST); void VisitInjectedClassNameType(const InjectedClassNameType *ICNT); diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h --- a/clang/include/clang/AST/TextNodeDumper.h +++ b/clang/include/clang/AST/TextNodeDumper.h @@ -317,6 +317,7 @@ void VisitUnaryTransformType(const UnaryTransformType *T); void VisitTagType(const TagType *T); void VisitTemplateTypeParmType(const TemplateTypeParmType *T); + void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *T); void VisitAutoType(const AutoType *T); void VisitDeducedTemplateSpecializationType( const DeducedTemplateSpecializationType *T); diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -1790,6 +1790,18 @@ unsigned NumArgs; }; + class SubstTemplateTypeParmTypeBitfields { + friend class SubstTemplateTypeParmType; + + unsigned : NumTypeBits; + + /// Represents the index within a pack if this represents a substitution + /// from a pack expansion. + /// Positive non-zero number represents the index + 1. + /// Zero means this is not substituted from an expansion. + unsigned PackIndex; + }; + class SubstTemplateTypeParmPackTypeBitfields { friend class SubstTemplateTypeParmPackType; @@ -1872,6 +1884,7 @@ ElaboratedTypeBitfields ElaboratedTypeBits; VectorTypeBitfields VectorTypeBits; SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits; + SubstTemplateTypeParmTypeBitfields SubstTemplateTypeParmTypeBits; TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits; DependentTemplateSpecializationTypeBitfields DependentTemplateSpecializationTypeBits; @@ -4973,9 +4986,12 @@ // The original type parameter. const TemplateTypeParmType *Replaced; - SubstTemplateTypeParmType(const TemplateTypeParmType *Param, QualType Canon) + SubstTemplateTypeParmType(const TemplateTypeParmType *Param, QualType Canon, + Optional PackIndex) : Type(SubstTemplateTypeParm, Canon, Canon->getDependence()), - Replaced(Param) {} + Replaced(Param) { + SubstTemplateTypeParmTypeBits.PackIndex = PackIndex ? *PackIndex + 1 : 0; + } public: /// Gets the template parameter that was substituted for. @@ -4989,18 +5005,25 @@ return getCanonicalTypeInternal(); } + Optional getPackIndex() const { + if (SubstTemplateTypeParmTypeBits.PackIndex == 0) + return None; + return SubstTemplateTypeParmTypeBits.PackIndex - 1; + } + bool isSugared() const { return true; } QualType desugar() const { return getReplacementType(); } void Profile(llvm::FoldingSetNodeID &ID) { - Profile(ID, getReplacedParameter(), getReplacementType()); + Profile(ID, getReplacedParameter(), getReplacementType(), getPackIndex()); } static void Profile(llvm::FoldingSetNodeID &ID, const TemplateTypeParmType *Replaced, - QualType Replacement) { + QualType Replacement, Optional PackIndex) { ID.AddPointer(Replaced); ID.AddPointer(Replacement.getAsOpaquePtr()); + ID.AddInteger(PackIndex ? *PackIndex - 1 : 0); } static bool classof(const Type *T) { diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td --- a/clang/include/clang/AST/TypeProperties.td +++ b/clang/include/clang/AST/TypeProperties.td @@ -734,12 +734,15 @@ def : Property<"replacementType", QualType> { let Read = [{ node->getReplacementType() }]; } + def : Property<"PackIndex", Optional> { + let Read = [{ node->getPackIndex() }]; + } def : Creator<[{ // The call to getCanonicalType here existed in ASTReader.cpp, too. return ctx.getSubstTemplateTypeParmType( cast(replacedParameter), - ctx.getCanonicalType(replacementType)); + ctx.getCanonicalType(replacementType), PackIndex); }]>; } diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -4744,19 +4744,20 @@ /// Retrieve a substitution-result type. QualType ASTContext::getSubstTemplateTypeParmType(const TemplateTypeParmType *Parm, - QualType Replacement) const { + QualType Replacement, + Optional PackIndex) const { assert(Replacement.isCanonical() && "replacement types must always be canonical"); llvm::FoldingSetNodeID ID; - SubstTemplateTypeParmType::Profile(ID, Parm, Replacement); + SubstTemplateTypeParmType::Profile(ID, Parm, Replacement, PackIndex); void *InsertPos = nullptr; SubstTemplateTypeParmType *SubstParm = SubstTemplateTypeParmTypes.FindNodeOrInsertPos(ID, InsertPos); if (!SubstParm) { SubstParm = new (*this, TypeAlignment) - SubstTemplateTypeParmType(Parm, Replacement); + SubstTemplateTypeParmType(Parm, Replacement, PackIndex); Types.push_back(SubstParm); SubstTemplateTypeParmTypes.InsertNode(SubstParm, InsertPos); } diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp --- a/clang/lib/AST/ASTImporter.cpp +++ b/clang/lib/AST/ASTImporter.cpp @@ -1530,7 +1530,8 @@ return ToReplacementTypeOrErr.takeError(); return Importer.getToContext().getSubstTemplateTypeParmType( - *ReplacedOrErr, ToReplacementTypeOrErr->getCanonicalType()); + *ReplacedOrErr, ToReplacementTypeOrErr->getCanonicalType(), + T->getPackIndex()); } ExpectedType ASTNodeImporter::VisitSubstTemplateTypeParmPackType( diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp --- a/clang/lib/AST/ASTStructuralEquivalence.cpp +++ b/clang/lib/AST/ASTStructuralEquivalence.cpp @@ -1062,6 +1062,8 @@ if (!IsStructurallyEquivalent(Context, Subst1->getReplacementType(), Subst2->getReplacementType())) return false; + if (Subst1->getPackIndex() != Subst2->getPackIndex()) + return false; break; } diff --git a/clang/lib/AST/JSONNodeDumper.cpp b/clang/lib/AST/JSONNodeDumper.cpp --- a/clang/lib/AST/JSONNodeDumper.cpp +++ b/clang/lib/AST/JSONNodeDumper.cpp @@ -680,6 +680,12 @@ JOS.attribute("decl", createBareDeclRef(TTPT->getDecl())); } +void JSONNodeDumper::VisitSubstTemplateTypeParmType( + const SubstTemplateTypeParmType *STTPT) { + if (auto PackIndex = STTPT->getPackIndex()) + JOS.attribute("pack_index", *PackIndex); +} + void JSONNodeDumper::VisitAutoType(const AutoType *AT) { JOS.attribute("undeduced", !AT->isDeduced()); switch (AT->getKeyword()) { diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp --- a/clang/lib/AST/TextNodeDumper.cpp +++ b/clang/lib/AST/TextNodeDumper.cpp @@ -1568,6 +1568,12 @@ dumpDeclRef(T->getDecl()); } +void TextNodeDumper::VisitSubstTemplateTypeParmType( + const SubstTemplateTypeParmType *T) { + if (auto PackIndex = T->getPackIndex()) + OS << " pack_index " << *PackIndex; +} + void TextNodeDumper::VisitAutoType(const AutoType *T) { if (T->isDecltypeAuto()) OS << " decltype(auto)"; diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -1167,7 +1167,7 @@ return QualType(T, 0); return Ctx.getSubstTemplateTypeParmType(T->getReplacedParameter(), - replacementType); + replacementType, T->getPackIndex()); } // FIXME: Non-trivial to implement, but important for C++ diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp --- a/clang/lib/Sema/SemaTemplateInstantiate.cpp +++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp @@ -1813,6 +1813,7 @@ return NewT; } + Optional PackIndex; if (T->isParameterPack()) { assert(Arg.getKind() == TemplateArgument::Pack && "Missing argument pack"); @@ -1830,6 +1831,7 @@ } Arg = getPackSubstitutedTemplateArgument(getSema(), Arg); + PackIndex = getSema().ArgumentPackSubstitutionIndex; } assert(Arg.getKind() == TemplateArgument::Type && @@ -1838,8 +1840,8 @@ QualType Replacement = Arg.getAsType(); // TODO: only do this uniquing once, at the start of instantiation. - QualType Result - = getSema().Context.getSubstTemplateTypeParmType(T, Replacement); + QualType Result = getSema().Context.getSubstTemplateTypeParmType( + T, Replacement, PackIndex); SubstTemplateTypeParmTypeLoc NewTL = TLB.push(Result); NewTL.setNameLoc(TL.getNameLoc()); @@ -1877,11 +1879,10 @@ TemplateArgument Arg = TL.getTypePtr()->getArgumentPack(); Arg = getPackSubstitutedTemplateArgument(getSema(), Arg); - QualType Result = Arg.getAsType(); - Result = getSema().Context.getSubstTemplateTypeParmType( - TL.getTypePtr()->getReplacedParameter(), - Result); + QualType Result = getSema().Context.getSubstTemplateTypeParmType( + TL.getTypePtr()->getReplacedParameter(), Arg.getAsType(), + getSema().ArgumentPackSubstitutionIndex); SubstTemplateTypeParmTypeLoc NewTL = TLB.push(Result); NewTL.setNameLoc(TL.getNameLoc()); diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -4854,7 +4854,8 @@ Replacement = SemaRef.Context.getQualifiedType( Replacement.getUnqualifiedType(), Qs); T = SemaRef.Context.getSubstTemplateTypeParmType( - SubstTypeParam->getReplacedParameter(), Replacement); + SubstTypeParam->getReplacedParameter(), Replacement, + SubstTypeParam->getPackIndex()); } else if ((AutoTy = dyn_cast(T)) && AutoTy->isDeduced()) { // 'auto' types behave the same way as template parameters. QualType Deduced = AutoTy->getDeducedType(); @@ -6410,9 +6411,8 @@ // Always canonicalize the replacement type. Replacement = SemaRef.Context.getCanonicalType(Replacement); - QualType Result - = SemaRef.Context.getSubstTemplateTypeParmType(T->getReplacedParameter(), - Replacement); + QualType Result = SemaRef.Context.getSubstTemplateTypeParmType( + T->getReplacedParameter(), Replacement, T->getPackIndex()); // Propagate type-source information. SubstTemplateTypeParmTypeLoc NewTL diff --git a/clang/test/AST/ast-dump-template-decls.cpp b/clang/test/AST/ast-dump-template-decls.cpp --- a/clang/test/AST/ast-dump-template-decls.cpp +++ b/clang/test/AST/ast-dump-template-decls.cpp @@ -136,13 +136,13 @@ }; using t1 = foo::bind; // CHECK: TemplateSpecializationType 0x{{[^ ]*}} 'Y' sugar Y -// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'char' sugar +// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'char' sugar pack_index 0 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'Bs' dependent contains_unexpanded_pack depth 0 index 0 pack -// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'float' sugar +// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'float' sugar pack_index 1 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'Bs' dependent contains_unexpanded_pack depth 0 index 0 pack -// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'int' sugar +// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'int' sugar pack_index 2 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'Bs' dependent contains_unexpanded_pack depth 0 index 0 pack -// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'short' sugar +// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'short' sugar pack_index 3 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'Bs' dependent contains_unexpanded_pack depth 0 index 0 pack template struct D { @@ -152,13 +152,13 @@ // CHECK: TemplateSpecializationType 0x{{[^ ]*}} 'B' sugar alias B // CHECK: FunctionProtoType 0x{{[^ ]*}} 'int (int (*)(float, int), int (*)(char, short))' cdecl // CHECK: FunctionProtoType 0x{{[^ ]*}} 'int (float, int)' cdecl -// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'float' sugar +// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'float' sugar pack_index 0 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'T' dependent contains_unexpanded_pack depth 0 index 0 pack -// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'int' sugar +// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'int' sugar pack_index 0 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'U' dependent contains_unexpanded_pack depth 0 index 0 pack // CHECK: FunctionProtoType 0x{{[^ ]*}} 'int (char, short)' cdecl -// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'char' sugar +// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'char' sugar pack_index 1 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'T' dependent contains_unexpanded_pack depth 0 index 0 pack -// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'short' sugar +// CHECK: SubstTemplateTypeParmType 0x{{[^ ]*}} 'short' sugar pack_index 1 // CHECK-NEXT: TemplateTypeParmType 0x{{[^ ]*}} 'U' dependent contains_unexpanded_pack depth 0 index 0 pack } // namespace PR56099 diff --git a/clang/unittests/AST/ASTImporterTest.cpp b/clang/unittests/AST/ASTImporterTest.cpp --- a/clang/unittests/AST/ASTImporterTest.cpp +++ b/clang/unittests/AST/ASTImporterTest.cpp @@ -4793,6 +4793,44 @@ ToD2->getDeclContext(), ToD2->getTemplateParameters()->getParam(0))); } +TEST_P(ASTImporterOptionSpecificTestBase, ImportSubstTemplateTypeParmType) { + constexpr auto Code = R"( + template struct A { + using B = A1(A2...); + }; + template struct A; + )"; + Decl *FromTU = getTuDecl(Code, Lang_CXX11, "input.cpp"); + auto *FromClass = FirstDeclMatcher().match( + FromTU, classTemplateSpecializationDecl()); + + auto testType = [&](ASTContext &Ctx, const char *Name, + llvm::Optional PackIndex) { + const auto *Subst = selectFirst( + "sttp", match(substTemplateTypeParmType( + hasReplacementType(hasCanonicalType(asString(Name)))) + .bind("sttp"), + Ctx)); + const char *ExpectedTemplateParamName = PackIndex ? "A2" : "A1"; + ASSERT_TRUE(Subst); + ASSERT_EQ(Subst->getReplacedParameter()->getIdentifier()->getName(), + ExpectedTemplateParamName); + ASSERT_EQ(Subst->getPackIndex(), PackIndex); + }; + auto tests = [&](ASTContext &Ctx) { + testType(Ctx, "void", None); + testType(Ctx, "char", 0); + testType(Ctx, "float", 1); + testType(Ctx, "int", 2); + testType(Ctx, "short", 3); + }; + + tests(FromTU->getASTContext()); + + ClassTemplateSpecializationDecl *ToClass = Import(FromClass, Lang_CXX11); + tests(ToClass->getASTContext()); +} + const AstTypeMatcher substTemplateTypeParmPackType;