aboutsummaryrefslogtreecommitdiff
path: root/src/zenhorde/hordecomputesocket.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhorde/hordecomputesocket.cpp')
-rw-r--r--src/zenhorde/hordecomputesocket.cpp204
1 files changed, 204 insertions, 0 deletions
diff --git a/src/zenhorde/hordecomputesocket.cpp b/src/zenhorde/hordecomputesocket.cpp
new file mode 100644
index 000000000..6ef67760c
--- /dev/null
+++ b/src/zenhorde/hordecomputesocket.cpp
@@ -0,0 +1,204 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include "hordecomputesocket.h"
+
+#include <zencore/logging.h>
+
+namespace zen::horde {
+
+ComputeSocket::ComputeSocket(std::unique_ptr<ComputeTransport> Transport)
+: m_Log(zen::logging::Get("horde.socket"))
+, m_Transport(std::move(Transport))
+{
+}
+
+ComputeSocket::~ComputeSocket()
+{
+ // Shutdown order matters: first stop the ping thread, then unblock send threads
+ // by detaching readers, then join send threads, and finally close the transport
+ // to unblock the recv thread (which is blocked on RecvMessage).
+ {
+ std::lock_guard<std::mutex> Lock(m_PingMutex);
+ m_PingShouldStop = true;
+ m_PingCV.notify_all();
+ }
+
+ for (auto& Reader : m_Readers)
+ {
+ Reader.Detach();
+ }
+
+ for (auto& [Id, Thread] : m_SendThreads)
+ {
+ if (Thread.joinable())
+ {
+ Thread.join();
+ }
+ }
+
+ m_Transport->Close();
+
+ if (m_RecvThread.joinable())
+ {
+ m_RecvThread.join();
+ }
+ if (m_PingThread.joinable())
+ {
+ m_PingThread.join();
+ }
+}
+
+Ref<ComputeChannel>
+ComputeSocket::CreateChannel(int ChannelId)
+{
+ ComputeBuffer::Params Params;
+
+ ComputeBuffer RecvBuffer;
+ if (!RecvBuffer.CreateNew(Params))
+ {
+ return {};
+ }
+
+ ComputeBuffer SendBuffer;
+ if (!SendBuffer.CreateNew(Params))
+ {
+ return {};
+ }
+
+ Ref<ComputeChannel> Channel(new ComputeChannel(RecvBuffer.CreateReader(), SendBuffer.CreateWriter()));
+
+ // Attach recv buffer writer (transport recv thread writes into this)
+ {
+ std::lock_guard<std::mutex> Lock(m_WritersMutex);
+ m_Writers.emplace(ChannelId, RecvBuffer.CreateWriter());
+ }
+
+ // Attach send buffer reader (send thread reads from this)
+ {
+ ComputeBufferReader Reader = SendBuffer.CreateReader();
+ m_Readers.push_back(Reader);
+ m_SendThreads.emplace(ChannelId, std::thread(&ComputeSocket::SendThreadProc, this, ChannelId, std::move(Reader)));
+ }
+
+ return Channel;
+}
+
+void
+ComputeSocket::StartCommunication()
+{
+ m_RecvThread = std::thread(&ComputeSocket::RecvThreadProc, this);
+ m_PingThread = std::thread(&ComputeSocket::PingThreadProc, this);
+}
+
+void
+ComputeSocket::PingThreadProc()
+{
+ while (true)
+ {
+ {
+ std::unique_lock<std::mutex> Lock(m_PingMutex);
+ if (m_PingCV.wait_for(Lock, std::chrono::milliseconds(2000), [this] { return m_PingShouldStop; }))
+ {
+ break;
+ }
+ }
+
+ std::lock_guard<std::mutex> Lock(m_SendMutex);
+ FrameHeader Header;
+ Header.Channel = 0;
+ Header.Size = ControlPing;
+ m_Transport->SendMessage(&Header, sizeof(Header));
+ }
+}
+
+void
+ComputeSocket::RecvThreadProc()
+{
+ // Writers are cached locally to avoid taking m_WritersMutex on every frame.
+ // The shared m_Writers map is only accessed when a channel is seen for the first time.
+ std::unordered_map<int, ComputeBufferWriter> CachedWriters;
+
+ FrameHeader Header;
+ while (m_Transport->RecvMessage(&Header, sizeof(Header)))
+ {
+ if (Header.Size >= 0)
+ {
+ // Data frame
+ auto It = CachedWriters.find(Header.Channel);
+ if (It == CachedWriters.end())
+ {
+ std::lock_guard<std::mutex> Lock(m_WritersMutex);
+ auto WIt = m_Writers.find(Header.Channel);
+ if (WIt == m_Writers.end())
+ {
+ ZEN_WARN("recv frame for unknown channel {}", Header.Channel);
+ // Skip the data
+ std::vector<uint8_t> Discard(Header.Size);
+ m_Transport->RecvMessage(Discard.data(), Header.Size);
+ continue;
+ }
+ It = CachedWriters.emplace(Header.Channel, WIt->second).first;
+ }
+
+ ComputeBufferWriter& Writer = It->second;
+ uint8_t* Dest = Writer.WaitToWrite(Header.Size);
+ if (!Dest || !m_Transport->RecvMessage(Dest, Header.Size))
+ {
+ ZEN_WARN("failed to read frame data (channel={}, size={})", Header.Channel, Header.Size);
+ return;
+ }
+ Writer.AdvanceWritePosition(Header.Size);
+ }
+ else if (Header.Size == ControlDetach)
+ {
+ // Detach the recv buffer for this channel
+ CachedWriters.erase(Header.Channel);
+
+ std::lock_guard<std::mutex> Lock(m_WritersMutex);
+ auto It = m_Writers.find(Header.Channel);
+ if (It != m_Writers.end())
+ {
+ It->second.MarkComplete();
+ m_Writers.erase(It);
+ }
+ }
+ else if (Header.Size == ControlPing)
+ {
+ // Ping response - ignore
+ }
+ else
+ {
+ ZEN_WARN("invalid frame header size: {}", Header.Size);
+ return;
+ }
+ }
+}
+
+void
+ComputeSocket::SendThreadProc(int Channel, ComputeBufferReader Reader)
+{
+ // Each channel has its own send thread. All send threads share m_SendMutex
+ // to serialize writes to the transport, since TCP requires atomic frame writes.
+ FrameHeader Header;
+ Header.Channel = Channel;
+
+ const uint8_t* Data;
+ while ((Data = Reader.WaitToRead(1)) != nullptr)
+ {
+ std::lock_guard<std::mutex> Lock(m_SendMutex);
+
+ Header.Size = static_cast<int32_t>(Reader.GetMaxReadSize());
+ m_Transport->SendMessage(&Header, sizeof(Header));
+ m_Transport->SendMessage(Data, Header.Size);
+ Reader.AdvanceReadPosition(Header.Size);
+ }
+
+ if (Reader.IsComplete())
+ {
+ std::lock_guard<std::mutex> Lock(m_SendMutex);
+ Header.Size = ControlDetach;
+ m_Transport->SendMessage(&Header, sizeof(Header));
+ }
+}
+
+} // namespace zen::horde