diff --git a/llvm/include/llvm/IR/Comdat.h b/llvm/include/llvm/IR/Comdat.h --- a/llvm/include/llvm/IR/Comdat.h +++ b/llvm/include/llvm/IR/Comdat.h @@ -16,10 +16,12 @@ #define LLVM_IR_COMDAT_H #include "llvm-c/Types.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/CBindingWrapping.h" namespace llvm { +class GlobalObject; class raw_ostream; class StringRef; template class StringMapEntry; @@ -46,15 +48,21 @@ StringRef getName() const; void print(raw_ostream &OS, bool IsForDebug = false) const; void dump() const; + const SmallPtrSetImpl &getUsers() const { return Users; } private: friend class Module; + friend class GlobalObject; Comdat(); + void addUser(GlobalObject *GO); + void removeUser(GlobalObject *GO); // Points to the map in Module. StringMapEntry *Name = nullptr; SelectionKind SK = Any; + // Globals using this comdat. + SmallPtrSet Users; }; // Create wrappers for C Binding types (see CBindingWrapping.h). diff --git a/llvm/include/llvm/IR/GlobalObject.h b/llvm/include/llvm/IR/GlobalObject.h --- a/llvm/include/llvm/IR/GlobalObject.h +++ b/llvm/include/llvm/IR/GlobalObject.h @@ -48,6 +48,7 @@ ObjComdat(nullptr) { setGlobalValueSubClassData(0); } + ~GlobalObject(); Comdat *ObjComdat; enum { @@ -122,7 +123,7 @@ bool hasComdat() const { return getComdat() != nullptr; } const Comdat *getComdat() const { return ObjComdat; } Comdat *getComdat() { return ObjComdat; } - void setComdat(Comdat *C) { ObjComdat = C; } + void setComdat(Comdat *C); using Value::addMetadata; using Value::clearMetadata; diff --git a/llvm/lib/IR/Comdat.cpp b/llvm/lib/IR/Comdat.cpp --- a/llvm/lib/IR/Comdat.cpp +++ b/llvm/lib/IR/Comdat.cpp @@ -25,6 +25,10 @@ StringRef Comdat::getName() const { return Name->first(); } +void Comdat::addUser(GlobalObject *GO) { Users.insert(GO); } + +void Comdat::removeUser(GlobalObject *GO) { Users.erase(GO); } + LLVMComdatRef LLVMGetOrInsertComdat(LLVMModuleRef M, const char *Name) { return wrap(unwrap(M)->getOrInsertComdat(Name)); } diff --git a/llvm/lib/IR/Globals.cpp b/llvm/lib/IR/Globals.cpp --- a/llvm/lib/IR/Globals.cpp +++ b/llvm/lib/IR/Globals.cpp @@ -95,6 +95,8 @@ llvm_unreachable("not a global"); } +GlobalObject::~GlobalObject() { setComdat(nullptr); } + bool GlobalValue::isInterposable() const { if (isInterposableLinkage(getLinkage())) return true; @@ -182,6 +184,14 @@ return cast(this)->getComdat(); } +void GlobalObject::setComdat(Comdat *C) { + if (ObjComdat) + ObjComdat->removeUser(this); + ObjComdat = C; + if (C) + C->addUser(this); +} + StringRef GlobalValue::getPartition() const { if (!hasPartition()) return ""; diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp --- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -179,65 +179,29 @@ void llvm::filterDeadComdatFunctions( Module &M, SmallVectorImpl &DeadComdatFunctions) { - // Build a map from the comdat to the number of entries in that comdat we - // think are dead. If this fully covers the comdat group, then the entire - // group is dead. If we find another entry in the comdat group though, we'll - // have to preserve the whole group. - SmallDenseMap ComdatEntriesCovered; + SmallPtrSet MaybeDeadFunctions; + SmallPtrSet MaybeDeadComdats; for (Function *F : DeadComdatFunctions) { - Comdat *C = F->getComdat(); - assert(C && "Expected all input GVs to be in a comdat!"); - ComdatEntriesCovered[C] += 1; + MaybeDeadFunctions.insert(F); + if (Comdat *C = F->getComdat()) + MaybeDeadComdats.insert(C); } - auto CheckComdat = [&](Comdat &C) { - auto CI = ComdatEntriesCovered.find(&C); - if (CI == ComdatEntriesCovered.end()) - return; - - // If this could have been covered by a dead entry, just subtract one to - // account for it. - if (CI->second > 0) { - CI->second -= 1; - return; - } - - // If we've already accounted for all the entries that were dead, the - // entire comdat is alive so remove it from the map. - ComdatEntriesCovered.erase(CI); - }; - - auto CheckAllComdats = [&] { - for (Function &F : M.functions()) - if (Comdat *C = F.getComdat()) { - CheckComdat(*C); - if (ComdatEntriesCovered.empty()) - return; - } - for (GlobalVariable &GV : M.globals()) - if (Comdat *C = GV.getComdat()) { - CheckComdat(*C); - if (ComdatEntriesCovered.empty()) - return; - } - for (GlobalAlias &GA : M.aliases()) - if (Comdat *C = GA.getComdat()) { - CheckComdat(*C); - if (ComdatEntriesCovered.empty()) - return; - } - }; - CheckAllComdats(); - - if (ComdatEntriesCovered.empty()) { - DeadComdatFunctions.clear(); - return; + // Find comdats for which all users are dead now. + SmallPtrSet DeadComdats; + for (Comdat *C : MaybeDeadComdats) { + auto IsUserDead = [&](GlobalObject *GO) { + auto *F = dyn_cast(GO); + return F && MaybeDeadFunctions.contains(F); + }; + if (all_of(C->getUsers(), IsUserDead)) + DeadComdats.insert(C); } - // Remove the entries that were not covering. - erase_if(DeadComdatFunctions, [&](GlobalValue *GV) { - return ComdatEntriesCovered.find(GV->getComdat()) == - ComdatEntriesCovered.end(); + // Only keep functions which have no comdat or a dead comdat. + erase_if(DeadComdatFunctions, [&](Function *F) { + Comdat *C = F->getComdat(); + return C && !DeadComdats.contains(C); }); } diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp --- a/llvm/unittests/IR/ConstantsTest.cpp +++ b/llvm/unittests/IR/ConstantsTest.cpp @@ -763,5 +763,32 @@ } } +TEST(ConstantsTest, ComdatUserTracking) { + LLVMContext Context; + Module M("MyModule", Context); + + Comdat *C = M.getOrInsertComdat("comdat"); + const SmallPtrSetImpl &Users = C->getUsers(); + EXPECT_TRUE(Users.size() == 0); + + Type *Ty = Type::getInt8Ty(Context); + GlobalVariable *GV1 = cast(M.getOrInsertGlobal("gv1", Ty)); + GV1->setComdat(C); + EXPECT_TRUE(Users.size() == 1); + EXPECT_TRUE(Users.contains(GV1)); + + GlobalVariable *GV2 = cast(M.getOrInsertGlobal("gv2", Ty)); + GV2->setComdat(C); + EXPECT_TRUE(Users.size() == 2); + EXPECT_TRUE(Users.contains(GV2)); + + GV1->eraseFromParent(); + EXPECT_TRUE(Users.size() == 1); + EXPECT_TRUE(Users.contains(GV2)); + + GV2->eraseFromParent(); + EXPECT_TRUE(Users.size() == 0); +} + } // end anonymous namespace } // end namespace llvm