aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordecomputesocket.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordecomputesocket.cpp')
-rw-r--r--src/zenhorde/hordecomputesocket.cpp410
1 files changed, 269 insertions, 141 deletions
diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp
index 6ef67760c..8a6fc40a9 100644
--- a/src/zenhorde/hordecomputesocket.cpp
+++ b/src/zenhorde/hordecomputesocket.cpp
@@ -6,198 +6,326 @@
namespace zen::horde {
-ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport)
-: m_Log(zen::logging::Get("horde.socket"))
+AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr<AsyncComputeTransport> Transport, asio::io_context& IoContext)
+: m_Log(zen::logging::Get("horde.socket.async"))
, m_Transport(std::move(Transport))
+, m_Strand(asio::make_strand(IoContext))
+, m_PingTimer(m_Strand)
{
}
-ComputeSocket::~ComputeSocket()
+AsyncComputeSocket::~AsyncComputeSocket()
{
- // Shutdown order matters: first stop the ping thread, then unblock send threads
- // by detaching readers, then join send threads, and finally close the transport
- // to unblock the recv thread (which is blocked on RecvMessage).
- {
- std::lock_guard<std::mutex> Lock(m_PingMutex);
- m_PingShouldStop = true;
- m_PingCV.notify_all();
- }
-
- for (auto& Reader : m_Readers)
- {
- Reader.Detach();
- }
-
- for (auto& [Id, Thread] : m_SendThreads)
- {
- if (Thread.joinable())
- {
- Thread.join();
- }
- }
+ Close();
+}
- m_Transport->Close();
+void
+AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach)
+{
+ m_FrameHandlers[ChannelId] = std::move(OnFrame);
+ m_DetachHandlers[ChannelId] = std::move(OnDetach);
+}
- if (m_RecvThread.joinable())
- {
- m_RecvThread.join();
- }
- if (m_PingThread.joinable())
- {
- m_PingThread.join();
- }
+void
+AsyncComputeSocket::StartRecvPump()
+{
+ StartPingTimer();
+ DoRecvHeader();
}
-Ref<ComputeChannel>
-ComputeSocket::CreateChannel(int ChannelId)
+void
+AsyncComputeSocket::DoRecvHeader()
{
- ComputeBuffer::Params Params;
+ auto Self = shared_from_this();
+ m_Transport->AsyncRead(&m_RecvHeader,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted && !m_Closed)
+ {
+ ZEN_WARN("recv header error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
- ComputeBuffer RecvBuffer;
- if (!RecvBuffer.CreateNew(Params))
- {
- return {};
- }
+ if (m_Closed)
+ {
+ return;
+ }
- ComputeBuffer SendBuffer;
- if (!SendBuffer.CreateNew(Params))
- {
- return {};
- }
+ if (m_RecvHeader.Size >= 0)
+ {
+ DoRecvPayload(m_RecvHeader);
+ }
+ else if (m_RecvHeader.Size == ControlDetach)
+ {
+ if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second)
+ {
+ It->second();
+ }
+ DoRecvHeader();
+ }
+ else if (m_RecvHeader.Size == ControlPing)
+ {
+ DoRecvHeader();
+ }
+ else
+ {
+ ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size);
+ }
+ }));
+}
- Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter()));
+void
+AsyncComputeSocket::DoRecvPayload(FrameHeader Header)
+{
+ auto PayloadBuf = std::make_shared<std::vector<uint8_t>>(static_cast<size_t>(Header.Size));
+ auto Self = shared_from_this();
- // Attach recv buffer writer (transport recv thread writes into this)
- {
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter());
- }
+ m_Transport->AsyncRead(PayloadBuf->data(),
+ PayloadBuf->size(),
+ asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted && !m_Closed)
+ {
+ ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message());
+ HandleError();
+ }
+ return;
+ }
- // Attach send buffer reader (send thread reads from this)
- {
- ComputeBufferReader Reader = SendBuffer.CreateReader();
- m_Readers.push_back(Reader);
- m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader)));
- }
+ if (m_Closed)
+ {
+ return;
+ }
+
+ if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second)
+ {
+ It->second(std::move(*PayloadBuf));
+ }
+ else
+ {
+ ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
+ }
- return Channel;
+ DoRecvHeader();
+ }));
}
void
-ComputeSocket::StartCommunication()
+AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector<uint8_t> Data, SendHandler Handler)
{
- m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this);
- m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this);
+ auto Self = shared_from_this();
+ asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable {
+ if (m_Closed)
+ {
+ if (Handler)
+ {
+ Handler(asio::error::make_error_code(asio::error::operation_aborted));
+ }
+ return;
+ }
+
+ PendingWrite Write;
+ Write.Header.Channel = ChannelId;
+ Write.Header.Size = static_cast<int32_t>(Data.size());
+ Write.Data = std::move(Data);
+ Write.Handler = std::move(Handler);
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
+ {
+ FlushNextSend();
+ }
+ });
}
void
-ComputeSocket::PingThreadProc()
+AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler)
{
- while (true)
- {
+ auto Self = shared_from_this();
+ asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable {
+ if (m_Closed)
{
- std::unique_lock<std::mutex> Lock(m_PingMutex);
- if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; }))
+ if (Handler)
{
- break;
+ Handler(asio::error::make_error_code(asio::error::operation_aborted));
}
+ return;
}
- std::lock_guard<std::mutex> Lock(m_SendMutex);
- FrameHeader Header;
- Header.Channel = 0;
- Header.Size = ControlPing;
- m_Transport->SendMessage(&Header, sizeof(Header));
- }
+ PendingWrite Write;
+ Write.Header.Channel = ChannelId;
+ Write.Header.Size = ControlDetach;
+ Write.Handler = std::move(Handler);
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
+ {
+ FlushNextSend();
+ }
+ });
}
void
-ComputeSocket::RecvThreadProc()
+AsyncComputeSocket::FlushNextSend()
{
- // Writers are cached locally to avoid taking m_WritersMutex on every frame.
- // The shared m_Writers map is only accessed when a channel is seen for the first time.
- std::unordered_map<int, ComputeBufferWriter> CachedWriters;
+ if (m_SendQueue.empty() || m_Closed)
+ {
+ return;
+ }
- FrameHeader Header;
- while (m_Transport->RecvMessage(&Header, sizeof(Header)))
+ PendingWrite& Front = m_SendQueue.front();
+
+ if (Front.Data.empty())
{
- if (Header.Size >= 0)
- {
- // Data frame
- auto It = CachedWriters.find(Header.Channel);
- if (It == CachedWriters.end())
- {
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- auto WIt = m_Writers.find(Header.Channel);
- if (WIt == m_Writers.end())
- {
- ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
- // Skip the data
- std::vector<uint8_t> Discard(Header.Size);
- m_Transport->RecvMessage(Discard.data(), Header.Size);
- continue;
- }
- It = CachedWriters.emplace(Header.Channel, WIt->second).first;
- }
+ // Control frame — header only
+ auto Self = shared_from_this();
+ m_Transport->AsyncWrite(&Front.Header,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
- ComputeBufferWriter& Writer = It->second;
- uint8_t* Dest = Writer.WaitToWrite(Header.Size);
- if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size))
- {
- ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size);
- return;
- }
- Writer.AdvanceWritePosition(Header.Size);
- }
- else if (Header.Size == ControlDetach)
- {
- // Detach the recv buffer for this channel
- CachedWriters.erase(Header.Channel);
+ if (Handler)
+ {
+ Handler(Ec);
+ }
- std::lock_guard<std::mutex> Lock(m_WritersMutex);
- auto It = m_Writers.find(Header.Channel);
- if (It != m_Writers.end())
- {
- It->second.MarkComplete();
- m_Writers.erase(It);
- }
- }
- else if (Header.Size == ControlPing)
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ FlushNextSend();
+ }));
+ }
+ else
+ {
+ // Data frame — write header first, then payload
+ auto Self = shared_from_this();
+ m_Transport->AsyncWrite(&Front.Header,
+ sizeof(FrameHeader),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ if (Ec)
+ {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
+ if (Handler)
+ {
+ Handler(Ec);
+ }
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send header error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ PendingWrite& Payload = m_SendQueue.front();
+ m_Transport->AsyncWrite(
+ Payload.Data.data(),
+ Payload.Data.size(),
+ asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) {
+ SendHandler Handler = std::move(m_SendQueue.front().Handler);
+ m_SendQueue.pop_front();
+
+ if (Handler)
+ {
+ Handler(Ec);
+ }
+
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_WARN("send payload error: {}", Ec.message());
+ HandleError();
+ }
+ return;
+ }
+
+ FlushNextSend();
+ }));
+ }));
+ }
+}
+
+void
+AsyncComputeSocket::StartPingTimer()
+{
+ if (m_Closed)
+ {
+ return;
+ }
+
+ m_PingTimer.expires_after(std::chrono::seconds(2));
+
+ auto Self = shared_from_this();
+ m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) {
+ if (Ec || m_Closed)
{
- // Ping response - ignore
+ return;
}
- else
+
+ // Enqueue a ping control frame
+ PendingWrite Write;
+ Write.Header.Channel = 0;
+ Write.Header.Size = ControlPing;
+
+ m_SendQueue.push_back(std::move(Write));
+ if (m_SendQueue.size() == 1)
{
- ZEN_WARN("invalid frame header size: {}", Header.Size);
- return;
+ FlushNextSend();
}
- }
+
+ StartPingTimer();
+ }));
}
void
-ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader)
+AsyncComputeSocket::HandleError()
{
- // Each channel has its own send thread. All send threads share m_SendMutex
- // to serialize writes to the transport, since TCP requires atomic frame writes.
- FrameHeader Header;
- Header.Channel = Channel;
+ if (m_Closed)
+ {
+ return;
+ }
+
+ Close();
- const uint8_t* Data;
- while ((Data = Reader.WaitToRead(1)) != nullptr)
+ // Notify all channels that the connection is gone so agents can clean up
+ for (auto& [ChannelId, Handler] : m_DetachHandlers)
{
- std::lock_guard<std::mutex> Lock(m_SendMutex);
+ if (Handler)
+ {
+ Handler();
+ }
+ }
+}
- Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize());
- m_Transport->SendMessage(&Header, sizeof(Header));
- m_Transport->SendMessage(Data, Header.Size);
- Reader.AdvanceReadPosition(Header.Size);
+void
+AsyncComputeSocket::Close()
+{
+ if (m_Closed)
+ {
+ return;
}
- if (Reader.IsComplete())
+ m_Closed = true;
+ m_PingTimer.cancel();
+
+ if (m_Transport)
{
- std::lock_guard<std::mutex> Lock(m_SendMutex);
- Header.Size = ControlDetach;
- m_Transport->SendMessage(&Header, sizeof(Header));
+ m_Transport->Close();
}
}