Index: include/llvm/Linker/IRMover.h =================================================================== --- include/llvm/Linker/IRMover.h +++ include/llvm/Linker/IRMover.h @@ -40,8 +40,8 @@ public: class IdentifiedStructTypeSet { - // The set of opaque types is the composite module. - DenseSet OpaqueStructTypes; + // The set of all struct types in the composite module. + DenseSet AllStructTypes; // The set of identified but non opaque structures in the composite module. DenseSet NonOpaqueStructTypes; Index: lib/Linker/IRMover.cpp =================================================================== --- lib/Linker/IRMover.cpp +++ lib/Linker/IRMover.cpp @@ -1453,20 +1453,18 @@ void IRMover::IdentifiedStructTypeSet::addNonOpaque(StructType *Ty) { assert(!Ty->isOpaque()); + AllStructTypes.insert(Ty); NonOpaqueStructTypes.insert(Ty); } void IRMover::IdentifiedStructTypeSet::switchToNonOpaque(StructType *Ty) { assert(!Ty->isOpaque()); NonOpaqueStructTypes.insert(Ty); - bool Removed = OpaqueStructTypes.erase(Ty); - (void)Removed; - assert(Removed); } void IRMover::IdentifiedStructTypeSet::addOpaque(StructType *Ty) { assert(Ty->isOpaque()); - OpaqueStructTypes.insert(Ty); + AllStructTypes.insert(Ty); } StructType * @@ -1480,12 +1478,7 @@ } bool IRMover::IdentifiedStructTypeSet::hasType(StructType *Ty) { - if (Ty->isOpaque()) - return OpaqueStructTypes.count(Ty); - auto I = NonOpaqueStructTypes.find(Ty); - if (I == NonOpaqueStructTypes.end()) - return false; - return *I == Ty; + return AllStructTypes.count(Ty); } IRMover::IRMover(Module &M) : Composite(M) { Index: unittests/Linker/LinkModulesTest.cpp =================================================================== --- unittests/Linker/LinkModulesTest.cpp +++ unittests/Linker/LinkModulesTest.cpp @@ -7,6 +7,8 @@ // //===----------------------------------------------------------------------===// +#include "llvm-c/Core.h" +#include "llvm-c/Linker.h" #include "llvm/ADT/STLExtras.h" #include "llvm/AsmParser/Parser.h" #include "llvm/IR/BasicBlock.h" @@ -14,10 +16,9 @@ #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "llvm/Linker/Linker.h" #include "llvm/Support/SourceMgr.h" -#include "llvm-c/Core.h" -#include "llvm-c/Linker.h" #include "gtest/gtest.h" using namespace llvm; @@ -306,4 +307,72 @@ EXPECT_EQ(M3, M4->getOperand(0)); } +TEST_F(LinkModuleTest, MergeDuplicateStructsByName) { + LLVMContext C; + SMDiagnostic Err; + + // Destination module has both types, with some dummy references to make sure + // they + // aren't eliminated when linking with Src. + const char *DstStr = "%\"struct.T1\" = type { i8 }\n" + "%\"struct.T2\" = type { i8 }\n" + "define void @foo1(%\"struct.T1\"* %x) {\n" + " ret void\n" + "}\n" + "define void @foo2(%\"struct.T2\"* %x) {\n" + " ret void\n" + "}\n"; + + // Source module has both types along with functions that take them as + // arguments. + const char *SrcStr = "%\"struct.T1\" = type { i8 }\n" + "%\"struct.T2\" = type { i8 }\n" + "define void @callee1(%\"struct.T1\"* %x) {\n" + " ret void\n" + "}\n" + "define void @callee2(%\"struct.T2\"* %x) {\n" + " ret void\n" + "}\n"; + + std::unique_ptr Dst = parseAssemblyString(DstStr, Err, C); + assert(Dst); + ASSERT_TRUE(Dst.get()); + + std::unique_ptr Src = parseAssemblyString(SrcStr, Err, C); + assert(Src); + ASSERT_TRUE(Src.get()); + + Ctx.setDiagnosticHandler(expectNoDiags); + Linker::linkModules(*Dst, std::move(Src)); + + // Check that we can call both callee1 and callee2 still. + // Since T1 and T2 were defined in Src, it should be valid to call functions + // with those types in Dst. + Type *T1Ptr = PointerType::get(Dst->getTypeByName("struct.T1"), 0); + Type *T2Ptr = PointerType::get(Dst->getTypeByName("struct.T2"), 0); + FunctionType *FT1 = + FunctionType::get(Type::getVoidTy(C), ArrayRef({T1Ptr}), false); + FunctionType *FT2 = + FunctionType::get(Type::getVoidTy(C), ArrayRef({T2Ptr}), false); + Function *F1 = + Function::Create(FT1, Function::ExternalLinkage, "T1Func", Dst.get()); + Function *F2 = + Function::Create(FT2, Function::ExternalLinkage, "T2Func", Dst.get()); + BasicBlock *BB1 = BasicBlock::Create(C, "entry", F1); + BasicBlock *BB2 = BasicBlock::Create(C, "entry", F2); + + IRBuilder<> Builder(C); + Builder.SetInsertPoint(BB1); + Builder.CreateCall(Dst->getFunction("callee1"), + ArrayRef({&*F1->arg_begin()})); + Builder.CreateRetVoid(); + EXPECT_FALSE(verifyFunction(*F1, &errs())); + + Builder.SetInsertPoint(BB2); + Builder.CreateCall(Dst->getFunction("callee2"), + ArrayRef({&*F2->arg_begin()})); + Builder.CreateRetVoid(); + EXPECT_FALSE(verifyFunction(*F2, &errs())); +} + } // end anonymous namespace