Index: include/clang/AST/DeclCXX.h =================================================================== --- include/clang/AST/DeclCXX.h +++ include/clang/AST/DeclCXX.h @@ -1377,6 +1377,15 @@ /// \brief Set the kind of specialization or template instantiation this is. void setTemplateSpecializationKind(TemplateSpecializationKind TSK); + /// \brief Retrieve the record declaration from which this record could be + /// instantiated. Returns null if this class is not a template instantiation. + const CXXRecordDecl *getTemplateInstantiationPattern() const; + + CXXRecordDecl *getTemplateInstantiationPattern() { + return const_cast(const_cast(this) + ->getTemplateInstantiationPattern()); + } + /// \brief Returns the destructor decl for this class. CXXDestructorDecl *getDestructor() const; Index: lib/AST/DeclCXX.cpp =================================================================== --- lib/AST/DeclCXX.cpp +++ lib/AST/DeclCXX.cpp @@ -1261,6 +1261,44 @@ llvm_unreachable("Not a class template or member class specialization"); } +const CXXRecordDecl *CXXRecordDecl::getTemplateInstantiationPattern() const { + // If it's a class template specialization, find the template or partial + // specialization from which it was instantiated. + if (auto *TD = dyn_cast(this)) { + auto From = TD->getInstantiatedFrom(); + if (auto *CTD = From.dyn_cast()) { + while (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate()) { + if (NewCTD->isMemberSpecialization()) + break; + CTD = NewCTD; + } + return CTD->getTemplatedDecl(); + } + if (auto *CTPSD = + From.dyn_cast()) { + while (auto *NewCTPSD = CTPSD->getInstantiatedFromMember()) { + if (NewCTPSD->isMemberSpecialization()) + break; + CTPSD = NewCTPSD; + } + return CTPSD; + } + } + + if (MemberSpecializationInfo *MSInfo = getMemberSpecializationInfo()) { + if (isTemplateInstantiation(MSInfo->getTemplateSpecializationKind())) { + const CXXRecordDecl *RD = this; + while (auto *NewRD = RD->getInstantiatedFromMemberClass()) + RD = NewRD; + return RD; + } + } + + assert(!isTemplateInstantiation(this->getTemplateSpecializationKind()) && + "couldn't find pattern for class template instantiation"); + return nullptr; +} + CXXDestructorDecl *CXXRecordDecl::getDestructor() const { ASTContext &Context = getASTContext(); QualType ClassType = Context.getTypeDeclType(this); Index: lib/Sema/SemaLookup.cpp =================================================================== --- lib/Sema/SemaLookup.cpp +++ lib/Sema/SemaLookup.cpp @@ -1175,21 +1175,8 @@ if (FunctionDecl *Pattern = FD->getTemplateInstantiationPattern()) Entity = Pattern; } else if (CXXRecordDecl *RD = dyn_cast(Entity)) { - // If it's a class template specialization, find the template or partial - // specialization from which it was instantiated. - if (ClassTemplateSpecializationDecl *SpecRD = - dyn_cast(RD)) { - llvm::PointerUnion From = - SpecRD->getInstantiatedFrom(); - if (ClassTemplateDecl *FromTemplate = From.dyn_cast()) - Entity = FromTemplate->getTemplatedDecl(); - else if (From) - Entity = From.get(); - // Otherwise, it's an explicit specialization. - } else if (MemberSpecializationInfo *MSInfo = - RD->getMemberSpecializationInfo()) - Entity = getInstantiatedFrom(RD, MSInfo); + if (CXXRecordDecl *Pattern = RD->getTemplateInstantiationPattern()) + Entity = Pattern; } else if (EnumDecl *ED = dyn_cast(Entity)) { if (MemberSpecializationInfo *MSInfo = ED->getMemberSpecializationInfo()) Entity = getInstantiatedFrom(ED, MSInfo); Index: lib/Sema/SemaType.cpp =================================================================== --- lib/Sema/SemaType.cpp +++ lib/Sema/SemaType.cpp @@ -5085,31 +5085,9 @@ // If this definition was instantiated from a template, map back to the // pattern from which it was instantiated. - // - // FIXME: There must be a better place for this to live. if (auto *RD = dyn_cast(D)) { - if (auto *TD = dyn_cast(RD)) { - auto From = TD->getInstantiatedFrom(); - if (auto *CTD = From.dyn_cast()) { - while (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate()) { - if (NewCTD->isMemberSpecialization()) - break; - CTD = NewCTD; - } - RD = CTD->getTemplatedDecl(); - } else if (auto *CTPSD = From.dyn_cast< - ClassTemplatePartialSpecializationDecl *>()) { - while (auto *NewCTPSD = CTPSD->getInstantiatedFromMember()) { - if (NewCTPSD->isMemberSpecialization()) - break; - CTPSD = NewCTPSD; - } - RD = CTPSD; - } - } else if (isTemplateInstantiation(RD->getTemplateSpecializationKind())) { - while (auto *NewRD = RD->getInstantiatedFromMemberClass()) - RD = NewRD; - } + if (auto *Pattern = RD->getTemplateInstantiationPattern()) + RD = Pattern; D = RD->getDefinition(); } else if (auto *ED = dyn_cast(D)) { while (auto *NewED = ED->getInstantiatedFromMemberEnum())