Index: llvm/include/llvm/IR/DerivedTypes.h =================================================================== --- llvm/include/llvm/IR/DerivedTypes.h +++ llvm/include/llvm/IR/DerivedTypes.h @@ -269,6 +269,8 @@ return create(StructFields, Name); } + static StructType *getIfExists(LLVMContext &Context, StringRef Name); + /// This static method is the primary way to create a literal StructType. static StructType *get(LLVMContext &Context, ArrayRef Elements, bool isPacked = false); Index: llvm/lib/Bitcode/Reader/BitcodeReader.cpp =================================================================== --- llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1543,6 +1543,23 @@ return parseTypeTableBody(); } +static StructType * +findMatchingStructType(LLVMContext & Context, + const StringRef & Name, + const SmallVector & EltTys) { + StructType * Candidate = StructType::getIfExists(Context, Name); + if (!Candidate) + return nullptr; + if (Candidate->getNumElements() != EltTys.size()) + return nullptr; + int i = EltTys.size(); + while (i-- > 0) + if (Candidate->getElementType(i) != EltTys[i]) + return nullptr; + return Candidate; +} + + Error BitcodeReader::parseTypeTableBody() { if (!TypeList.empty()) return error("Invalid multiple blocks"); @@ -1709,15 +1726,6 @@ if (NumRecords >= TypeList.size()) return error("Invalid TYPE table"); - // Check to see if this was forward referenced, if so fill in the temp. - StructType *Res = cast_or_null(TypeList[NumRecords]); - if (Res) { - Res->setName(TypeName); - TypeList[NumRecords] = nullptr; - } else // Otherwise, create a new struct. - Res = createIdentifiedStructType(Context, TypeName); - TypeName.clear(); - SmallVector EltTys; for (unsigned i = 1, e = Record.size(); i != e; ++i) { if (Type *T = getTypeByID(Record[i])) @@ -1727,7 +1735,24 @@ } if (EltTys.size() != Record.size()-1) return error("Invalid record"); - Res->setBody(EltTys, Record[0]); + + // Check to see if this was forward referenced, if so fill in the temp. + StructType *Res = cast_or_null(TypeList[NumRecords]); + if (Res) { + Res->setName(TypeName); + TypeList[NumRecords] = nullptr; + Res->setBody(EltTys, Record[0]); + } else { + // If it was not, check so see whether the context already contains that + // type, and create it if necessary + Res = findMatchingStructType(Context, TypeName, EltTys); + if (!Res) { + Res = createIdentifiedStructType(Context, TypeName); + Res->setBody(EltTys, Record[0]); + } + } + TypeName.clear(); + ResultTy = Res; break; } Index: llvm/lib/IR/Type.cpp =================================================================== --- llvm/lib/IR/Type.cpp +++ llvm/lib/IR/Type.cpp @@ -336,8 +336,16 @@ // StructType Implementation //===----------------------------------------------------------------------===// +StructType *StructType::getIfExists(LLVMContext &Context, StringRef Name) { + auto I = Context.pImpl->NamedStructTypes.find(Name); + if(I == Context.pImpl->NamedStructTypes.end()) + return 0; + return I->getValue(); +} + // Primitive Constructors. + StructType *StructType::get(LLVMContext &Context, ArrayRef ETypes, bool isPacked) { LLVMContextImpl *pImpl = Context.pImpl; Index: llvm/unittests/Bitcode/BitReaderTest.cpp =================================================================== --- llvm/unittests/Bitcode/BitReaderTest.cpp +++ llvm/unittests/Bitcode/BitReaderTest.cpp @@ -190,4 +190,45 @@ EXPECT_FALSE(verifyModule(*M, &dbgs())); } +TEST(BitReaderTest, UseExistingNameStructType) { + // Make a module using a struct type, then write that. + LLVMContext C1; + StructType * T1 = StructType::create(C1, "Correct"); + T1->setBody(Type::getInt32Ty(C1)); + std::unique_ptr M1(new Module("M1", C1)); + M1->getOrInsertFunction("F1", T1); + SmallString<1024> Memory; + raw_svector_ostream OS(Memory); + WriteBitcodeToFile(*M1, OS); + + // Define that struct type in a new context. Read the module into that + // context. At this point, the context contains the StructType the module + // uses. + LLVMContext C2; + StructType * T2 = StructType::create(C2, T1->getName()); + T2->setBody(Type::getInt32Ty(C2)); + auto Careful = parseBitcodeFile(MemoryBufferRef(Memory.str(), "test"), C2); + EXPECT_TRUE((bool)Careful); + std::unique_ptr M2(Careful.get().release()); + + // Then one single test: was the already-present struct type used? + Function * F2 = M2->getFunction("F1"); + EXPECT_FALSE(F2 == nullptr); + EXPECT_TRUE(cast(F2->getReturnType())->getName()==T1->getName()); + + // Define a different StructType of the same name, then read the bitcode into + // THAT context. + LLVMContext C3; + StructType * T3 = StructType::create(C3, T1->getName()); + T3->setBody(Type::getFloatTy(C3)); + Careful = parseBitcodeFile(MemoryBufferRef(Memory.str(), "test"), C3); + EXPECT_TRUE((bool)Careful); + std::unique_ptr M3(Careful.get().release()); + + // The same test: Was the already-present struct type used? + Function * F3 = M3->getFunction("F1"); + EXPECT_FALSE(F3 == nullptr); + EXPECT_FALSE(cast(F3->getReturnType())->getName()==T1->getName()); +} + } // end namespace