diff options
Diffstat (limited to 'src/zenhorde/hordetransport.cpp')
| -rw-r--r-- | src/zenhorde/hordetransport.cpp | 153 |
1 files changed, 55 insertions, 98 deletions
diff --git a/src/zenhorde/hordetransport.cpp b/src/zenhorde/hordetransport.cpp index 69766e73e..65eaea477 100644 --- a/src/zenhorde/hordetransport.cpp +++ b/src/zenhorde/hordetransport.cpp @@ -9,71 +9,33 @@ ZEN_THIRD_PARTY_INCLUDES_START #include <asio.hpp> ZEN_THIRD_PARTY_INCLUDES_END -#if ZEN_PLATFORM_WINDOWS -# undef SendMessage -#endif - namespace zen::horde { -// ComputeTransport base +// --- AsyncTcpComputeTransport --- -bool -ComputeTransport::SendMessage(const void* Data, size_t Size) +struct AsyncTcpComputeTransport::Impl { - const uint8_t* Ptr = static_cast<const uint8_t*>(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Sent = Send(Ptr, Remaining); - if (Sent == 0) - { - return false; - } - Ptr += Sent; - Remaining -= Sent; - } + asio::io_context& IoContext; + asio::ip::tcp::socket Socket; - return true; -} + explicit Impl(asio::io_context& Ctx) : IoContext(Ctx), Socket(Ctx) {} +}; -bool -ComputeTransport::RecvMessage(void* Data, size_t Size) +AsyncTcpComputeTransport::AsyncTcpComputeTransport(asio::io_context& IoContext) +: m_Impl(std::make_unique<Impl>(IoContext)) +, m_Log(zen::logging::Get("horde.transport.async")) { - uint8_t* Ptr = static_cast<uint8_t*>(Data); - size_t Remaining = Size; - - while (Remaining > 0) - { - const size_t Received = Recv(Ptr, Remaining); - if (Received == 0) - { - return false; - } - Ptr += Received; - Remaining -= Received; - } - - return true; } -// TcpComputeTransport - ASIO pimpl - -struct TcpComputeTransport::Impl +AsyncTcpComputeTransport::~AsyncTcpComputeTransport() { - asio::io_context IoContext; - asio::ip::tcp::socket Socket; - - Impl() : Socket(IoContext) {} -}; + Close(); +} -// Uses ASIO in synchronous mode only — no async operations or io_context::run(). -// The io_context is only needed because ASIO sockets require one to be constructed. -TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) -: m_Impl(std::make_unique<Impl>()) -, m_Log(zen::logging::Get("horde.transport")) +void +AsyncTcpComputeTransport::AsyncConnect(const MachineInfo& Info, AsyncConnectHandler Handler) { - ZEN_TRACE_CPU("TcpComputeTransport::Connect"); + ZEN_TRACE_CPU("AsyncTcpComputeTransport::AsyncConnect"); asio::error_code Ec; @@ -82,80 +44,75 @@ TcpComputeTransport::TcpComputeTransport(const MachineInfo& Info) { ZEN_WARN("invalid address '{}': {}", Info.GetConnectionAddress(), Ec.message()); m_HasErrors = true; + asio::post(m_Impl->IoContext, [Handler = std::move(Handler), Ec] { Handler(Ec); }); return; } const asio::ip::tcp::endpoint Endpoint(Address, Info.GetConnectionPort()); - m_Impl->Socket.connect(Endpoint, Ec); - if (Ec) - { - ZEN_WARN("failed to connect to Horde compute [{}:{}]: {}", Info.GetConnectionAddress(), Info.GetConnectionPort(), Ec.message()); - m_HasErrors = true; - return; - } + // Copy the nonce so it survives past this scope into the async callback + auto NonceBuf = std::make_shared<std::vector<uint8_t>>(Info.Nonce, Info.Nonce + NonceSize); - // Disable Nagle's algorithm for lower latency - m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), Ec); -} + m_Impl->Socket.async_connect(Endpoint, [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec) mutable { + if (Ec) + { + ZEN_WARN("async connect failed: {}", Ec.message()); + m_HasErrors = true; + Handler(Ec); + return; + } -TcpComputeTransport::~TcpComputeTransport() -{ - Close(); + asio::error_code SetOptEc; + m_Impl->Socket.set_option(asio::ip::tcp::no_delay(true), SetOptEc); + + // Send the 64-byte nonce as the first thing on the wire + asio::async_write(m_Impl->Socket, + asio::buffer(*NonceBuf), + [this, Handler = std::move(Handler), NonceBuf](const asio::error_code& Ec, size_t /*BytesWritten*/) { + if (Ec) + { + ZEN_WARN("nonce write failed: {}", Ec.message()); + m_HasErrors = true; + } + Handler(Ec); + }); + }); } bool -TcpComputeTransport::IsValid() const +AsyncTcpComputeTransport::IsValid() const { return m_Impl && m_Impl->Socket.is_open() && !m_HasErrors && !m_IsClosed; } -size_t -TcpComputeTransport::Send(const void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncWrite(const void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Sent = m_Impl->Socket.send(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - m_HasErrors = true; - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Sent; + asio::async_write(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } -size_t -TcpComputeTransport::Recv(void* Data, size_t Size) +void +AsyncTcpComputeTransport::AsyncRead(void* Data, size_t Size, AsyncIoHandler Handler) { if (!IsValid()) { - return 0; - } - - asio::error_code Ec; - const size_t Received = m_Impl->Socket.receive(asio::buffer(Data, Size), 0, Ec); - - if (Ec) - { - return 0; + asio::post(m_Impl->IoContext, + [Handler = std::move(Handler)] { Handler(asio::error::make_error_code(asio::error::not_connected), 0); }); + return; } - return Received; -} - -void -TcpComputeTransport::MarkComplete() -{ + asio::async_read(m_Impl->Socket, asio::buffer(Data, Size), std::move(Handler)); } void -TcpComputeTransport::Close() +AsyncTcpComputeTransport::Close() { if (!m_IsClosed && m_Impl && m_Impl->Socket.is_open()) { |