Index: include/llvm/Support/StreamingMemoryObject.h =================================================================== --- include/llvm/Support/StreamingMemoryObject.h +++ include/llvm/Support/StreamingMemoryObject.h @@ -55,25 +55,32 @@ std::unique_ptr Streamer; mutable size_t BytesRead; // Bytes read from stream size_t BytesSkipped;// Bytes skipped at start of stream (e.g. wrapper/header) - mutable size_t ObjectSize; // 0 if unknown, set if wrapper seen or EOF reached - mutable bool EOFReached; + mutable size_t ObjectSize; // 0 if unknown, set if wrapper seen or end of + // object reached. + mutable bool EOOReached; // end of object reached. - // Fetch enough bytes such that Pos can be read or EOF is reached - // (i.e. BytesRead > Pos). Return true if Pos can be read. - // Unlike most of the functions in BitcodeReader, returns true on success. - // Most of the requests will be small, but we fetch at kChunkSize bytes - // at a time to avoid making too many potentially expensive GetBytes calls + // Fetch enough bytes such that Pos can be read or end of object is + // reached (i.e. BytesRead > Pos). Note: EOF sets end of object if + // not already defined. Returns true if Pos can be read. Unlike + // most of the functions in BitcodeReader, returns true on success. + // Most of the requests will be small, but we fetch at kChunkSize + // bytes at a time to avoid making too many potentially expensive + // GetBytes calls bool fetchToPos(size_t Pos) const { - if (EOFReached) + if (EOOReached) return Pos < ObjectSize; while (Pos >= BytesRead) { Bytes.resize(BytesRead + BytesSkipped + kChunkSize); size_t bytes = Streamer->GetBytes(&Bytes[BytesRead + BytesSkipped], kChunkSize); BytesRead += bytes; - if (bytes != kChunkSize) { // reached EOF/ran out of bytes - ObjectSize = BytesRead; - EOFReached = true; + if (ObjectSize && BytesRead > ObjectSize) { + BytesRead = ObjectSize; + EOOReached = true; + } else if (bytes == 0) { // reached EOF/ran out of bytes + if (ObjectSize == 0) + ObjectSize = BytesRead; + EOOReached = true; break; } } Index: lib/Support/StreamingMemoryObject.cpp =================================================================== --- lib/Support/StreamingMemoryObject.cpp +++ lib/Support/StreamingMemoryObject.cpp @@ -93,7 +93,7 @@ uint64_t End = Address + Size; if (End > BytesRead) End = BytesRead; - assert(static_cast(End - Address) >= 0); + assert(End >= Address); Size = End - Address; memcpy(Buf, &Bytes[Address + BytesSkipped], Size); return Size; @@ -109,6 +109,10 @@ void StreamingMemoryObject::setKnownObjectSize(size_t size) { ObjectSize = size; Bytes.reserve(size); + if (BytesRead >= ObjectSize) { + BytesRead = ObjectSize; + EOOReached = true; + } } MemoryObject *getNonStreamedMemoryObject(const unsigned char *Start, @@ -118,7 +122,7 @@ StreamingMemoryObject::StreamingMemoryObject(DataStreamer *streamer) : Bytes(kChunkSize), Streamer(streamer), BytesRead(0), BytesSkipped(0), - ObjectSize(0), EOFReached(false) { + ObjectSize(0), EOOReached(false) { BytesRead = streamer->GetBytes(&Bytes[0], kChunkSize); } } Index: unittests/Support/StreamingMemoryObject.cpp =================================================================== --- unittests/Support/StreamingMemoryObject.cpp +++ unittests/Support/StreamingMemoryObject.cpp @@ -27,3 +27,12 @@ StreamingMemoryObject O(DS); EXPECT_TRUE(O.isValidAddress(32 * 1024)); } + +TEST(StreamingMemoryObject, TestSetKnownObjectSize) { + auto *DS = new NullDataStreamer(); + StreamingMemoryObject O(DS); + uint8_t Buf[32]; + EXPECT_EQ((uint64_t) 16, O.readBytes(Buf, 16, 0)); + O.setKnownObjectSize(24); + EXPECT_EQ((uint64_t) 8, O.readBytes(Buf, 16, 16)); +}