Index: include/llvm/Support/StreamingMemoryObject.h =================================================================== --- include/llvm/Support/StreamingMemoryObject.h +++ include/llvm/Support/StreamingMemoryObject.h @@ -55,29 +55,33 @@ std::unique_ptr Streamer; mutable size_t BytesRead; // Bytes read from stream size_t BytesSkipped;// Bytes skipped at start of stream (e.g. wrapper/header) - mutable size_t ObjectSize; // 0 if unknown, set if wrapper seen or EOF reached - mutable bool EOFReached; + mutable size_t ObjectSize; // 0 if unknown, set if wrapper seen or end of + // object reached. + mutable bool EOOReached; // end of object reached. - // Fetch enough bytes such that Pos can be read or EOF is reached - // (i.e. BytesRead > Pos). Return true if Pos can be read. - // Unlike most of the functions in BitcodeReader, returns true on success. - // Most of the requests will be small, but we fetch at kChunkSize bytes - // at a time to avoid making too many potentially expensive GetBytes calls + // Fetch enough bytes such that Pos can be read or end of object is + // reached (i.e. BytesRead > Pos). Note: EOF sets end of object if + // not already defined. Returns true if Pos can be read. Unlike + // most of the functions in BitcodeReader, returns true on success. + // Most of the requests will be small, but we fetch at kChunkSize + // bytes at a time to avoid making too many potentially expensive + // GetBytes calls bool fetchToPos(size_t Pos) const { - if (EOFReached) + if (EOOReached) return Pos < ObjectSize; while (Pos >= BytesRead) { Bytes.resize(BytesRead + BytesSkipped + kChunkSize); size_t bytes = Streamer->GetBytes(&Bytes[BytesRead + BytesSkipped], kChunkSize); BytesRead += bytes; - if (bytes != kChunkSize) { // reached EOF/ran out of bytes - ObjectSize = BytesRead; - EOFReached = true; + if (bytes == 0) { // reached EOF/ran out of bytes + if (ObjectSize == 0) + ObjectSize = BytesRead; + EOOReached = true; break; } } - return Pos < BytesRead; + return (Pos < BytesRead) || (ObjectSize && Pos < ObjectSize); } StreamingMemoryObject(const StreamingMemoryObject&) = delete; Index: lib/Support/StreamingMemoryObject.cpp =================================================================== --- lib/Support/StreamingMemoryObject.cpp +++ lib/Support/StreamingMemoryObject.cpp @@ -87,13 +87,17 @@ uint64_t StreamingMemoryObject::readBytes(uint8_t *Buf, uint64_t Size, uint64_t Address) const { fetchToPos(Address + Size - 1); - if (Address >= BytesRead) + if (Address >= BytesRead || (ObjectSize && Address >= ObjectSize)) return 0; uint64_t End = Address + Size; - if (End > BytesRead) + if (ObjectSize) { + if (End > ObjectSize) { + End = ObjectSize; + } + } else if (End > BytesRead) End = BytesRead; - assert(static_cast(End - Address) >= 0); + assert(End >= Address); Size = End - Address; memcpy(Buf, &Bytes[Address + BytesSkipped], Size); return Size; @@ -109,6 +113,8 @@ void StreamingMemoryObject::setKnownObjectSize(size_t size) { ObjectSize = size; Bytes.reserve(size); + if (BytesRead >= ObjectSize) + EOOReached = true; } MemoryObject *getNonStreamedMemoryObject(const unsigned char *Start, @@ -118,7 +124,7 @@ StreamingMemoryObject::StreamingMemoryObject(DataStreamer *streamer) : Bytes(kChunkSize), Streamer(streamer), BytesRead(0), BytesSkipped(0), - ObjectSize(0), EOFReached(false) { + ObjectSize(0), EOOReached(false) { BytesRead = streamer->GetBytes(&Bytes[0], kChunkSize); } } Index: unittests/Support/StreamingMemoryObject.cpp =================================================================== --- unittests/Support/StreamingMemoryObject.cpp +++ unittests/Support/StreamingMemoryObject.cpp @@ -27,3 +27,12 @@ StreamingMemoryObject O(DS); EXPECT_TRUE(O.isValidAddress(32 * 1024)); } + +TEST(StreamingMemoryObject, TestSetKnownObjectSize) { + auto *DS = new NullDataStreamer(); + StreamingMemoryObject O(DS); + uint8_t Buf[32]; + EXPECT_EQ((uint64_t) 16, O.readBytes(Buf, 16, 0)); + O.setKnownObjectSize(24); + EXPECT_EQ((uint64_t) 8, O.readBytes(Buf, 16, 16)); +}