Index: lldb/include/lldb/Core/Communication.h =================================================================== --- lldb/include/lldb/Core/Communication.h +++ lldb/include/lldb/Core/Communication.h @@ -99,6 +99,9 @@ eBroadcastBitNoMorePendingInput = (1u << 5), ///< Sent by the read thread ///to indicate all pending ///input has been processed. + eBroadcastBitWriteDone = (1u << 6), ///< Sent by the read thread + ///to indicate that a write request + ///has been processed. kLoUserBroadcastBit = (1u << 16), ///< Subclasses can used bits 31:16 for any needed events. kHiUserBroadcastBit = (1u << 31), @@ -325,6 +328,7 @@ std::atomic m_read_thread_did_exit; std::string m_bytes; ///< A buffer to cache bytes read in the ReadThread function. + std::string m_write_bytes; ///< A buffer used to pass write bytes. std::recursive_mutex m_bytes_mutex; ///< A mutex to protect multi-threaded ///access to the cached bytes. lldb::ConnectionStatus m_pass_status; ///< Connection status passthrough @@ -340,6 +344,8 @@ size_t ReadFromConnection(void *dst, size_t dst_len, const Timeout &timeout, lldb::ConnectionStatus &status, Status *error_ptr); + size_t WriteToConnection(const void *src, size_t src_len, + lldb::ConnectionStatus &status, Status *error_ptr); /// Append new bytes that get read from the read thread into the internal /// object byte cache. This will cause a \b eBroadcastBitReadThreadGotBytes @@ -380,6 +386,8 @@ /// The number of bytes extracted from the data cache. size_t GetCachedBytes(void *dst, size_t dst_len); + void WriteThread(); + private: Communication(const Communication &) = delete; const Communication &operator=(const Communication &) = delete; Index: lldb/source/Core/Communication.cpp =================================================================== --- lldb/source/Core/Communication.cpp +++ lldb/source/Core/Communication.cpp @@ -186,21 +186,54 @@ size_t Communication::Write(const void *src, size_t src_len, ConnectionStatus &status, Status *error_ptr) { - lldb::ConnectionSP connection_sp(m_connection_sp); - std::lock_guard guard(m_write_mutex); LLDB_LOG(GetLog(LLDBLog::Communication), "{0} Communication::Write (src = {1}, src_len = {2}" ") connection = {3}", - this, src, (uint64_t)src_len, connection_sp.get()); + this, src, (uint64_t)src_len, m_connection_sp.get()); - if (connection_sp) - return connection_sp->Write(src, src_len, status, error_ptr); + if (m_read_thread_enabled) { + assert(m_write_bytes.empty()); - if (error_ptr) - error_ptr->SetErrorString("Invalid connection."); - status = eConnectionStatusNoConnection; - return 0; + if (!m_connection_sp) { + if (error_ptr) + error_ptr->SetErrorString("Invalid connection."); + status = eConnectionStatusNoConnection; + return 0; + } + + ListenerSP listener_sp(Listener::MakeListener("Communication::Write")); + listener_sp->StartListeningForEvents( + this, eBroadcastBitWriteDone | eBroadcastBitReadThreadDidExit); + + m_write_bytes.assign(static_cast(src), src_len); + m_io_loop->AddPendingCallback( + [this](MainLoopBase &loop) { WriteThread(); }); + + EventSP event_sp; + while (m_read_thread_enabled) { + if (listener_sp->GetEvent(event_sp, std::chrono::seconds(5))) { + const uint32_t event_type = event_sp->GetType(); + if (event_type & eBroadcastBitWriteDone) { + size_t ret = src_len - m_write_bytes.size(); + status = m_pass_status; + if (error_ptr) + *error_ptr = std::move(m_pass_error); + m_write_bytes.clear(); + return ret; + } + + if (event_type & eBroadcastBitReadThreadDidExit) + break; + } + } + + // If read thread exited before performing the write, fall back + // to writing directly. + m_write_bytes.clear(); + } + + return WriteToConnection(src, src_len, status, error_ptr); } size_t Communication::WriteAll(const void *src, size_t src_len, @@ -330,6 +363,19 @@ return 0; } +size_t Communication::WriteToConnection(const void *src, size_t src_len, + ConnectionStatus &status, + Status *error_ptr) { + lldb::ConnectionSP connection_sp(m_connection_sp); + if (connection_sp) + return connection_sp->Write(src, src_len, status, error_ptr); + + if (error_ptr) + error_ptr->SetErrorString("Invalid connection."); + status = eConnectionStatusNoConnection; + return 0; +} + bool Communication::ReadThreadIsRunning() { return m_read_thread_enabled; } lldb::thread_result_t Communication::ReadThread() { @@ -429,6 +475,24 @@ return {}; } +void Communication::WriteThread() { + // There should be only one pending request queued. + assert(!m_write_bytes.empty()); + + ConnectionStatus status = eConnectionStatusSuccess; + Status error; + do { + size_t bytes_written = WriteToConnection( + m_write_bytes.data(), m_write_bytes.size(), status, &error); + if (bytes_written > 0) + m_write_bytes.erase(0, bytes_written); + } while (!m_write_bytes.empty() && status == eConnectionStatusSuccess); + + m_pass_status = status; + m_pass_error = std::move(error); + BroadcastEvent(eBroadcastBitWriteDone); +} + void Communication::SetReadThreadBytesReceivedCallback( ReadThreadBytesReceived callback, void *callback_baton) { m_callback = callback; Index: lldb/unittests/Core/CommunicationTest.cpp =================================================================== --- lldb/unittests/Core/CommunicationTest.cpp +++ lldb/unittests/Core/CommunicationTest.cpp @@ -116,6 +116,91 @@ CommunicationReadTest(/*use_thread=*/true); } +static void CommunicationWriteTest(bool use_read_thread) { + std::unique_ptr a, b; + ASSERT_TRUE(CreateTCPConnectedSockets("localhost", &a, &b)); + + Communication comm("test"); + comm.SetConnection(std::make_unique(a.release())); + comm.SetCloseOnEOF(true); + + if (use_read_thread) + ASSERT_TRUE(comm.StartReadThread()); + + // In our test case, a short Write() should be atomic. + lldb::ConnectionStatus status = lldb::eConnectionStatusSuccess; + Status error; + std::string test_str{"test"}; + EXPECT_EQ(comm.Write(test_str.data(), test_str.size(), status, &error), 4U); + EXPECT_EQ(status, lldb::eConnectionStatusSuccess); + EXPECT_THAT_ERROR(error.ToError(), llvm::Succeeded()); + + char buf[5]; + size_t bytes_read = 4; + ASSERT_THAT_ERROR(b->Read(buf, bytes_read).ToError(), llvm::Succeeded()); + EXPECT_EQ(bytes_read, 4U); + buf[4] = 0; + EXPECT_EQ(test_str, buf); + + // Test WriteAll() too. + test_str[3] = '2'; + error.Clear(); + EXPECT_EQ(comm.WriteAll(test_str.data(), test_str.size(), status, &error), + 4U); + EXPECT_EQ(status, lldb::eConnectionStatusSuccess); + EXPECT_THAT_ERROR(error.ToError(), llvm::Succeeded()); + + bytes_read = 4; + ASSERT_THAT_ERROR(b->Read(buf, bytes_read).ToError(), llvm::Succeeded()); + EXPECT_EQ(bytes_read, 4U); + EXPECT_EQ(test_str, buf); + + EXPECT_TRUE(comm.StopReadThread()); + + // Test using Communication that is disconnected. + ASSERT_EQ(comm.Disconnect(), lldb::eConnectionStatusSuccess); + if (use_read_thread) + ASSERT_TRUE(comm.StartReadThread()); + error.Clear(); + EXPECT_EQ(comm.Write(test_str.data(), test_str.size(), status, &error), 0U); + EXPECT_EQ(status, lldb::eConnectionStatusNoConnection); + EXPECT_THAT_ERROR(error.ToError(), llvm::Failed()); + EXPECT_TRUE(comm.StopReadThread()); + + // Test using Communication without a connection. + comm.SetConnection(nullptr); + if (use_read_thread) + ASSERT_TRUE(comm.StartReadThread()); + error.Clear(); + EXPECT_EQ(comm.Write(test_str.data(), test_str.size(), status, &error), 0U); + EXPECT_EQ(status, lldb::eConnectionStatusNoConnection); + EXPECT_THAT_ERROR(error.ToError(), llvm::Failed()); + EXPECT_TRUE(comm.StopReadThread()); + + // Test using the wrong end of a pipe. + Pipe pipe; + ASSERT_THAT_ERROR(pipe.CreateNew(/*child_process_inherit=*/false).ToError(), + llvm::Succeeded()); + comm.SetConnection(std::make_unique( + pipe.ReleaseReadFileDescriptor(), /*owns_fd=*/true)); + comm.SetCloseOnEOF(true); + if (use_read_thread) + ASSERT_TRUE(comm.StartReadThread()); + error.Clear(); + EXPECT_EQ(comm.Write(test_str.data(), test_str.size(), status, &error), 0U); + EXPECT_EQ(status, lldb::eConnectionStatusError); + EXPECT_THAT_ERROR(error.ToError(), llvm::Failed()); + EXPECT_TRUE(comm.StopReadThread()); +} + +TEST_F(CommunicationTest, Write) { + CommunicationWriteTest(/*use_thread=*/false); +} + +TEST_F(CommunicationTest, WriteThread) { + CommunicationWriteTest(/*use_thread=*/true); +} + TEST_F(CommunicationTest, StopReadThread) { std::condition_variable finished; std::mutex finished_mutex;