// Copyright Epic Games, Inc. All Rights Reserved. #include "hordecomputesocket.h" #include namespace zen::horde { ComputeSocket::ComputeSocket(std::unique_ptr Transport) : m_Log(zen::logging::Get("horde.socket")) , m_Transport(std::move(Transport)) { } ComputeSocket::~ComputeSocket() { // Shutdown order matters: first stop the ping thread, then unblock send threads // by detaching readers, then join send threads, and finally close the transport // to unblock the recv thread (which is blocked on RecvMessage). { std::lock_guard Lock(m_PingMutex); m_PingShouldStop = true; m_PingCV.notify_all(); } for (auto& Reader : m_Readers) { Reader.Detach(); } for (auto& [Id, Thread] : m_SendThreads) { if (Thread.joinable()) { Thread.join(); } } m_Transport->Close(); if (m_RecvThread.joinable()) { m_RecvThread.join(); } if (m_PingThread.joinable()) { m_PingThread.join(); } } Ref ComputeSocket::CreateChannel(int ChannelId) { ComputeBuffer::Params Params; ComputeBuffer RecvBuffer; if (!RecvBuffer.CreateNew(Params)) { return {}; } ComputeBuffer SendBuffer; if (!SendBuffer.CreateNew(Params)) { return {}; } Ref Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter())); // Attach recv buffer writer (transport recv thread writes into this) { std::lock_guard Lock(m_WritersMutex); m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter()); } // Attach send buffer reader (send thread reads from this) { ComputeBufferReader Reader = SendBuffer.CreateReader(); m_Readers.push_back(Reader); m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader))); } return Channel; } void ComputeSocket::StartCommunication() { m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this); m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this); } void ComputeSocket::PingThreadProc() { while (true) { { std::unique_lock Lock(m_PingMutex); if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; })) { break; } } std::lock_guard Lock(m_SendMutex); FrameHeader Header; Header.Channel = 0; Header.Size = ControlPing; m_Transport->SendMessage(&Header, sizeof(Header)); } } void ComputeSocket::RecvThreadProc() { // Writers are cached locally to avoid taking m_WritersMutex on every frame. // The shared m_Writers map is only accessed when a channel is seen for the first time. std::unordered_map CachedWriters; FrameHeader Header; while (m_Transport->RecvMessage(&Header, sizeof(Header))) { if (Header.Size >= 0) { // Data frame auto It = CachedWriters.find(Header.Channel); if (It == CachedWriters.end()) { std::lock_guard Lock(m_WritersMutex); auto WIt = m_Writers.find(Header.Channel); if (WIt == m_Writers.end()) { ZEN_WARN("recv frame for unknown channel {}", Header.Channel); // Skip the data std::vector Discard(Header.Size); m_Transport->RecvMessage(Discard.data(), Header.Size); continue; } It = CachedWriters.emplace(Header.Channel, WIt->second).first; } ComputeBufferWriter& Writer = It->second; uint8_t* Dest = Writer.WaitToWrite(Header.Size); if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size)) { ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size); return; } Writer.AdvanceWritePosition(Header.Size); } else if (Header.Size == ControlDetach) { // Detach the recv buffer for this channel CachedWriters.erase(Header.Channel); std::lock_guard Lock(m_WritersMutex); auto It = m_Writers.find(Header.Channel); if (It != m_Writers.end()) { It->second.MarkComplete(); m_Writers.erase(It); } } else if (Header.Size == ControlPing) { // Ping response - ignore } else { ZEN_WARN("invalid frame header size: {}", Header.Size); return; } } } void ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader) { // Each channel has its own send thread. All send threads share m_SendMutex // to serialize writes to the transport, since TCP requires atomic frame writes. FrameHeader Header; Header.Channel = Channel; const uint8_t* Data; while ((Data = Reader.WaitToRead(1)) != nullptr) { std::lock_guard Lock(m_SendMutex); Header.Size = static_cast(Reader.GetMaxReadSize()); m_Transport->SendMessage(&Header, sizeof(Header)); m_Transport->SendMessage(Data, Header.Size); Reader.AdvanceReadPosition(Header.Size); } if (Reader.IsComplete()) { std::lock_guard Lock(m_SendMutex); Header.Size = ControlDetach; m_Transport->SendMessage(&Header, sizeof(Header)); } } } // namespace zen::horde