Index: llvm/include/llvm/IR/DerivedTypes.h =================================================================== --- llvm/include/llvm/IR/DerivedTypes.h +++ llvm/include/llvm/IR/DerivedTypes.h @@ -269,6 +269,10 @@ return create(StructFields, Name); } + /// This static method returns a StructType by that name if one exists, and a + /// null pointer otherwise. + static StructType *getIfExists(LLVMContext &Context, StringRef Name); + /// This static method is the primary way to create a literal StructType. static StructType *get(LLVMContext &Context, ArrayRef Elements, bool isPacked = false); Index: llvm/lib/Bitcode/Reader/BitcodeReader.cpp =================================================================== --- llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1543,6 +1543,14 @@ return parseTypeTableBody(); } +static StructType *findMatchingStructType(LLVMContext &Context, StringRef Name, + ArrayRef EltTys) { + StructType *Candidate = StructType::getIfExists(Context, Name); + if (Candidate && EltTys.equals(Candidate->elements())) + return Candidate; + return nullptr; +} + Error BitcodeReader::parseTypeTableBody() { if (!TypeList.empty()) return error("Invalid multiple blocks"); @@ -1709,25 +1717,33 @@ if (NumRecords >= TypeList.size()) return error("Invalid TYPE table"); - // Check to see if this was forward referenced, if so fill in the temp. - StructType *Res = cast_or_null(TypeList[NumRecords]); - if (Res) { - Res->setName(TypeName); - TypeList[NumRecords] = nullptr; - } else // Otherwise, create a new struct. - Res = createIdentifiedStructType(Context, TypeName); - TypeName.clear(); - - SmallVector EltTys; + SmallVector EltTys; for (unsigned i = 1, e = Record.size(); i != e; ++i) { if (Type *T = getTypeByID(Record[i])) EltTys.push_back(T); else break; } - if (EltTys.size() != Record.size()-1) + if (EltTys.size() != Record.size() - 1) return error("Invalid record"); - Res->setBody(EltTys, Record[0]); + + // Check to see if this was forward referenced, if so fill in the temp. + StructType *Res = cast_or_null(TypeList[NumRecords]); + if (Res) { + Res->setName(TypeName); + TypeList[NumRecords] = nullptr; + Res->setBody(EltTys, Record[0]); + } else { + // If it was not, check so see whether the context already contains that + // type, and create it if necessary + Res = findMatchingStructType(Context, TypeName, EltTys); + if (!Res) { + Res = createIdentifiedStructType(Context, TypeName); + Res->setBody(EltTys, Record[0]); + } + } + TypeName.clear(); + ResultTy = Res; break; } Index: llvm/lib/IR/Type.cpp =================================================================== --- llvm/lib/IR/Type.cpp +++ llvm/lib/IR/Type.cpp @@ -336,8 +336,16 @@ // StructType Implementation //===----------------------------------------------------------------------===// +StructType *StructType::getIfExists(LLVMContext &Context, StringRef Name) { + auto I = Context.pImpl->NamedStructTypes.find(Name); + if (I == Context.pImpl->NamedStructTypes.end()) + return nullptr; + return I->getValue(); +} + // Primitive Constructors. + StructType *StructType::get(LLVMContext &Context, ArrayRef ETypes, bool isPacked) { LLVMContextImpl *pImpl = Context.pImpl; Index: llvm/unittests/Bitcode/BitReaderTest.cpp =================================================================== --- llvm/unittests/Bitcode/BitReaderTest.cpp +++ llvm/unittests/Bitcode/BitReaderTest.cpp @@ -190,4 +190,46 @@ EXPECT_FALSE(verifyModule(*M, &dbgs())); } +TEST(BitReaderTest, UseExistingNameStructType) { + // Make a module using a struct type, then write that. + LLVMContext C1; + StructType *T1 = StructType::create(C1, "Correct"); + T1->setBody(Type::getInt32Ty(C1)); + std::unique_ptr M1(new Module("M1", C1)); + M1->getOrInsertFunction("F1", T1); + SmallString<1024> Memory; + raw_svector_ostream OS(Memory); + WriteBitcodeToFile(*M1, OS); + + // Define that struct type in a new context. Read the module into that + // context. At this point, the context contains the StructType the module + // uses. + LLVMContext C2; + StructType *T2 = StructType::create(C2, T1->getName()); + T2->setBody(Type::getInt32Ty(C2)); + auto Careful = parseBitcodeFile(MemoryBufferRef(Memory.str(), "test"), C2); + EXPECT_TRUE((bool)Careful); + std::unique_ptr M2(Careful.get().release()); + + // Then one single test: was the already-present struct type used? + Function *F2 = M2->getFunction("F1"); + EXPECT_NE(F2, nullptr); + EXPECT_EQ(cast(F2->getReturnType())->getName(), T1->getName()); + + // Define a different StructType of the same name, then read the bitcode into + // THAT context. + LLVMContext C3; + StructType *T3 = StructType::create(C3, T1->getName()); + T3->setBody(Type::getFloatTy(C3)); + Careful = parseBitcodeFile(MemoryBufferRef(Memory.str(), "test"), C3); + EXPECT_TRUE((bool)Careful); + std::unique_ptr M3(Careful.get().release()); + + // The same test: Was the already-present struct type used? It should not be + // used, since its structure differs. + Function *F3 = M3->getFunction("F1"); + EXPECT_NE(F3, nullptr); + EXPECT_NE(cast(F3->getReturnType())->getName(), T1->getName()); +} + } // end namespace