diff options
Diffstat (limited to 'src/zenhorde/hordecomputesocket.cpp')
| -rw-r--r-- | src/zenhorde/hordecomputesocket.cpp | 410 |
1 files changed, 269 insertions, 141 deletions
diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp index 6ef67760c..8a6fc40a9 100644 --- a/src/zenhorde/hordecomputesocket.cpp +++ b/src/zenhorde/hordecomputesocket.cpp @@ -6,198 +6,326 @@ namespace zen::horde { -ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport) -: m_Log(zen::logging::Get("horde.socket")) +AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext) +: m_Log(zen::logging::Get("horde.socket.async")) , m_Transport(std::move(Transport)) +, m_Strand(asio::make_strand(IoContext)) +, m_PingTimer(m_Strand) { } -ComputeSocket::~ComputeSocket() +AsyncComputeSocket::~AsyncComputeSocket() { - // Shutdown order matters: first stop the ping thread, then unblock send threads - // by detaching readers, then join send threads, and finally close the transport - // to unblock the recv thread (which is blocked on RecvMessage). - { - std::lock_guard<std::mutex> Lock(m_PingMutex); - m_PingShouldStop = true; - m_PingCV.notify_all(); - } - - for (auto& Reader : m_Readers) - { - Reader.Detach(); - } - - for (auto& [Id, Thread] : m_SendThreads) - { - if (Thread.joinable()) - { - Thread.join(); - } - } + Close(); +} - m_Transport->Close(); +void +AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach) +{ + m_FrameHandlers[ChannelId] = std::move(OnFrame); + m_DetachHandlers[ChannelId] = std::move(OnDetach); +} - if (m_RecvThread.joinable()) - { - m_RecvThread.join(); - } - if (m_PingThread.joinable()) - { - m_PingThread.join(); - } +void +AsyncComputeSocket::StartRecvPump() +{ + StartPingTimer(); + DoRecvHeader(); } -Ref<ComputeChannel> -ComputeSocket::CreateChannel(int ChannelId) +void +AsyncComputeSocket::DoRecvHeader() { - ComputeBuffer::Params Params; + auto Self = shared_from_this(); + m_Transport->AsyncRead(&m_RecvHeader, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv header error: {}", Ec.message()); + HandleError(); + } + return; + } - ComputeBuffer RecvBuffer; - if (!RecvBuffer.CreateNew(Params)) - { - return {}; - } + if (m_Closed) + { + return; + } - ComputeBuffer SendBuffer; - if (!SendBuffer.CreateNew(Params)) - { - return {}; - } + if (m_RecvHeader.Size >= 0) + { + DoRecvPayload(m_RecvHeader); + } + else if (m_RecvHeader.Size == ControlDetach) + { + if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second) + { + It->second(); + } + DoRecvHeader(); + } + else if (m_RecvHeader.Size == ControlPing) + { + DoRecvHeader(); + } + else + { + ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size); + } + })); +} - Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); +void +AsyncComputeSocket::DoRecvPayload(FrameHeader Header) +{ + auto PayloadBuf = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(Header.Size)); + auto Self = shared_from_this(); - // Attach recv buffer writer (transport recv thread writes into this) - { - std::lock_guard<std::mutex> Lock(m_WritersMutex); - m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); - } + m_Transport->AsyncRead(PayloadBuf->data(), + PayloadBuf->size(), + asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + if (Ec != asio::error::operation_aborted && !m_Closed) + { + ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message()); + HandleError(); + } + return; + } - // Attach send buffer reader (send thread reads from this) - { - ComputeBufferReader Reader = SendBuffer.CreateReader(); - m_Readers.push_back(Reader); - m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); - } + if (m_Closed) + { + return; + } + + if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second) + { + It->second(std::move(*PayloadBuf)); + } + else + { + ZEN_WARN("recv frame for unknown channel {}", Header.Channel); + } - return Channel; + DoRecvHeader(); + })); } void -ComputeSocket::StartCommunication() +AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler) { - m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); - m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable { + if (m_Closed) + { + if (Handler) + { + Handler(asio::error::make_error_code(asio::error::operation_aborted)); + } + return; + } + + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = static_cast<int32_t>(Data.size()); + Write.Data = std::move(Data); + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::PingThreadProc() +AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler) { - while (true) - { + auto Self = shared_from_this(); + asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable { + if (m_Closed) { - std::unique_lock<std::mutex> Lock(m_PingMutex); - if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) + if (Handler) { - break; + Handler(asio::error::make_error_code(asio::error::operation_aborted)); } + return; } - std::lock_guard<std::mutex> Lock(m_SendMutex); - FrameHeader Header; - Header.Channel = 0; - Header.Size = ControlPing; - m_Transport->SendMessage(&Header, sizeof(Header)); - } + PendingWrite Write; + Write.Header.Channel = ChannelId; + Write.Header.Size = ControlDetach; + Write.Handler = std::move(Handler); + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) + { + FlushNextSend(); + } + }); } void -ComputeSocket::RecvThreadProc() +AsyncComputeSocket::FlushNextSend() { - // Writers are cached locally to avoid taking m_WritersMutex on every frame. - // The shared m_Writers map is only accessed when a channel is seen for the first time. - std::unordered_map<int, ComputeBufferWriter> CachedWriters; + if (m_SendQueue.empty() || m_Closed) + { + return; + } - FrameHeader Header; - while (m_Transport->RecvMessage(&Header, sizeof(Header))) + PendingWrite& Front = m_SendQueue.front(); + + if (Front.Data.empty()) { - if (Header.Size >= 0) - { - // Data frame - auto It = CachedWriters.find(Header.Channel); - if (It == CachedWriters.end()) - { - std::lock_guard<std::mutex> Lock(m_WritersMutex); - auto WIt = m_Writers.find(Header.Channel); - if (WIt == m_Writers.end()) - { - ZEN_WARN("recv frame for unknown channel {}", Header.Channel); - // Skip the data - std::vector<uint8_t> Discard(Header.Size); - m_Transport->RecvMessage(Discard.data(), Header.Size); - continue; - } - It = CachedWriters.emplace(Header.Channel, WIt->second).first; - } + // Control frame — header only + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); - ComputeBufferWriter& Writer = It->second; - uint8_t* Dest = Writer.WaitToWrite(Header.Size); - if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) - { - ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); - return; - } - Writer.AdvanceWritePosition(Header.Size); - } - else if (Header.Size == ControlDetach) - { - // Detach the recv buffer for this channel - CachedWriters.erase(Header.Channel); + if (Handler) + { + Handler(Ec); + } - std::lock_guard<std::mutex> Lock(m_WritersMutex); - auto It = m_Writers.find(Header.Channel); - if (It != m_Writers.end()) - { - It->second.MarkComplete(); - m_Writers.erase(It); - } - } - else if (Header.Size == ControlPing) + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + } + else + { + // Data frame — write header first, then payload + auto Self = shared_from_this(); + m_Transport->AsyncWrite(&Front.Header, + sizeof(FrameHeader), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + if (Ec) + { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + if (Handler) + { + Handler(Ec); + } + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send header error: {}", Ec.message()); + HandleError(); + } + return; + } + + PendingWrite& Payload = m_SendQueue.front(); + m_Transport->AsyncWrite( + Payload.Data.data(), + Payload.Data.size(), + asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { + SendHandler Handler = std::move(m_SendQueue.front().Handler); + m_SendQueue.pop_front(); + + if (Handler) + { + Handler(Ec); + } + + if (Ec) + { + if (Ec != asio::error::operation_aborted) + { + ZEN_WARN("send payload error: {}", Ec.message()); + HandleError(); + } + return; + } + + FlushNextSend(); + })); + })); + } +} + +void +AsyncComputeSocket::StartPingTimer() +{ + if (m_Closed) + { + return; + } + + m_PingTimer.expires_after(std::chrono::seconds(2)); + + auto Self = shared_from_this(); + m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) { + if (Ec || m_Closed) { - // Ping response - ignore + return; } - else + + // Enqueue a ping control frame + PendingWrite Write; + Write.Header.Channel = 0; + Write.Header.Size = ControlPing; + + m_SendQueue.push_back(std::move(Write)); + if (m_SendQueue.size() == 1) { - ZEN_WARN("invalid frame header size: {}", Header.Size); - return; + FlushNextSend(); } - } + + StartPingTimer(); + })); } void -ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) +AsyncComputeSocket::HandleError() { - // Each channel has its own send thread. All send threads share m_SendMutex - // to serialize writes to the transport, since TCP requires atomic frame writes. - FrameHeader Header; - Header.Channel = Channel; + if (m_Closed) + { + return; + } + + Close(); - const uint8_t* Data; - while ((Data = Reader.WaitToRead(1)) != nullptr) + // Notify all channels that the connection is gone so agents can clean up + for (auto& [ChannelId, Handler] : m_DetachHandlers) { - std::lock_guard<std::mutex> Lock(m_SendMutex); + if (Handler) + { + Handler(); + } + } +} - Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize()); - m_Transport->SendMessage(&Header, sizeof(Header)); - m_Transport->SendMessage(Data, Header.Size); - Reader.AdvanceReadPosition(Header.Size); +void +AsyncComputeSocket::Close() +{ + if (m_Closed) + { + return; } - if (Reader.IsComplete()) + m_Closed = true; + m_PingTimer.cancel(); + + if (m_Transport) { - std::lock_guard<std::mutex> Lock(m_SendMutex); - Header.Size = ControlDetach; - m_Transport->SendMessage(&Header, sizeof(Header)); + m_Transport->Close(); } } |