diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp --- a/clang/lib/AST/ASTImporter.cpp +++ b/clang/lib/AST/ASTImporter.cpp @@ -6433,7 +6433,8 @@ ToFunc->setAccess(D->getAccess()); ToFunc->setLexicalDeclContext(LexicalDC); - LexicalDC->addDeclInternal(ToFunc); + if (D->getFriendObjectKind() == Decl::FOK_None) + LexicalDC->addDeclInternal(ToFunc); ASTImporterLookupTable *LT = Importer.SharedState->getLookupTable(); if (LT && !OldParamDC.empty()) { diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp --- a/clang/lib/AST/ASTStructuralEquivalence.cpp +++ b/clang/lib/AST/ASTStructuralEquivalence.cpp @@ -1464,6 +1464,160 @@ return IsStructurallyEquivalent(GetName(D1), GetName(D2)); } +static bool +IsCXXRecordBaseStructurallyEquivalent(StructuralEquivalenceContext &Context, + RecordDecl *D1, RecordDecl *D2) { + auto *D1CXX = cast(D1); + auto *D2CXX = cast(D2); + + if (D1CXX->getNumBases() != D2CXX->getNumBases()) { + if (Context.Complain) { + Context.Diag2(D2->getLocation(), Context.getApplicableDiagnostic( + diag::err_odr_tag_type_inconsistent)) + << Context.ToCtx.getTypeDeclType(D2); + Context.Diag2(D2->getLocation(), diag::note_odr_number_of_bases) + << D2CXX->getNumBases(); + Context.Diag1(D1->getLocation(), diag::note_odr_number_of_bases) + << D1CXX->getNumBases(); + } + return false; + } + + for (CXXRecordDecl::base_class_iterator Base1 = D1CXX->bases_begin(), + BaseEnd1 = D1CXX->bases_end(), + Base2 = D2CXX->bases_begin(); + Base1 != BaseEnd1; ++Base1, ++Base2) { + if (!IsStructurallyEquivalent(Context, Base1->getType(), + Base2->getType())) { + if (Context.Complain) { + Context.Diag2(D2->getLocation(), + Context.getApplicableDiagnostic( + diag::err_odr_tag_type_inconsistent)) + << Context.ToCtx.getTypeDeclType(D2); + Context.Diag2(Base2->getBeginLoc(), diag::note_odr_base) + << Base2->getType() << Base2->getSourceRange(); + Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base) + << Base1->getType() << Base1->getSourceRange(); + } + return false; + } + + // Check virtual vs. non-virtual inheritance mismatch. + if (Base1->isVirtual() != Base2->isVirtual()) { + if (Context.Complain) { + Context.Diag2(D2->getLocation(), + Context.getApplicableDiagnostic( + diag::err_odr_tag_type_inconsistent)) + << Context.ToCtx.getTypeDeclType(D2); + Context.Diag2(Base2->getBeginLoc(), diag::note_odr_virtual_base) + << Base2->isVirtual() << Base2->getSourceRange(); + Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base) + << Base1->isVirtual() << Base1->getSourceRange(); + } + return false; + } + } + + return true; +} + +using NonEquivalentDeclSet = llvm::DenseSet>; + +static bool IsEquivalentFriend(FriendDecl *F1, FriendDecl *F2, + NonEquivalentDeclSet &NonEquivalentDecls) { + StructuralEquivalenceContext Ctx( + F1->getASTContext(), F2->getASTContext(), NonEquivalentDecls, + StructuralEquivalenceKind::Minimal, false, false); + if (F1->getFriendDecl() && F2->getFriendDecl()) + return Ctx.IsEquivalent(F1->getFriendDecl(), F2->getFriendDecl()); + if (F1->getFriendType() && F2->getFriendType()) + return Ctx.IsEquivalent(F1->getFriendType()->getType(), + F2->getFriendType()->getType()); + + return false; +} + +static bool +IsEquivalentToAnyExistingFriends(FriendDecl *F, ArrayRef Friends, + NonEquivalentDeclSet &NonEquivalentDecls) { + for (FriendDecl *Other : Friends) + if (IsEquivalentFriend(F, Other, NonEquivalentDecls)) + return true; + + return false; +} + +static SmallVector getDeduplicatedFriends(CXXRecordDecl *RD) { + NonEquivalentDeclSet NonEquivalentDecls; + SmallVector EquivalentFriends; + + auto Friend = RD->friend_begin(), FriendEnd = RD->friend_end(); + if (Friend == FriendEnd) + return EquivalentFriends; + + EquivalentFriends.push_back(*Friend); + Friend = ++Friend; + while (Friend != FriendEnd) { + if (!IsEquivalentToAnyExistingFriends(*Friend, EquivalentFriends, + NonEquivalentDecls)) + EquivalentFriends.push_back(*Friend); + Friend = ++Friend; + } + + return EquivalentFriends; +} + +static bool +IsFriendInCXXRecordStructurallyEquivalent(StructuralEquivalenceContext &Context, + RecordDecl *D1, RecordDecl *D2) { + auto *D1CXX = cast(D1); + auto *D2CXX = cast(D2); + + const auto &Friends1 = getDeduplicatedFriends(D1CXX); + const auto &Friends2 = getDeduplicatedFriends(D2CXX); + + auto Friend2 = Friends2.begin(), Friend2End = Friends2.end(); + for (auto Friend1 = Friends1.begin(), Friend1End = Friends1.end(); + Friend1 != Friend1End; ++Friend1, ++Friend2) { + if (Friend2 == Friend2End) { + if (Context.Complain) { + Context.Diag2(D2->getLocation(), + Context.getApplicableDiagnostic( + diag::err_odr_tag_type_inconsistent)) + << Context.ToCtx.getTypeDeclType(D2CXX); + Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend); + Context.Diag2(D2->getLocation(), diag::note_odr_missing_friend); + } + return false; + } + + if (!IsStructurallyEquivalent(Context, *Friend1, *Friend2)) { + if (Context.Complain) { + Context.Diag2(D2->getLocation(), + Context.getApplicableDiagnostic( + diag::err_odr_tag_type_inconsistent)) + << Context.ToCtx.getTypeDeclType(D2CXX); + Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend); + Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend); + } + return false; + } + } + + if (Friend2 != Friend2End) { + if (Context.Complain) { + Context.Diag2(D2->getLocation(), Context.getApplicableDiagnostic( + diag::err_odr_tag_type_inconsistent)) + << Context.ToCtx.getTypeDeclType(D2); + Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend); + Context.Diag1(D1->getLocation(), diag::note_odr_missing_friend); + } + return false; + } + + return true; +} + /// Determine structural equivalence of two records. static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context, RecordDecl *D1, RecordDecl *D2) { @@ -1562,98 +1716,11 @@ return false; } - if (D1CXX->getNumBases() != D2CXX->getNumBases()) { - if (Context.Complain) { - Context.Diag2(D2->getLocation(), - Context.getApplicableDiagnostic( - diag::err_odr_tag_type_inconsistent)) - << Context.ToCtx.getTypeDeclType(D2); - Context.Diag2(D2->getLocation(), diag::note_odr_number_of_bases) - << D2CXX->getNumBases(); - Context.Diag1(D1->getLocation(), diag::note_odr_number_of_bases) - << D1CXX->getNumBases(); - } + if (!IsCXXRecordBaseStructurallyEquivalent(Context, D1, D2)) return false; - } - - // Check the base classes. - for (CXXRecordDecl::base_class_iterator Base1 = D1CXX->bases_begin(), - BaseEnd1 = D1CXX->bases_end(), - Base2 = D2CXX->bases_begin(); - Base1 != BaseEnd1; ++Base1, ++Base2) { - if (!IsStructurallyEquivalent(Context, Base1->getType(), - Base2->getType())) { - if (Context.Complain) { - Context.Diag2(D2->getLocation(), - Context.getApplicableDiagnostic( - diag::err_odr_tag_type_inconsistent)) - << Context.ToCtx.getTypeDeclType(D2); - Context.Diag2(Base2->getBeginLoc(), diag::note_odr_base) - << Base2->getType() << Base2->getSourceRange(); - Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base) - << Base1->getType() << Base1->getSourceRange(); - } - return false; - } - - // Check virtual vs. non-virtual inheritance mismatch. - if (Base1->isVirtual() != Base2->isVirtual()) { - if (Context.Complain) { - Context.Diag2(D2->getLocation(), - Context.getApplicableDiagnostic( - diag::err_odr_tag_type_inconsistent)) - << Context.ToCtx.getTypeDeclType(D2); - Context.Diag2(Base2->getBeginLoc(), diag::note_odr_virtual_base) - << Base2->isVirtual() << Base2->getSourceRange(); - Context.Diag1(Base1->getBeginLoc(), diag::note_odr_base) - << Base1->isVirtual() << Base1->getSourceRange(); - } - return false; - } - } - // Check the friends for consistency. - CXXRecordDecl::friend_iterator Friend2 = D2CXX->friend_begin(), - Friend2End = D2CXX->friend_end(); - for (CXXRecordDecl::friend_iterator Friend1 = D1CXX->friend_begin(), - Friend1End = D1CXX->friend_end(); - Friend1 != Friend1End; ++Friend1, ++Friend2) { - if (Friend2 == Friend2End) { - if (Context.Complain) { - Context.Diag2(D2->getLocation(), - Context.getApplicableDiagnostic( - diag::err_odr_tag_type_inconsistent)) - << Context.ToCtx.getTypeDeclType(D2CXX); - Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend); - Context.Diag2(D2->getLocation(), diag::note_odr_missing_friend); - } - return false; - } - - if (!IsStructurallyEquivalent(Context, *Friend1, *Friend2)) { - if (Context.Complain) { - Context.Diag2(D2->getLocation(), - Context.getApplicableDiagnostic( - diag::err_odr_tag_type_inconsistent)) - << Context.ToCtx.getTypeDeclType(D2CXX); - Context.Diag1((*Friend1)->getFriendLoc(), diag::note_odr_friend); - Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend); - } - return false; - } - } - - if (Friend2 != Friend2End) { - if (Context.Complain) { - Context.Diag2(D2->getLocation(), - Context.getApplicableDiagnostic( - diag::err_odr_tag_type_inconsistent)) - << Context.ToCtx.getTypeDeclType(D2); - Context.Diag2((*Friend2)->getFriendLoc(), diag::note_odr_friend); - Context.Diag1(D1->getLocation(), diag::note_odr_missing_friend); - } + if (!IsFriendInCXXRecordStructurallyEquivalent(Context, D1, D2)) return false; - } } else if (D1CXX->getNumBases() > 0) { if (Context.Complain) { Context.Diag2(D2->getLocation(), @@ -2327,8 +2394,8 @@ Decl *D1 = P.first; Decl *D2 = P.second; - bool Equivalent = - CheckCommonEquivalence(D1, D2) && CheckKindSpecificEquivalence(D1, D2); + bool Equivalent = (D1 == D2) || (CheckCommonEquivalence(D1, D2) && + CheckKindSpecificEquivalence(D1, D2)); if (!Equivalent) { // Note that these two declarations are not equivalent (and we already diff --git a/clang/unittests/AST/ASTImporterTest.cpp b/clang/unittests/AST/ASTImporterTest.cpp --- a/clang/unittests/AST/ASTImporterTest.cpp +++ b/clang/unittests/AST/ASTImporterTest.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include "clang/AST/ASTStructuralEquivalence.h" #include "clang/AST/RecordLayout.h" #include "clang/ASTMatchers/ASTMatchers.h" #include "llvm/ADT/StringMap.h" @@ -4351,6 +4352,44 @@ EXPECT_EQ(ToFriend2, ToImportedFriend2); } +TEST_P(ASTImporterOptionSpecificTestBase, ImportRepeatedFriendDeclIntoEmptyDC) { + Decl *From, *To; + std::tie(From, To) = getImportedDecl(R"( + template + class A { + public: + template friend A &f(); + template friend A &f(); + }; + )", + Lang_CXX17, "", Lang_CXX17, "A"); + + auto *FromFriend1 = FirstDeclMatcher().match(From, friendDecl()); + auto *FromFriend2 = LastDeclMatcher().match(From, friendDecl()); + auto *ToFriend1 = FirstDeclMatcher().match(To, friendDecl()); + auto *ToFriend2 = LastDeclMatcher().match(To, friendDecl()); + + // Two different FriendDecls in From context. + EXPECT_TRUE(FromFriend1 != FromFriend2); + // Only one is imported into empty DC. + EXPECT_TRUE(ToFriend1 == ToFriend2); + + // 'A' is imported into empty DC, keeping structure equivalence. + llvm::DenseSet> NonEquivalentDecls01; + llvm::DenseSet> NonEquivalentDecls10; + StructuralEquivalenceContext Ctx01( + From->getASTContext(), To->getASTContext(), NonEquivalentDecls01, + StructuralEquivalenceKind::Default, false, false); + StructuralEquivalenceContext Ctx10( + To->getASTContext(), From->getASTContext(), NonEquivalentDecls10, + StructuralEquivalenceKind::Default, false, false); + + bool Eq01 = Ctx01.IsEquivalent(From, To); + bool Eq10 = Ctx10.IsEquivalent(To, From); + EXPECT_EQ(Eq01, Eq10); + EXPECT_TRUE(Eq01); +} + TEST_P(ASTImporterOptionSpecificTestBase, FriendFunInClassTemplate) { auto *Code = R"( template diff --git a/clang/unittests/AST/StructuralEquivalenceTest.cpp b/clang/unittests/AST/StructuralEquivalenceTest.cpp --- a/clang/unittests/AST/StructuralEquivalenceTest.cpp +++ b/clang/unittests/AST/StructuralEquivalenceTest.cpp @@ -833,7 +833,7 @@ auto t = makeNamedDecls("struct foo { friend class X; };", "struct foo { friend class X; friend class X; };", Lang_CXX11); - EXPECT_FALSE(testStructuralMatch(t)); + EXPECT_TRUE(testStructuralMatch(t)); } TEST_F(StructuralEquivalenceRecordTest, SameFriendsDifferentOrder) {