diff --git a/clang-tools-extra/clangd/index/Serialization.cpp b/clang-tools-extra/clangd/index/Serialization.cpp --- a/clang-tools-extra/clangd/index/Serialization.cpp +++ b/clang-tools-extra/clangd/index/Serialization.cpp @@ -16,6 +16,7 @@ #include "support/Trace.h" #include "clang/Tooling/CompilationDatabase.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Compression.h" #include "llvm/Support/Endian.h" #include "llvm/Support/Error.h" @@ -104,6 +105,28 @@ llvm::StringRef Raw = consume(SymbolID::RawSize); // short if truncated. return LLVM_UNLIKELY(err()) ? SymbolID() : SymbolID::fromRaw(Raw); } + + // Could this size table possibly be valid to read? + // (Typically we preallocate storage to read into. This uses the filesize as + // a bound to prevent a corrupt file from causing huge allocations) + bool badSize(uint32_t N, unsigned MinSize = 1) { + uint32_t MaxPossible = (End - Begin) / MinSize; + return N > MaxPossible; + } + + // Read a varint (as consumeVar) and resize the container accordingly. + // If the size is invalid, return false and mark an error. + // (The caller should abort in this case). + template + LLVM_NODISCARD bool consumeSize(T &Container, unsigned MinSize = 1) { + auto Size = consumeVar(); + if (badSize(Size, MinSize)) { + Err = true; + return false; + } + Container.resize(Size); + return true; + } }; void write32(uint32_t I, llvm::raw_ostream &OS) { @@ -257,7 +280,8 @@ IGN.URI = Data.consumeString(Strings); llvm::StringRef Digest = Data.consume(IGN.Digest.size()); std::copy(Digest.bytes_begin(), Digest.bytes_end(), IGN.Digest.begin()); - IGN.DirectIncludes.resize(Data.consumeVar()); + if (!Data.consumeSize(IGN.DirectIncludes)) + return IGN; for (llvm::StringRef &Include : IGN.DirectIncludes) Include = Data.consumeString(Strings); return IGN; @@ -323,7 +347,8 @@ Sym.Documentation = Data.consumeString(Strings); Sym.ReturnType = Data.consumeString(Strings); Sym.Type = Data.consumeString(Strings); - Sym.IncludeHeaders.resize(Data.consumeVar()); + if (!Data.consumeSize(Sym.IncludeHeaders)) + return Sym; for (auto &I : Sym.IncludeHeaders) { I.IncludeHeader = Data.consumeString(Strings); I.References = Data.consumeVar(); @@ -353,7 +378,8 @@ readRefs(Reader &Data, llvm::ArrayRef Strings) { std::pair> Result; Result.first = Data.consumeID(); - Result.second.resize(Data.consumeVar()); + if (!Data.consumeSize(Result.second)) + return Result; for (auto &Ref : Result.second) { Ref.Kind = static_cast(Data.consume8()); Ref.Location = readLocation(Data, Strings); @@ -400,7 +426,8 @@ readCompileCommand(Reader CmdReader, llvm::ArrayRef Strings) { InternedCompileCommand Cmd; Cmd.Directory = CmdReader.consumeString(Strings); - Cmd.CommandLine.resize(CmdReader.consumeVar()); + if (!CmdReader.consumeSize(Cmd.CommandLine)) + return Cmd; for (llvm::StringRef &C : Cmd.CommandLine) C = CmdReader.consumeString(Strings); return Cmd; diff --git a/clang-tools-extra/clangd/unittests/SerializationTests.cpp b/clang-tools-extra/clangd/unittests/SerializationTests.cpp --- a/clang-tools-extra/clangd/unittests/SerializationTests.cpp +++ b/clang-tools-extra/clangd/unittests/SerializationTests.cpp @@ -7,15 +7,21 @@ //===----------------------------------------------------------------------===// #include "Headers.h" +#include "RIFF.h" #include "index/Index.h" #include "index/Serialization.h" +#include "support/Logger.h" #include "clang/Tooling/CompilationDatabase.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Error.h" #include "llvm/Support/ScopedPrinter.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#ifdef LLVM_ON_UNIX +#include +#endif -using ::testing::_; -using ::testing::AllOf; using ::testing::ElementsAre; using ::testing::Pair; using ::testing::UnorderedElementsAre; @@ -297,6 +303,87 @@ EXPECT_NE(SerializedCmd.Output, Cmd.Output); } } + +#if LLVM_ON_UNIX // rlimit is part of POSIX +class ScopedMemoryLimit { + struct rlimit OriginalLimit; + bool Succeeded = false; + +public: + ScopedMemoryLimit(rlim_t Bytes) { + if (!getrlimit(RLIMIT_AS, &OriginalLimit)) { + struct rlimit NewLimit = OriginalLimit; + NewLimit.rlim_cur = Bytes; + Succeeded = !setrlimit(RLIMIT_AS, &NewLimit); + } + if (!Succeeded) + elog("Failed to set rlimit"); + } + + ~ScopedMemoryLimit() { + if (Succeeded) + setrlimit(RLIMIT_AS, &OriginalLimit); + } +}; +#else +class ScopedMemoryLimit { +public: + ScopedMemoryLimit(unsigned Bytes) {} +}; +#endif + +// Test that our deserialization detects invalid array sizes without allocating. +// If this detection fails, the test should allocate a huge array and crash. +TEST(SerializationTest, NoCrashOnBadArraySize) { + // This test is tricky because we need to construct a subtly invalid file. + // First, create a valid serialized file. + auto In = readIndexFile(YAML); + ASSERT_FALSE(!In) << In.takeError(); + IndexFileOut Out(*In); + Out.Format = IndexFileFormat::RIFF; + std::string Serialized = llvm::to_string(Out); + llvm::consumeError(readIndexFile(Serialized).takeError()); + + // Low-level parse it again and find the `srcs` chunk we're going to corrupt. + auto Parsed = riff::readFile(Serialized); + ASSERT_FALSE(!Parsed) << Parsed.takeError(); + auto Srcs = llvm::find_if(Parsed->Chunks, [](riff::Chunk C) { + return C.ID == riff::fourCC("srcs"); + }); + ASSERT_NE(Srcs, Parsed->Chunks.end()); + + // Srcs consists of a sequence of IncludeGraphNodes. In our case, just one. + // The node has: + // - 1 byte: flags (1) + // - varint(stringID): URI + // - 8 byte: file digest + // - varint: DirectIncludes.length + // - repeated varint(stringID): DirectIncludes + // We want to set DirectIncludes.length to a huge number. + // The offset isn't trivial to find, so we use the file digest. + std::string FileDigest = llvm::fromHex("EED8F5EAF25C453C"); + unsigned Pos = Srcs->Data.find_first_of(FileDigest); + ASSERT_NE(Pos, StringRef::npos) << "Couldn't locate file digest"; + Pos += FileDigest.size(); + + // Varints are little-endian base-128 numbers, where the top-bit of each byte + // indicates whether there are more. 8fffffff7f -> 0xffffffff. + std::string CorruptSrcs = + (Srcs->Data.take_front(Pos) + llvm::fromHex("8fffffff7f") + + "some_random_garbage") + .str(); + Srcs->Data = CorruptSrcs; + + // Try to crash rather than hang on large allocation. + ScopedMemoryLimit MemLimit(1000 * 1024 * 1024); // 1GB + + std::string CorruptFile = llvm::to_string(*Parsed); + auto CorruptParsed = readIndexFile(CorruptFile); + ASSERT_TRUE(!CorruptParsed); + EXPECT_EQ(llvm::toString(CorruptParsed.takeError()), + "malformed or truncated include uri"); +} + } // namespace } // namespace clangd } // namespace clang