// Copyright Epic Games, Inc. All Rights Reserved. #include "hordecomputesocket.h" #include namespace zen::horde { AsyncComputeSocket::AsyncComputeSocket(std::unique_ptr Transport, asio::io_context& IoContext) : m_Log(zen::logging::Get("horde.socket.async")) , m_Transport(std::move(Transport)) , m_Strand(asio::make_strand(IoContext)) , m_PingTimer(m_Strand) { } AsyncComputeSocket::~AsyncComputeSocket() { Close(); } void AsyncComputeSocket::RegisterChannel(int ChannelId, FrameHandler OnFrame, DetachHandler OnDetach) { m_FrameHandlers[ChannelId] = std::move(OnFrame); m_DetachHandlers[ChannelId] = std::move(OnDetach); } void AsyncComputeSocket::StartRecvPump() { StartPingTimer(); DoRecvHeader(); } void AsyncComputeSocket::DoRecvHeader() { auto Self = shared_from_this(); m_Transport->AsyncRead(&m_RecvHeader, sizeof(FrameHeader), asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { if (Ec) { if (Ec != asio::error::operation_aborted && !m_Closed) { ZEN_WARN("recv header error: {}", Ec.message()); HandleError(); } return; } if (m_Closed) { return; } if (m_RecvHeader.Size >= 0) { DoRecvPayload(m_RecvHeader); } else if (m_RecvHeader.Size == ControlDetach) { if (auto It = m_DetachHandlers.find(m_RecvHeader.Channel); It != m_DetachHandlers.end() && It->second) { It->second(); } DoRecvHeader(); } else if (m_RecvHeader.Size == ControlPing) { DoRecvHeader(); } else { ZEN_WARN("invalid frame header size: {}", m_RecvHeader.Size); } })); } void AsyncComputeSocket::DoRecvPayload(FrameHeader Header) { auto PayloadBuf = std::make_shared>(static_cast(Header.Size)); auto Self = shared_from_this(); m_Transport->AsyncRead(PayloadBuf->data(), PayloadBuf->size(), asio::bind_executor(m_Strand, [this, Self, Header, PayloadBuf](const std::error_code& Ec, size_t /*Bytes*/) { if (Ec) { if (Ec != asio::error::operation_aborted && !m_Closed) { ZEN_WARN("recv payload error (channel={}, size={}): {}", Header.Channel, Header.Size, Ec.message()); HandleError(); } return; } if (m_Closed) { return; } if (auto It = m_FrameHandlers.find(Header.Channel); It != m_FrameHandlers.end() && It->second) { It->second(std::move(*PayloadBuf)); } else { ZEN_WARN("recv frame for unknown channel {}", Header.Channel); } DoRecvHeader(); })); } void AsyncComputeSocket::AsyncSendFrame(int ChannelId, std::vector Data, SendHandler Handler) { auto Self = shared_from_this(); asio::dispatch(m_Strand, [this, Self, ChannelId, Data = std::move(Data), Handler = std::move(Handler)]() mutable { if (m_Closed) { if (Handler) { Handler(asio::error::make_error_code(asio::error::operation_aborted)); } return; } PendingWrite Write; Write.Header.Channel = ChannelId; Write.Header.Size = static_cast(Data.size()); Write.Data = std::move(Data); Write.Handler = std::move(Handler); m_SendQueue.push_back(std::move(Write)); if (m_SendQueue.size() == 1) { FlushNextSend(); } }); } void AsyncComputeSocket::AsyncSendDetach(int ChannelId, SendHandler Handler) { auto Self = shared_from_this(); asio::dispatch(m_Strand, [this, Self, ChannelId, Handler = std::move(Handler)]() mutable { if (m_Closed) { if (Handler) { Handler(asio::error::make_error_code(asio::error::operation_aborted)); } return; } PendingWrite Write; Write.Header.Channel = ChannelId; Write.Header.Size = ControlDetach; Write.Handler = std::move(Handler); m_SendQueue.push_back(std::move(Write)); if (m_SendQueue.size() == 1) { FlushNextSend(); } }); } void AsyncComputeSocket::FlushNextSend() { if (m_SendQueue.empty() || m_Closed) { return; } PendingWrite& Front = m_SendQueue.front(); if (Front.Data.empty()) { // Control frame - header only auto Self = shared_from_this(); m_Transport->AsyncWrite(&Front.Header, sizeof(FrameHeader), asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { SendHandler Handler = std::move(m_SendQueue.front().Handler); m_SendQueue.pop_front(); if (Handler) { Handler(Ec); } if (Ec) { if (Ec != asio::error::operation_aborted) { ZEN_WARN("send error: {}", Ec.message()); HandleError(); } return; } FlushNextSend(); })); } else { // Data frame - write header first, then payload auto Self = shared_from_this(); m_Transport->AsyncWrite(&Front.Header, sizeof(FrameHeader), asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { if (Ec) { SendHandler Handler = std::move(m_SendQueue.front().Handler); m_SendQueue.pop_front(); if (Handler) { Handler(Ec); } if (Ec != asio::error::operation_aborted) { ZEN_WARN("send header error: {}", Ec.message()); HandleError(); } return; } PendingWrite& Payload = m_SendQueue.front(); m_Transport->AsyncWrite( Payload.Data.data(), Payload.Data.size(), asio::bind_executor(m_Strand, [this, Self](const std::error_code& Ec, size_t /*Bytes*/) { SendHandler Handler = std::move(m_SendQueue.front().Handler); m_SendQueue.pop_front(); if (Handler) { Handler(Ec); } if (Ec) { if (Ec != asio::error::operation_aborted) { ZEN_WARN("send payload error: {}", Ec.message()); HandleError(); } return; } FlushNextSend(); })); })); } } void AsyncComputeSocket::StartPingTimer() { if (m_Closed) { return; } m_PingTimer.expires_after(std::chrono::seconds(2)); auto Self = shared_from_this(); m_PingTimer.async_wait(asio::bind_executor(m_Strand, [this, Self](const asio::error_code& Ec) { if (Ec || m_Closed) { return; } // Enqueue a ping control frame PendingWrite Write; Write.Header.Channel = 0; Write.Header.Size = ControlPing; m_SendQueue.push_back(std::move(Write)); if (m_SendQueue.size() == 1) { FlushNextSend(); } StartPingTimer(); })); } void AsyncComputeSocket::HandleError() { if (m_Closed) { return; } Close(); // Notify all channels that the connection is gone so agents can clean up for (auto& [ChannelId, Handler] : m_DetachHandlers) { if (Handler) { Handler(); } } } void AsyncComputeSocket::Close() { if (m_Closed) { return; } m_Closed = true; m_PingTimer.cancel(); if (m_Transport) { m_Transport->Close(); } } } // namespace zen::horde