Index: llvm/include/llvm/IR/DerivedTypes.h =================================================================== --- llvm/include/llvm/IR/DerivedTypes.h +++ llvm/include/llvm/IR/DerivedTypes.h @@ -270,6 +270,10 @@ return create(StructFields, Name); } + /// This static method returns a StructType by that name if one exists, and a + /// null pointer otherwise. + static StructType *getIfExists(LLVMContext &Context, StringRef Name); + /// This static method is the primary way to create a literal StructType. static StructType *get(LLVMContext &Context, ArrayRef Elements, bool isPacked = false); Index: llvm/lib/Bitcode/Reader/BitcodeReader.cpp =================================================================== --- llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -589,8 +589,6 @@ private: std::vector IdentifiedStructTypes; - StructType *createIdentifiedStructType(LLVMContext &Context, StringRef Name); - StructType *createIdentifiedStructType(LLVMContext &Context); Type *getTypeByID(unsigned ID); @@ -1149,25 +1147,7 @@ if (ID >= TypeList.size()) return nullptr; - if (Type *Ty = TypeList[ID]) - return Ty; - - // If we have a forward reference, the only possible case is when it is to a - // named struct. Just create a placeholder for now. - return TypeList[ID] = createIdentifiedStructType(Context); -} - -StructType *BitcodeReader::createIdentifiedStructType(LLVMContext &Context, - StringRef Name) { - auto *Ret = StructType::create(Context, Name); - IdentifiedStructTypes.push_back(Ret); - return Ret; -} - -StructType *BitcodeReader::createIdentifiedStructType(LLVMContext &Context) { - auto *Ret = StructType::create(Context); - IdentifiedStructTypes.push_back(Ret); - return Ret; + return TypeList[ID]; } //===----------------------------------------------------------------------===// @@ -1620,13 +1600,13 @@ if (!TypeList.empty()) return error("Invalid multiple blocks"); - SmallVector Record; - unsigned NumRecords = 0; - SmallString<64> TypeName; + std::vector>>> + TypeRecords; // Read all the records for this type table. - while (true) { + bool Done = false; + while (!Done) { Expected MaybeEntry = Stream.advanceSkippingSubblocks(); if (!MaybeEntry) return MaybeEntry.takeError(); @@ -1637,68 +1617,100 @@ case BitstreamEntry::Error: return error("Malformed block"); case BitstreamEntry::EndBlock: - if (NumRecords != TypeList.size()) + if (TypeRecords.size() != TypeList.size()) return error("Malformed block"); - return Error::success(); + Done = true; + break; case BitstreamEntry::Record: - // The interesting case. + std::shared_ptr> Record( + new SmallVector()); + Expected MaybeCode = Stream.readRecord(Entry.ID, *Record); + if(!MaybeCode) + return MaybeCode.takeError(); + unsigned Code = MaybeCode.get(); + switch (Code) { + case bitc::TYPE_CODE_NUMENTRY: // TYPE_CODE_NUMENTRY: [numentries] + // TYPE_CODE_NUMENTRY contains a count of the number of types in the + // type list. This allows us to reserve space. + if (Record->size() < 1) + return error("Invalid record"); + TypeList.resize((*Record)[0]); + break; + case bitc::TYPE_CODE_STRUCT_NAME: // STRUCT_NAME: [strchr x N] + // TYPE_CODE_STRUCT_NAME provides a name for the next struct type. We + // provide that type already, so forward references can reach it. If + // necessary the type is left opaque for now. + if (convertToString(*Record, 0, TypeName)) + return error("Invalid record"); + break; + case bitc::TYPE_CODE_STRUCT_NAMED: + case bitc::TYPE_CODE_OPAQUE: { + StructType *NamedStructType = nullptr; + if (!TypeName.empty()) + NamedStructType = StructType::getIfExists(Context, TypeName); + if (!NamedStructType) + NamedStructType = StructType::create(Context, TypeName); + TypeList[TypeRecords.size()] = NamedStructType; + IdentifiedStructTypes.push_back(NamedStructType); + TypeName.clear(); + } + LLVM_FALLTHROUGH; + default: + // All other records are left to process once the forward references are + // resolvable. + TypeRecords.push_back(make_pair(Code, Record)); + break; + } break; } + } - // Read a record. - Record.clear(); + unsigned NumRecords = 0; + // Process the deferred type records. + for (auto TypeRecord : TypeRecords) { + auto Record = TypeRecord.second; Type *ResultTy = nullptr; - Expected MaybeRecord = Stream.readRecord(Entry.ID, Record); - if (!MaybeRecord) - return MaybeRecord.takeError(); - switch (MaybeRecord.get()) { + switch (TypeRecord.first) { default: return error("Invalid value"); - case bitc::TYPE_CODE_NUMENTRY: // TYPE_CODE_NUMENTRY: [numentries] - // TYPE_CODE_NUMENTRY contains a count of the number of types in the - // type list. This allows us to reserve space. - if (Record.size() < 1) - return error("Invalid record"); - TypeList.resize(Record[0]); - continue; - case bitc::TYPE_CODE_VOID: // VOID + case bitc::TYPE_CODE_VOID: // VOID ResultTy = Type::getVoidTy(Context); break; - case bitc::TYPE_CODE_HALF: // HALF + case bitc::TYPE_CODE_HALF: // HALF ResultTy = Type::getHalfTy(Context); break; - case bitc::TYPE_CODE_FLOAT: // FLOAT + case bitc::TYPE_CODE_FLOAT: // FLOAT ResultTy = Type::getFloatTy(Context); break; - case bitc::TYPE_CODE_DOUBLE: // DOUBLE + case bitc::TYPE_CODE_DOUBLE: // DOUBLE ResultTy = Type::getDoubleTy(Context); break; - case bitc::TYPE_CODE_X86_FP80: // X86_FP80 + case bitc::TYPE_CODE_X86_FP80: // X86_FP80 ResultTy = Type::getX86_FP80Ty(Context); break; - case bitc::TYPE_CODE_FP128: // FP128 + case bitc::TYPE_CODE_FP128: // FP128 ResultTy = Type::getFP128Ty(Context); break; case bitc::TYPE_CODE_PPC_FP128: // PPC_FP128 ResultTy = Type::getPPC_FP128Ty(Context); break; - case bitc::TYPE_CODE_LABEL: // LABEL + case bitc::TYPE_CODE_LABEL: // LABEL ResultTy = Type::getLabelTy(Context); break; - case bitc::TYPE_CODE_METADATA: // METADATA + case bitc::TYPE_CODE_METADATA: // METADATA ResultTy = Type::getMetadataTy(Context); break; - case bitc::TYPE_CODE_X86_MMX: // X86_MMX + case bitc::TYPE_CODE_X86_MMX: // X86_MMX ResultTy = Type::getX86_MMXTy(Context); break; - case bitc::TYPE_CODE_TOKEN: // TOKEN + case bitc::TYPE_CODE_TOKEN: // TOKEN ResultTy = Type::getTokenTy(Context); break; case bitc::TYPE_CODE_INTEGER: { // INTEGER: [width] - if (Record.size() < 1) + if (Record->size() < 1) return error("Invalid record"); - uint64_t NumBits = Record[0]; + uint64_t NumBits = (*Record)[0]; if (NumBits < IntegerType::MIN_INT_BITS || NumBits > IntegerType::MAX_INT_BITS) return error("Bitwidth for integer type out of range"); @@ -1707,14 +1719,13 @@ } case bitc::TYPE_CODE_POINTER: { // POINTER: [pointee type] or // [pointee type, address space] - if (Record.size() < 1) + if (Record->size() < 1) return error("Invalid record"); unsigned AddressSpace = 0; - if (Record.size() == 2) - AddressSpace = Record[1]; - ResultTy = getTypeByID(Record[0]); - if (!ResultTy || - !PointerType::isValidElementType(ResultTy)) + if (Record->size() == 2) + AddressSpace = (*Record)[1]; + ResultTy = getTypeByID((*Record)[0]); + if (!ResultTy || !PointerType::isValidElementType(ResultTy)) return error("Invalid type"); ResultTy = PointerType::get(ResultTy, AddressSpace); break; @@ -1722,131 +1733,133 @@ case bitc::TYPE_CODE_FUNCTION_OLD: { // FIXME: attrid is dead, remove it in LLVM 4.0 // FUNCTION: [vararg, attrid, retty, paramty x N] - if (Record.size() < 3) + if (Record->size() < 3) return error("Invalid record"); - SmallVector ArgTys; - for (unsigned i = 3, e = Record.size(); i != e; ++i) { - if (Type *T = getTypeByID(Record[i])) + SmallVector ArgTys; + for (unsigned i = 3, e = Record->size(); i != e; ++i) { + if (Type *T = getTypeByID((*Record)[i])) ArgTys.push_back(T); else break; } - ResultTy = getTypeByID(Record[2]); - if (!ResultTy || ArgTys.size() < Record.size()-3) + ResultTy = getTypeByID((*Record)[2]); + if (!ResultTy || ArgTys.size() < Record->size() - 3) return error("Invalid type"); - ResultTy = FunctionType::get(ResultTy, ArgTys, Record[0]); + ResultTy = FunctionType::get(ResultTy, ArgTys, (*Record)[0]); break; } case bitc::TYPE_CODE_FUNCTION: { // FUNCTION: [vararg, retty, paramty x N] - if (Record.size() < 2) + if (Record->size() < 2) return error("Invalid record"); - SmallVector ArgTys; - for (unsigned i = 2, e = Record.size(); i != e; ++i) { - if (Type *T = getTypeByID(Record[i])) { + SmallVector ArgTys; + for (unsigned i = 2, e = Record->size(); i != e; ++i) { + if (Type *T = getTypeByID((*Record)[i])) { if (!FunctionType::isValidArgumentType(T)) return error("Invalid function argument type"); ArgTys.push_back(T); - } - else + } else break; } - ResultTy = getTypeByID(Record[1]); - if (!ResultTy || ArgTys.size() < Record.size()-2) + ResultTy = getTypeByID((*Record)[1]); + if (!ResultTy || ArgTys.size() < Record->size() - 2) return error("Invalid type"); - ResultTy = FunctionType::get(ResultTy, ArgTys, Record[0]); + ResultTy = FunctionType::get(ResultTy, ArgTys, (*Record)[0]); break; } - case bitc::TYPE_CODE_STRUCT_ANON: { // STRUCT: [ispacked, eltty x N] - if (Record.size() < 1) + case bitc::TYPE_CODE_STRUCT_ANON: { // STRUCT: [ispacked, eltty x N] + if (Record->size() < 1) return error("Invalid record"); - SmallVector EltTys; - for (unsigned i = 1, e = Record.size(); i != e; ++i) { - if (Type *T = getTypeByID(Record[i])) + SmallVector EltTys; + for (unsigned i = 1, e = Record->size(); i != e; ++i) { + if (Type *T = getTypeByID((*Record)[i])) EltTys.push_back(T); else break; } - if (EltTys.size() != Record.size()-1) + if (EltTys.size() != Record->size() - 1) return error("Invalid type"); - ResultTy = StructType::get(Context, EltTys, Record[0]); + ResultTy = StructType::get(Context, EltTys, (*Record)[0]); break; } - case bitc::TYPE_CODE_STRUCT_NAME: // STRUCT_NAME: [strchr x N] - if (convertToString(Record, 0, TypeName)) - return error("Invalid record"); - continue; - case bitc::TYPE_CODE_STRUCT_NAMED: { // STRUCT: [ispacked, eltty x N] - if (Record.size() < 1) + if (Record->size() < 1) return error("Invalid record"); if (NumRecords >= TypeList.size()) return error("Invalid TYPE table"); - // Check to see if this was forward referenced, if so fill in the temp. - StructType *Res = cast_or_null(TypeList[NumRecords]); - if (Res) { - Res->setName(TypeName); - TypeList[NumRecords] = nullptr; - } else // Otherwise, create a new struct. - Res = createIdentifiedStructType(Context, TypeName); - TypeName.clear(); - - SmallVector EltTys; - for (unsigned i = 1, e = Record.size(); i != e; ++i) { - if (Type *T = getTypeByID(Record[i])) + SmallVector EltTys; + for (unsigned i = 1, e = Record->size(); i != e; ++i) { + if (Type *T = getTypeByID((*Record)[i])) EltTys.push_back(T); else break; } - if (EltTys.size() != Record.size()-1) + if (EltTys.size() != Record->size() - 1) return error("Invalid record"); - Res->setBody(EltTys, Record[0]); + + StructType *Res = cast(TypeList[NumRecords]); + TypeList[NumRecords] = nullptr; + if (Res->isOpaque()) { + Res->setBody(EltTys, (*Record)[0]); + } else if (!Res->elements().equals(EltTys)) { + // Ouch! The LLVMContext's existing named struct type and the one being + // read have different structures. This must mean that the LLVMContext + // contains more than one module, and there is disagreement. Several + // possibilities: 1a. There are no opaque references to structs. In this + // case renaming either type is safe (but perhaps not desirable). + // 1b. There are opaque references, but no further Modules will be read. + // In this case renaming the type in the new module is safe. 2a. There + // may be opaque references, and all are to the type that's already in + // the Context. In this case renaming the new type is safe. 2b. There + // may be opaque references, and all are to the type that's being + // read. In this case renaming the existing type is safe. 2c. There may + // be opaque references to either. In this case nothing is safe. + + // 3. This kind of conflict should not happen. + return error( + "named struct types match by name and differ by structure"); + } + ResultTy = Res; break; } - case bitc::TYPE_CODE_OPAQUE: { // OPAQUE: [] - if (Record.size() != 1) + case bitc::TYPE_CODE_OPAQUE: { // OPAQUE: [] + if (Record->size() != 1) return error("Invalid record"); if (NumRecords >= TypeList.size()) return error("Invalid TYPE table"); - // Check to see if this was forward referenced, if so fill in the temp. - StructType *Res = cast_or_null(TypeList[NumRecords]); - if (Res) { - Res->setName(TypeName); - TypeList[NumRecords] = nullptr; - } else // Otherwise, create a new struct with no body. - Res = createIdentifiedStructType(Context, TypeName); - TypeName.clear(); + StructType *Res = cast(TypeList[NumRecords]); + TypeList[NumRecords] = nullptr; ResultTy = Res; break; } - case bitc::TYPE_CODE_ARRAY: // ARRAY: [numelts, eltty] - if (Record.size() < 2) + case bitc::TYPE_CODE_ARRAY: // ARRAY: [numelts, eltty] + if (Record->size() < 2) return error("Invalid record"); - ResultTy = getTypeByID(Record[1]); + ResultTy = getTypeByID((*Record)[1]); if (!ResultTy || !ArrayType::isValidElementType(ResultTy)) return error("Invalid type"); - ResultTy = ArrayType::get(ResultTy, Record[0]); + ResultTy = ArrayType::get(ResultTy, (*Record)[0]); break; case bitc::TYPE_CODE_VECTOR: // VECTOR: [numelts, eltty] or // [numelts, eltty, scalable] - if (Record.size() < 2) + if (Record->size() < 2) return error("Invalid record"); - if (Record[0] == 0) + if ((*Record)[0] == 0) return error("Invalid vector length"); - ResultTy = getTypeByID(Record[1]); + ResultTy = getTypeByID((*Record)[1]); if (!ResultTy || !StructType::isValidElementType(ResultTy)) return error("Invalid type"); - bool Scalable = Record.size() > 2 ? Record[2] : false; - ResultTy = VectorType::get(ResultTy, Record[0], Scalable); + bool Scalable = Record->size() > 2 ? (*Record)[2] : false; + ResultTy = VectorType::get(ResultTy, (*Record)[0], Scalable); break; } @@ -1858,6 +1871,9 @@ assert(ResultTy && "Didn't read a type?"); TypeList[NumRecords++] = ResultTy; } + if (NumRecords < TypeList.size()) + return error("Invalid TYPE table"); + return Error::success(); } Error BitcodeReader::parseOperandBundleTags() { Index: llvm/lib/IR/Type.cpp =================================================================== --- llvm/lib/IR/Type.cpp +++ llvm/lib/IR/Type.cpp @@ -336,8 +336,16 @@ // StructType Implementation //===----------------------------------------------------------------------===// +StructType *StructType::getIfExists(LLVMContext &Context, StringRef Name) { + auto I = Context.pImpl->NamedStructTypes.find(Name); + if (I == Context.pImpl->NamedStructTypes.end()) + return nullptr; + return I->getValue(); +} + // Primitive Constructors. + StructType *StructType::get(LLVMContext &Context, ArrayRef ETypes, bool isPacked) { LLVMContextImpl *pImpl = Context.pImpl; Index: llvm/unittests/Bitcode/BitReaderTest.cpp =================================================================== --- llvm/unittests/Bitcode/BitReaderTest.cpp +++ llvm/unittests/Bitcode/BitReaderTest.cpp @@ -190,4 +190,130 @@ EXPECT_FALSE(verifyModule(*M, &dbgs())); } +TEST(BitReaderTest, UseExistingNameStructType) { + // Make a module using a struct type, then write that. + LLVMContext C1; + StructType *T1 = StructType::create(C1, "Correct"); + T1->setBody(Type::getInt32Ty(C1)); + std::unique_ptr M1(new Module("M1", C1)); + M1->getOrInsertFunction("F1", T1); + SmallString<1024> Memory; + raw_svector_ostream OS(Memory); + WriteBitcodeToFile(*M1, OS); + + // Define that struct type in a new context. Read the module into that + // context. At this point, the context contains the StructType the module + // uses. + LLVMContext C2; + StructType *T2 = StructType::create(C2, T1->getName()); + T2->setBody(Type::getInt32Ty(C2)); + auto Careful = parseBitcodeFile(MemoryBufferRef(Memory.str(), "test"), C2); + EXPECT_TRUE((bool)Careful); + std::unique_ptr M2(Careful.get().release()); + + // Then one single test: was the already-present struct type used? + Function *F2 = M2->getFunction("F1"); + EXPECT_NE(F2, nullptr); + EXPECT_EQ(cast(F2->getReturnType())->getName(), T1->getName()); + + // Define a different StructType of the same name. Reading the bitcode into + // THAT context should fail. + LLVMContext C3; + StructType *T3 = StructType::create(C3, T1->getName()); + T3->setBody(Type::getFloatTy(C3)); + Careful = parseBitcodeFile(MemoryBufferRef(Memory.str(), "test"), C3); + EXPECT_FALSE((bool)Careful); + handleAllErrors(Careful.takeError(), [&](const ErrorInfoBase &DE) {}); +} + +TEST(BitReaderTest, ResolveForwardStructReferences) { + // This test is intended to resemble reading a subclass and a superclass from + // separate .bc files into the same LLVMContext. Subclasses and superclasses + // often contain functions (or methods if you will) with the same signature. + // This test approximates that using two functions with the same FunctionType. + + // Make a context using two struct types that reference each other. One of the + // types necessarily has to contain a forward reference when written to the + // .bc files. + + LLVMContext C1; + StructType *T1A = StructType::create(C1, "Struct1"); + StructType *T2A = StructType::create(C1, "Struct2"); + T1A->setBody(T2A->getPointerTo()); + T2A->setBody(T1A->getPointerTo()); + + // Make two modules, each containing one function with the same function type. + + std::unique_ptr M1A(new Module("M1", C1)); + std::unique_ptr M2A(new Module("M2", C1)); + M1A->getOrInsertFunction("F1", Type::getVoidTy(C1), T1A->getPointerTo(), + T2A->getPointerTo()); + M2A->getOrInsertFunction("F2", Type::getVoidTy(C1), T1A->getPointerTo(), + T2A->getPointerTo()); + EXPECT_EQ(M1A->getFunction("F1")->getFunctionType(), + M2A->getFunction("F2")->getFunctionType()); + + // Write both modules to separate .bc files, then read the .bc files into the + // same new context. + + SmallString<1024> BC1; + raw_svector_ostream OS1(BC1); + WriteBitcodeToFile(*M1A, OS1); + SmallString<1024> BC2; + raw_svector_ostream OS2(BC2); + WriteBitcodeToFile(*M2A, OS2); + + LLVMContext C2; + auto Careful = parseBitcodeFile(MemoryBufferRef(BC1.str(), "test"), C2); + EXPECT_TRUE((bool)Careful); + std::unique_ptr M1B(Careful.get().release()); + Careful = parseBitcodeFile(MemoryBufferRef(BC2.str(), "test"), C2); + EXPECT_TRUE((bool)Careful); + std::unique_ptr M2B(Careful.get().release()); + + // The two functions should still have the same signature. + + EXPECT_EQ(M1B->getFunction("F1")->getFunctionType(), + M2B->getFunction("F2")->getFunctionType()); + + // I wish I could EXPECT_THAT(BC1->containsForwardTypeReference()) but that + // seems entirely unreasonable. +} + +TEST(BitReaderTest, CreateUnnamedStructTypes) { + // This test checks that two unnamed named struct types aren't inappropriately + // merged. I love it when I write a test just to be safe, and it passes the + // first time I run it. + + LLVMContext C1; + StructType *T1 = StructType::create(C1); + StructType *T2 = StructType::create(C1); + T1->setBody(T2->getPointerTo()); + // T2 intentionally opaque + + std::unique_ptr M(new Module("M", C1)); + M->getOrInsertFunction("F", Type::getVoidTy(C1), T1->getPointerTo(), + T2->getPointerTo()); + + SmallString<1024> BC; + raw_svector_ostream OS(BC); + WriteBitcodeToFile(*M, OS); + + LLVMContext C2; + StructType *T3 = StructType::create(C2); + auto Careful = parseBitcodeFile(MemoryBufferRef(BC.str(), "test"), C2); + EXPECT_TRUE((bool)Careful); + + // There's no reason to believe that T1 or T2 matches T3, and T1 and T2 + // definitely do not match each other. Verify all of that. + + Function *F = Careful.get().release()->getFunction("F"); + EXPECT_NE(nullptr, F); + FunctionType *FT = F->getFunctionType(); + EXPECT_EQ(2U, FT->getNumParams()); + EXPECT_NE(FT->getParamType(0), FT->getParamType(1)); + EXPECT_NE(T3, FT->getParamType(0)); + EXPECT_NE(T3, FT->getParamType(1)); +} + } // end namespace