diff --git a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationServerLLGS.cpp b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationServerLLGS.cpp --- a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationServerLLGS.cpp +++ b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteCommunicationServerLLGS.cpp @@ -3474,15 +3474,31 @@ if (packet.GetBytesLeft() < 1 || packet.GetChar() != ':') return SendIllFormedResponse(packet, invalid_type_err); - int32_t type = - packet.GetS32(std::numeric_limits::max(), /*base=*/16); - if (type == std::numeric_limits::max() || + // Type is a signed integer but packed into the packet as its raw bytes. + // However, our GetU64 uses strtoull which allows +/-. We do not want this. + const char *first_type_char = packet.Peek(); + if (first_type_char && (*first_type_char == '+' || *first_type_char == '-')) + return SendIllFormedResponse(packet, invalid_type_err); + + // Extract type as unsigned then cast to signed. + // Using a uint64_t here so that we have some value outside of the 32 bit + // range to use as the invalid return value. + uint64_t raw_type = + packet.GetU64(std::numeric_limits::max(), /*base=*/16); + + if ( // Make sure the cast below would be valid + raw_type > std::numeric_limits::max() || // To catch inputs like "123aardvark" that will parse but clearly aren't // valid in this case. packet.GetBytesLeft()) { return SendIllFormedResponse(packet, invalid_type_err); } + // First narrow to 32 bits otherwise the copy into type would take + // the wrong 4 bytes on big endian. + uint32_t raw_type_32 = raw_type; + int32_t type = reinterpret_cast(raw_type_32); + StreamGDBRemote response; std::vector tags; Status error = m_current_process->ReadMemoryTags(type, addr, length, tags); @@ -3552,7 +3568,11 @@ packet.GetU64(std::numeric_limits::max(), /*base=*/16); if (raw_type > std::numeric_limits::max()) return SendIllFormedResponse(packet, invalid_type_err); - int32_t type = static_cast(raw_type); + + // First narrow to 32 bits. Otherwise the copy below would get the wrong + // 4 bytes on big endian. + uint32_t raw_type_32 = raw_type; + int32_t type = reinterpret_cast(raw_type_32); // Tag data if (packet.GetBytesLeft() < 1 || packet.GetChar() != ':') diff --git a/lldb/test/API/tools/lldb-server/memory-tagging/TestGdbRemoteMemoryTagging.py b/lldb/test/API/tools/lldb-server/memory-tagging/TestGdbRemoteMemoryTagging.py --- a/lldb/test/API/tools/lldb-server/memory-tagging/TestGdbRemoteMemoryTagging.py +++ b/lldb/test/API/tools/lldb-server/memory-tagging/TestGdbRemoteMemoryTagging.py @@ -105,13 +105,20 @@ self.check_tag_read("{:x},10:".format(buf_address), "E03") # Types we don't support self.check_tag_read("{:x},10:FF".format(buf_address), "E01") + # Types can also be negative, -1 in this case. + # So this is E01 for not supported, instead of E03 for invalid formatting. + self.check_tag_read("{:x},10:FFFFFFFF".format(buf_address), "E01") # (even if the length of the read is zero) self.check_tag_read("{:x},0:FF".format(buf_address), "E01") - self.check_tag_read("{:x},10:-1".format(buf_address), "E01") - self.check_tag_read("{:x},10:+20".format(buf_address), "E01") # Invalid type format self.check_tag_read("{:x},10:cat".format(buf_address), "E03") self.check_tag_read("{:x},10:?11".format(buf_address), "E03") + # Type is signed but in packet as raw bytes, no +/-. + self.check_tag_read("{:x},10:-1".format(buf_address), "E03") + self.check_tag_read("{:x},10:+20".format(buf_address), "E03") + # We do use a uint64_t for unpacking but that's just an implementation + # detail. Any value > 32 bit is invalid. + self.check_tag_read("{:x},10:123412341".format(buf_address), "E03") # Valid packets diff --git a/lldb/unittests/Process/gdb-remote/GDBRemoteCommunicationClientTest.cpp b/lldb/unittests/Process/gdb-remote/GDBRemoteCommunicationClientTest.cpp --- a/lldb/unittests/Process/gdb-remote/GDBRemoteCommunicationClientTest.cpp +++ b/lldb/unittests/Process/gdb-remote/GDBRemoteCommunicationClientTest.cpp @@ -470,19 +470,18 @@ static void check_qmemtags(TestClient &client, MockServer &server, size_t read_len, - const char *packet, llvm::StringRef response, + int32_t type, const char *packet, llvm::StringRef response, llvm::Optional> expected_tag_data) { - const auto &ReadMemoryTags = [&](size_t len, const char *packet, - llvm::StringRef response) { + const auto &ReadMemoryTags = [&]() { std::future result = std::async(std::launch::async, [&] { - return client.ReadMemoryTags(0xDEF0, read_len, 1); + return client.ReadMemoryTags(0xDEF0, read_len, type); }); HandlePacket(server, packet, response); return result.get(); }; - auto result = ReadMemoryTags(0, packet, response); + auto result = ReadMemoryTags(); if (expected_tag_data) { ASSERT_TRUE(result); llvm::ArrayRef expected(*expected_tag_data); @@ -495,41 +494,53 @@ TEST_F(GDBRemoteCommunicationClientTest, ReadMemoryTags) { // Zero length reads are valid - check_qmemtags(client, server, 0, "qMemTags:def0,0:1", "m", + check_qmemtags(client, server, 0, 1, "qMemTags:def0,0:1", "m", std::vector{}); + // Type can be negative. Put into the packet as the raw bytes + // (as opposed to a literal -1) + check_qmemtags(client, server, 0, -1, "qMemTags:def0,0:ffffffff", "m", + std::vector{}); + check_qmemtags(client, server, 0, std::numeric_limits::min(), + "qMemTags:def0,0:80000000", "m", std::vector{}); + check_qmemtags(client, server, 0, std::numeric_limits::max(), + "qMemTags:def0,0:7fffffff", "m", std::vector{}); + // The client layer does not check the length of the received data. // All we need is the "m" and for the decode to use all of the chars - check_qmemtags(client, server, 32, "qMemTags:def0,20:1", "m09", + check_qmemtags(client, server, 32, 2, "qMemTags:def0,20:2", "m09", std::vector{0x9}); // Zero length response is fine as long as the "m" is present - check_qmemtags(client, server, 0, "qMemTags:def0,0:1", "m", + check_qmemtags(client, server, 0, 0x34, "qMemTags:def0,0:34", "m", std::vector{}); // Normal responses - check_qmemtags(client, server, 16, "qMemTags:def0,10:1", "m66", + check_qmemtags(client, server, 16, 1, "qMemTags:def0,10:1", "m66", std::vector{0x66}); - check_qmemtags(client, server, 32, "qMemTags:def0,20:1", "m0102", + check_qmemtags(client, server, 32, 1, "qMemTags:def0,20:1", "m0102", std::vector{0x1, 0x2}); // Empty response is an error - check_qmemtags(client, server, 17, "qMemTags:def0,11:1", "", llvm::None); + check_qmemtags(client, server, 17, 1, "qMemTags:def0,11:1", "", llvm::None); // Usual error response - check_qmemtags(client, server, 17, "qMemTags:def0,11:1", "E01", llvm::None); + check_qmemtags(client, server, 17, 1, "qMemTags:def0,11:1", "E01", + llvm::None); // Leading m missing - check_qmemtags(client, server, 17, "qMemTags:def0,11:1", "01", llvm::None); + check_qmemtags(client, server, 17, 1, "qMemTags:def0,11:1", "01", llvm::None); // Anything other than m is an error - check_qmemtags(client, server, 17, "qMemTags:def0,11:1", "z01", llvm::None); + check_qmemtags(client, server, 17, 1, "qMemTags:def0,11:1", "z01", + llvm::None); // Decoding tag data doesn't use all the chars in the packet - check_qmemtags(client, server, 32, "qMemTags:def0,20:1", "m09zz", llvm::None); + check_qmemtags(client, server, 32, 1, "qMemTags:def0,20:1", "m09zz", + llvm::None); // Data that is not hex bytes - check_qmemtags(client, server, 32, "qMemTags:def0,20:1", "mhello", + check_qmemtags(client, server, 32, 1, "qMemTags:def0,20:1", "mhello", llvm::None); // Data is not a complete hex char - check_qmemtags(client, server, 32, "qMemTags:def0,20:1", "m9", llvm::None); + check_qmemtags(client, server, 32, 1, "qMemTags:def0,20:1", "m9", llvm::None); // Data has a trailing hex char - check_qmemtags(client, server, 32, "qMemTags:def0,20:1", "m01020", + check_qmemtags(client, server, 32, 1, "qMemTags:def0,20:1", "m01020", llvm::None); }