aboutsummaryrefslogtreecommitdiff
path: root/src/zenhttp/clients/httpwsclient.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/zenhttp/clients/httpwsclient.cpp')
-rw-r--r--src/zenhttp/clients/httpwsclient.cpp566
1 files changed, 566 insertions, 0 deletions
diff --git a/src/zenhttp/clients/httpwsclient.cpp b/src/zenhttp/clients/httpwsclient.cpp
new file mode 100644
index 000000000..9497dadb8
--- /dev/null
+++ b/src/zenhttp/clients/httpwsclient.cpp
@@ -0,0 +1,566 @@
+// Copyright Epic Games, Inc. All Rights Reserved.
+
+#include <zenhttp/httpwsclient.h>
+
+#include "../servers/wsframecodec.h"
+
+#include <zencore/base64.h>
+#include <zencore/logging.h>
+#include <zencore/string.h>
+
+ZEN_THIRD_PARTY_INCLUDES_START
+#include <asio.hpp>
+ZEN_THIRD_PARTY_INCLUDES_END
+
+#include <deque>
+#include <random>
+#include <thread>
+
+namespace zen {
+
+//////////////////////////////////////////////////////////////////////////
+
+struct HttpWsClient::Impl
+{
+ Impl(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_OwnedIoContext(std::make_unique<asio::io_context>())
+ , m_IoContext(*m_OwnedIoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ Impl(std::string_view Url, IWsClientHandler& Handler, asio::io_context& IoContext, const HttpWsClientSettings& Settings)
+ : m_Handler(Handler)
+ , m_Settings(Settings)
+ , m_Log(logging::Get(Settings.LogCategory))
+ , m_IoContext(IoContext)
+ {
+ ParseUrl(Url);
+ }
+
+ ~Impl()
+ {
+ // Release work guard so io_context::run() can return
+ m_WorkGuard.reset();
+
+ // Close the socket to cancel pending async ops
+ if (m_Socket)
+ {
+ asio::error_code Ec;
+ m_Socket->close(Ec);
+ }
+
+ if (m_IoThread.joinable())
+ {
+ m_IoThread.join();
+ }
+ }
+
+ void ParseUrl(std::string_view Url)
+ {
+ // Expected format: ws://host:port/path
+ if (Url.substr(0, 5) == "ws://")
+ {
+ Url.remove_prefix(5);
+ }
+
+ auto SlashPos = Url.find('/');
+ std::string_view HostPort;
+ if (SlashPos != std::string_view::npos)
+ {
+ HostPort = Url.substr(0, SlashPos);
+ m_Path = std::string(Url.substr(SlashPos));
+ }
+ else
+ {
+ HostPort = Url;
+ m_Path = "/";
+ }
+
+ auto ColonPos = HostPort.find(':');
+ if (ColonPos != std::string_view::npos)
+ {
+ m_Host = std::string(HostPort.substr(0, ColonPos));
+ m_Port = std::string(HostPort.substr(ColonPos + 1));
+ }
+ else
+ {
+ m_Host = std::string(HostPort);
+ m_Port = "80";
+ }
+ }
+
+ void Connect()
+ {
+ if (m_OwnedIoContext)
+ {
+ m_WorkGuard = std::make_unique<asio::io_context::work>(m_IoContext);
+ m_IoThread = std::thread([this] { m_IoContext.run(); });
+ }
+
+ asio::post(m_IoContext, [this] { DoResolve(); });
+ }
+
+ void DoResolve()
+ {
+ m_Resolver = std::make_unique<asio::ip::tcp::resolver>(m_IoContext);
+
+ m_Resolver->async_resolve(m_Host, m_Port, [this](const asio::error_code& Ec, asio::ip::tcp::resolver::results_type Results) {
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket resolve failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "resolve failed");
+ return;
+ }
+
+ DoConnect(Results);
+ });
+ }
+
+ void DoConnect(const asio::ip::tcp::resolver::results_type& Endpoints)
+ {
+ m_Socket = std::make_unique<asio::ip::tcp::socket>(m_IoContext);
+
+ // Start connect timeout timer
+ m_Timer = std::make_unique<asio::steady_timer>(m_IoContext, m_Settings.ConnectTimeout);
+ m_Timer->async_wait([this](const asio::error_code& Ec) {
+ if (!Ec && !m_IsOpen.load(std::memory_order_relaxed))
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect timeout for {}:{}", m_Host, m_Port);
+ if (m_Socket)
+ {
+ asio::error_code CloseEc;
+ m_Socket->close(CloseEc);
+ }
+ }
+ });
+
+ asio::async_connect(*m_Socket, Endpoints, [this](const asio::error_code& Ec, const asio::ip::tcp::endpoint&) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket connect failed for {}:{}: {}", m_Host, m_Port, Ec.message());
+ m_Handler.OnWsClose(1006, "connect failed");
+ return;
+ }
+
+ DoHandshake();
+ });
+ }
+
+ void DoHandshake()
+ {
+ // Generate random Sec-WebSocket-Key (16 random bytes, base64 encoded)
+ uint8_t KeyBytes[16];
+ {
+ static thread_local std::mt19937 s_Rng(std::random_device{}());
+ for (int i = 0; i < 4; ++i)
+ {
+ uint32_t Val = s_Rng();
+ std::memcpy(KeyBytes + i * 4, &Val, 4);
+ }
+ }
+
+ char KeyBase64[Base64::GetEncodedDataSize(16) + 1];
+ uint32_t KeyLen = Base64::Encode(KeyBytes, 16, KeyBase64);
+ KeyBase64[KeyLen] = '\0';
+ m_WebSocketKey = std::string(KeyBase64, KeyLen);
+
+ // Build the HTTP upgrade request
+ ExtendableStringBuilder<512> Request;
+ Request << "GET " << m_Path << " HTTP/1.1\r\n"
+ << "Host: " << m_Host << ":" << m_Port << "\r\n"
+ << "Upgrade: websocket\r\n"
+ << "Connection: Upgrade\r\n"
+ << "Sec-WebSocket-Key: " << m_WebSocketKey << "\r\n"
+ << "Sec-WebSocket-Version: 13\r\n";
+
+ // Add Authorization header if access token provider is set
+ if (m_Settings.AccessTokenProvider)
+ {
+ HttpClientAccessToken Token = (*m_Settings.AccessTokenProvider)();
+ if (Token.IsValid())
+ {
+ Request << "Authorization: Bearer " << Token.Value << "\r\n";
+ }
+ }
+
+ Request << "\r\n";
+
+ std::string_view ReqStr = Request.ToView();
+
+ m_HandshakeBuffer = std::make_shared<std::string>(ReqStr);
+
+ asio::async_write(*m_Socket,
+ asio::buffer(m_HandshakeBuffer->data(), m_HandshakeBuffer->size()),
+ [this](const asio::error_code& Ec, std::size_t) {
+ if (Ec)
+ {
+ m_Timer->cancel();
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake write failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake write failed");
+ return;
+ }
+
+ DoReadHandshakeResponse();
+ });
+ }
+
+ void DoReadHandshakeResponse()
+ {
+ asio::async_read_until(*m_Socket, m_ReadBuffer, "\r\n\r\n", [this](const asio::error_code& Ec, std::size_t) {
+ m_Timer->cancel();
+
+ if (Ec)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake read failed: {}", Ec.message());
+ m_Handler.OnWsClose(1006, "handshake read failed");
+ return;
+ }
+
+ // Parse the response
+ const auto& Data = m_ReadBuffer.data();
+ std::string Response(asio::buffers_begin(Data), asio::buffers_end(Data));
+
+ // Consume the headers from the read buffer (any extra data stays for frame parsing)
+ auto HeaderEnd = Response.find("\r\n\r\n");
+ if (HeaderEnd != std::string::npos)
+ {
+ m_ReadBuffer.consume(HeaderEnd + 4);
+ }
+
+ // Validate 101 response
+ if (Response.find("101") == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake rejected (no 101): {}", Response.substr(0, 80));
+ m_Handler.OnWsClose(1006, "handshake rejected");
+ return;
+ }
+
+ // Validate Sec-WebSocket-Accept
+ std::string ExpectedAccept = WsFrameCodec::ComputeAcceptKey(m_WebSocketKey);
+ if (Response.find(ExpectedAccept) == std::string::npos)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket handshake: invalid Sec-WebSocket-Accept");
+ m_Handler.OnWsClose(1006, "invalid accept key");
+ return;
+ }
+
+ m_IsOpen.store(true);
+ m_Handler.OnWsOpen();
+ EnqueueRead();
+ });
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Read loop
+ //
+
+ void EnqueueRead()
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ asio::async_read(*m_Socket, m_ReadBuffer, asio::transfer_at_least(1), [this](const asio::error_code& Ec, std::size_t) {
+ OnDataReceived(Ec);
+ });
+ }
+
+ void OnDataReceived(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::eof && Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket read error: {}", Ec.message());
+ }
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "connection lost");
+ }
+ return;
+ }
+
+ ProcessReceivedData();
+
+ if (m_IsOpen.load(std::memory_order_relaxed))
+ {
+ EnqueueRead();
+ }
+ }
+
+ void ProcessReceivedData()
+ {
+ while (m_ReadBuffer.size() > 0)
+ {
+ const auto& InputBuffer = m_ReadBuffer.data();
+ const auto* RawData = static_cast<const uint8_t*>(InputBuffer.data());
+ const auto Size = InputBuffer.size();
+
+ WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(RawData, Size);
+ if (!Frame.IsValid)
+ {
+ break;
+ }
+
+ m_ReadBuffer.consume(Frame.BytesConsumed);
+
+ switch (Frame.Opcode)
+ {
+ case WebSocketOpcode::kText:
+ case WebSocketOpcode::kBinary:
+ {
+ WebSocketMessage Msg;
+ Msg.Opcode = Frame.Opcode;
+ Msg.Payload = IoBuffer(IoBuffer::Clone, Frame.Payload.data(), Frame.Payload.size());
+ m_Handler.OnWsMessage(Msg);
+ break;
+ }
+
+ case WebSocketOpcode::kPing:
+ {
+ // Auto-respond with masked pong
+ std::vector<uint8_t> PongFrame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kPong, Frame.Payload);
+ EnqueueWrite(std::move(PongFrame));
+ break;
+ }
+
+ case WebSocketOpcode::kPong:
+ break;
+
+ case WebSocketOpcode::kClose:
+ {
+ uint16_t Code = 1000;
+ std::string_view Reason;
+
+ if (Frame.Payload.size() >= 2)
+ {
+ Code = (uint16_t(Frame.Payload[0]) << 8) | uint16_t(Frame.Payload[1]);
+ if (Frame.Payload.size() > 2)
+ {
+ Reason =
+ std::string_view(reinterpret_cast<const char*>(Frame.Payload.data() + 2), Frame.Payload.size() - 2);
+ }
+ }
+
+ // Echo masked close frame if we haven't sent one yet
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+
+ m_IsOpen.store(false);
+ m_Handler.OnWsClose(Code, Reason);
+ return;
+ }
+
+ default:
+ ZEN_LOG_WARN(m_Log, "Unknown WebSocket opcode: {:#x}", static_cast<uint8_t>(Frame.Opcode));
+ break;
+ }
+ }
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Write queue
+ //
+
+ void EnqueueWrite(std::vector<uint8_t> Frame)
+ {
+ bool ShouldFlush = false;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_WriteQueue.push_back(std::move(Frame));
+ if (!m_IsWriting)
+ {
+ m_IsWriting = true;
+ ShouldFlush = true;
+ }
+ });
+
+ if (ShouldFlush)
+ {
+ FlushWriteQueue();
+ }
+ }
+
+ void FlushWriteQueue()
+ {
+ std::vector<uint8_t> Frame;
+
+ m_WriteLock.WithExclusiveLock([&] {
+ if (m_WriteQueue.empty())
+ {
+ m_IsWriting = false;
+ return;
+ }
+ Frame = std::move(m_WriteQueue.front());
+ m_WriteQueue.pop_front();
+ });
+
+ if (Frame.empty())
+ {
+ return;
+ }
+
+ auto OwnedFrame = std::make_shared<std::vector<uint8_t>>(std::move(Frame));
+
+ asio::async_write(*m_Socket,
+ asio::buffer(OwnedFrame->data(), OwnedFrame->size()),
+ [this, OwnedFrame](const asio::error_code& Ec, std::size_t) { OnWriteComplete(Ec); });
+ }
+
+ void OnWriteComplete(const asio::error_code& Ec)
+ {
+ if (Ec)
+ {
+ if (Ec != asio::error::operation_aborted)
+ {
+ ZEN_LOG_DEBUG(m_Log, "WebSocket write error: {}", Ec.message());
+ }
+
+ m_WriteLock.WithExclusiveLock([&] {
+ m_IsWriting = false;
+ m_WriteQueue.clear();
+ });
+
+ if (m_IsOpen.exchange(false))
+ {
+ m_Handler.OnWsClose(1006, "write error");
+ }
+ return;
+ }
+
+ FlushWriteQueue();
+ }
+
+ //////////////////////////////////////////////////////////////////////////
+ //
+ // Public operations
+ //
+
+ void SendText(std::string_view Text)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::span<const uint8_t> Payload(reinterpret_cast<const uint8_t*>(Text.data()), Text.size());
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void SendBinary(std::span<const uint8_t> Data)
+ {
+ if (!m_IsOpen.load(std::memory_order_relaxed))
+ {
+ return;
+ }
+
+ std::vector<uint8_t> Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Data);
+ EnqueueWrite(std::move(Frame));
+ }
+
+ void DoClose(uint16_t Code, std::string_view Reason)
+ {
+ if (!m_IsOpen.exchange(false))
+ {
+ return;
+ }
+
+ if (!m_CloseSent.exchange(true))
+ {
+ std::vector<uint8_t> CloseFrame = WsFrameCodec::BuildMaskedCloseFrame(Code, Reason);
+ EnqueueWrite(std::move(CloseFrame));
+ }
+ }
+
+ IWsClientHandler& m_Handler;
+ HttpWsClientSettings m_Settings;
+ LoggerRef m_Log;
+
+ std::string m_Host;
+ std::string m_Port;
+ std::string m_Path;
+
+ // io_context: owned (standalone) or external (shared)
+ std::unique_ptr<asio::io_context> m_OwnedIoContext;
+ asio::io_context& m_IoContext;
+ std::unique_ptr<asio::io_context::work> m_WorkGuard;
+ std::thread m_IoThread;
+
+ // Connection state
+ std::unique_ptr<asio::ip::tcp::resolver> m_Resolver;
+ std::unique_ptr<asio::ip::tcp::socket> m_Socket;
+ std::unique_ptr<asio::steady_timer> m_Timer;
+ asio::streambuf m_ReadBuffer;
+ std::string m_WebSocketKey;
+ std::shared_ptr<std::string> m_HandshakeBuffer;
+
+ // Write queue
+ RwLock m_WriteLock;
+ std::deque<std::vector<uint8_t>> m_WriteQueue;
+ bool m_IsWriting = false;
+
+ std::atomic<bool> m_IsOpen{false};
+ std::atomic<bool> m_CloseSent{false};
+};
+
+//////////////////////////////////////////////////////////////////////////
+
+HttpWsClient::HttpWsClient(std::string_view Url, IWsClientHandler& Handler, const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, Settings))
+{
+}
+
+HttpWsClient::HttpWsClient(std::string_view Url,
+ IWsClientHandler& Handler,
+ asio::io_context& IoContext,
+ const HttpWsClientSettings& Settings)
+: m_Impl(std::make_unique<Impl>(Url, Handler, IoContext, Settings))
+{
+}
+
+HttpWsClient::~HttpWsClient() = default;
+
+void
+HttpWsClient::Connect()
+{
+ m_Impl->Connect();
+}
+
+void
+HttpWsClient::SendText(std::string_view Text)
+{
+ m_Impl->SendText(Text);
+}
+
+void
+HttpWsClient::SendBinary(std::span<const uint8_t> Data)
+{
+ m_Impl->SendBinary(Data);
+}
+
+void
+HttpWsClient::Close(uint16_t Code, std::string_view Reason)
+{
+ m_Impl->DoClose(Code, Reason);
+}
+
+bool
+HttpWsClient::IsOpen() const
+{
+ return m_Impl->m_IsOpen.load(std::memory_order_relaxed);
+}
+
+} // namespace zen