// Copyright Epic Games, Inc. All Rights Reserved. #if ZEN_WITH_TESTS # include # include # include # include # include # include # include "httpasio.h" # include "wsframecodec.h" ZEN_THIRD_PARTY_INCLUDES_START # if ZEN_PLATFORM_WINDOWS # include # else # include # include # endif # include ZEN_THIRD_PARTY_INCLUDES_END # include # include # include # include # include # include # include # include namespace zen { using namespace std::literals; ////////////////////////////////////////////////////////////////////////// // // Unit tests: WsFrameCodec // TEST_CASE("websocket.framecodec") { SUBCASE("ComputeAcceptKey RFC 6455 test vector") { // RFC 6455 section 4.2.2 example std::string AcceptKey = WsFrameCodec::ComputeAcceptKey("dGhlIHNhbXBsZSBub25jZQ=="); CHECK_EQ(AcceptKey, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); } SUBCASE("BuildFrame and TryParseFrame roundtrip - text") { std::string_view Text = "Hello, WebSocket!"; std::span Payload(reinterpret_cast(Text.data()), Text.size()); std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, Payload); // Server frames are unmasked — TryParseFrame should handle them WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.BytesConsumed, Frame.size()); CHECK(Result.Fin); CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); CHECK_EQ(Result.Payload.size(), Text.size()); CHECK_EQ(std::string_view(reinterpret_cast(Result.Payload.data()), Result.Payload.size()), Text); } SUBCASE("BuildFrame and TryParseFrame roundtrip - binary") { std::vector BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, BinaryData); WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); CHECK_EQ(Result.Payload, BinaryData); } SUBCASE("BuildFrame - medium payload (126-65535 bytes)") { std::vector Payload(300, 0x42); std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Payload.size(), 300u); CHECK_EQ(Result.Payload, Payload); } SUBCASE("BuildFrame - large payload (>65535 bytes)") { std::vector Payload(70000, 0xAB); std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kBinary, Payload); WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Payload.size(), 70000u); } SUBCASE("BuildCloseFrame roundtrip") { std::vector Frame = WsFrameCodec::BuildCloseFrame(1000, "normal closure"); WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); REQUIRE(Result.Payload.size() >= 2); uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); CHECK_EQ(Code, 1000); std::string_view Reason(reinterpret_cast(Result.Payload.data() + 2), Result.Payload.size() - 2); CHECK_EQ(Reason, "normal closure"); } SUBCASE("TryParseFrame - partial data returns invalid") { std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span{}); // Pass only 1 byte — not enough for a frame header WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), 1); CHECK_FALSE(Result.IsValid); CHECK_EQ(Result.BytesConsumed, 0u); } SUBCASE("TryParseFrame - empty payload") { std::vector Frame = WsFrameCodec::BuildFrame(WebSocketOpcode::kText, std::span{}); WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); CHECK(Result.Payload.empty()); } SUBCASE("TryParseFrame - masked client frame") { // Build a masked frame manually as a client would send // Frame: FIN=1, opcode=text, MASK=1, payload_len=5, mask_key=0x37FA213D, payload="Hello" uint8_t MaskKey[4] = {0x37, 0xFA, 0x21, 0x3D}; uint8_t MaskedPayload[5] = {}; const char* Original = "Hello"; for (int i = 0; i < 5; ++i) { MaskedPayload[i] = static_cast(Original[i]) ^ MaskKey[i % 4]; } std::vector Frame; Frame.push_back(0x81); // FIN + text Frame.push_back(0x85); // MASK + len=5 Frame.insert(Frame.end(), MaskKey, MaskKey + 4); Frame.insert(Frame.end(), MaskedPayload, MaskedPayload + 5); WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); CHECK_EQ(Result.Payload.size(), 5u); CHECK_EQ(std::string_view(reinterpret_cast(Result.Payload.data()), 5), "Hello"sv); } SUBCASE("BuildMaskedFrame roundtrip - text") { std::string_view Text = "Hello, masked WebSocket!"; std::span Payload(reinterpret_cast(Text.data()), Text.size()); std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kText, Payload); // Verify mask bit is set CHECK((Frame[1] & 0x80) != 0); WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.BytesConsumed, Frame.size()); CHECK(Result.Fin); CHECK_EQ(Result.Opcode, WebSocketOpcode::kText); CHECK_EQ(Result.Payload.size(), Text.size()); CHECK_EQ(std::string_view(reinterpret_cast(Result.Payload.data()), Result.Payload.size()), Text); } SUBCASE("BuildMaskedFrame roundtrip - binary") { std::vector BinaryData = {0x00, 0x01, 0x02, 0xFF, 0xFE, 0xFD}; std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, BinaryData); CHECK((Frame[1] & 0x80) != 0); WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Opcode, WebSocketOpcode::kBinary); CHECK_EQ(Result.Payload, BinaryData); } SUBCASE("BuildMaskedFrame - medium payload (126-65535 bytes)") { std::vector Payload(300, 0x42); std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); CHECK((Frame[1] & 0x80) != 0); CHECK_EQ((Frame[1] & 0x7F), 126); // 16-bit extended length WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Payload.size(), 300u); CHECK_EQ(Result.Payload, Payload); } SUBCASE("BuildMaskedFrame - large payload (>65535 bytes)") { std::vector Payload(70000, 0xAB); std::vector Frame = WsFrameCodec::BuildMaskedFrame(WebSocketOpcode::kBinary, Payload); CHECK((Frame[1] & 0x80) != 0); CHECK_EQ((Frame[1] & 0x7F), 127); // 64-bit extended length WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Payload.size(), 70000u); } SUBCASE("BuildMaskedCloseFrame roundtrip") { std::vector Frame = WsFrameCodec::BuildMaskedCloseFrame(1000, "normal closure"); CHECK((Frame[1] & 0x80) != 0); WsFrameParseResult Result = WsFrameCodec::TryParseFrame(Frame.data(), Frame.size()); CHECK(Result.IsValid); CHECK_EQ(Result.Opcode, WebSocketOpcode::kClose); REQUIRE(Result.Payload.size() >= 2); uint16_t Code = (uint16_t(Result.Payload[0]) << 8) | uint16_t(Result.Payload[1]); CHECK_EQ(Code, 1000); std::string_view Reason(reinterpret_cast(Result.Payload.data() + 2), Result.Payload.size() - 2); CHECK_EQ(Reason, "normal closure"); } } ////////////////////////////////////////////////////////////////////////// // // Integration tests: WebSocket over ASIO // namespace { /** * Helper: Build a masked client-to-server frame per RFC 6455 */ std::vector BuildMaskedFrame(WebSocketOpcode Opcode, std::span Payload) { std::vector Frame; // FIN + opcode Frame.push_back(0x80 | static_cast(Opcode)); // Payload length with mask bit set if (Payload.size() < 126) { Frame.push_back(0x80 | static_cast(Payload.size())); } else if (Payload.size() <= 0xFFFF) { Frame.push_back(0x80 | 126); Frame.push_back(static_cast((Payload.size() >> 8) & 0xFF)); Frame.push_back(static_cast(Payload.size() & 0xFF)); } else { Frame.push_back(0x80 | 127); for (int i = 7; i >= 0; --i) { Frame.push_back(static_cast((Payload.size() >> (i * 8)) & 0xFF)); } } // Mask key (use a fixed key for deterministic tests) uint8_t MaskKey[4] = {0x12, 0x34, 0x56, 0x78}; Frame.insert(Frame.end(), MaskKey, MaskKey + 4); // Masked payload for (size_t i = 0; i < Payload.size(); ++i) { Frame.push_back(Payload[i] ^ MaskKey[i & 3]); } return Frame; } std::vector BuildMaskedTextFrame(std::string_view Text) { std::span Payload(reinterpret_cast(Text.data()), Text.size()); return BuildMaskedFrame(WebSocketOpcode::kText, Payload); } std::vector BuildMaskedCloseFrame(uint16_t Code) { std::vector Payload; Payload.push_back(static_cast((Code >> 8) & 0xFF)); Payload.push_back(static_cast(Code & 0xFF)); return BuildMaskedFrame(WebSocketOpcode::kClose, Payload); } /** * Test service that implements IWebSocketHandler */ struct WsTestService : public HttpService, public IWebSocketHandler { const char* BaseUri() const override { return "/wstest/"; } void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK, HttpContentType::kText, "hello from wstest"); } // IWebSocketHandler void OnWebSocketOpen(Ref Connection) override { m_OpenCount.fetch_add(1); m_ConnectionsLock.WithExclusiveLock([&] { m_Connections.push_back(Connection); }); } void OnWebSocketMessage(WebSocketConnection& Conn, const WebSocketMessage& Msg) override { m_MessageCount.fetch_add(1); if (Msg.Opcode == WebSocketOpcode::kText) { std::string_view Text(static_cast(Msg.Payload.Data()), Msg.Payload.Size()); m_LastMessage = std::string(Text); // Echo the message back Conn.SendText(Text); } } void OnWebSocketClose(WebSocketConnection& Conn, uint16_t Code, [[maybe_unused]] std::string_view Reason) override { m_CloseCount.fetch_add(1); m_LastCloseCode = Code; m_ConnectionsLock.WithExclusiveLock([&] { auto It = std::remove_if(m_Connections.begin(), m_Connections.end(), [&Conn](const Ref& C) { return C.Get() == &Conn; }); m_Connections.erase(It, m_Connections.end()); }); } void SendToAll(std::string_view Text) { RwLock::SharedLockScope _(m_ConnectionsLock); for (auto& Conn : m_Connections) { if (Conn->IsOpen()) { Conn->SendText(Text); } } } std::atomic m_OpenCount{0}; std::atomic m_MessageCount{0}; std::atomic m_CloseCount{0}; std::atomic m_LastCloseCode{0}; std::string m_LastMessage; RwLock m_ConnectionsLock; std::vector> m_Connections; }; /** * Helper: Perform the WebSocket upgrade handshake on a raw TCP socket * * Returns true on success (101 response), false otherwise. */ bool DoWebSocketHandshake(asio::ip::tcp::socket& Sock, std::string_view Path, int Port) { // Send HTTP upgrade request ExtendableStringBuilder<512> Request; Request << "GET " << Path << " HTTP/1.1\r\n" << "Host: 127.0.0.1:" << Port << "\r\n" << "Upgrade: websocket\r\n" << "Connection: Upgrade\r\n" << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" << "Sec-WebSocket-Version: 13\r\n" << "\r\n"; std::string_view ReqStr = Request.ToView(); asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); // Read the response (look for "101") asio::streambuf ResponseBuf; asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); return Response.find("101") != std::string::npos; } /** * Helper: Read a single server-to-client frame from a socket * * Uses a background thread with a synchronous ASIO read and a timeout. */ WsFrameParseResult ReadOneFrame(asio::ip::tcp::socket& Sock, int TimeoutMs = 5000) { std::vector Buffer; WsFrameParseResult Result; std::atomic Done{false}; std::thread Reader([&] { while (!Done.load()) { uint8_t Tmp[4096]; asio::error_code Ec; size_t BytesRead = Sock.read_some(asio::buffer(Tmp), Ec); if (Ec || BytesRead == 0) { break; } Buffer.insert(Buffer.end(), Tmp, Tmp + BytesRead); WsFrameParseResult Frame = WsFrameCodec::TryParseFrame(Buffer.data(), Buffer.size()); if (Frame.IsValid) { Result = std::move(Frame); Done.store(true); return; } } }); auto Deadline = std::chrono::steady_clock::now() + std::chrono::milliseconds(TimeoutMs); while (!Done.load() && std::chrono::steady_clock::now() < Deadline) { Sleep(10); } if (!Done.load()) { // Timeout — cancel the read asio::error_code Ec; Sock.cancel(Ec); } if (Reader.joinable()) { Reader.join(); } return Result; } } // anonymous namespace TEST_CASE("websocket.integration") { WsTestService TestService; ScopedTemporaryDirectory TmpDir; Ref Server = CreateHttpAsioServer(AsioConfig{}); int Port = Server->Initialize(7575, TmpDir.Path()); REQUIRE(Port != 0); Server->RegisterService(TestService); std::thread ServerThread([&]() { Server->Run(false); }); auto ServerGuard = MakeGuard([&]() { Server->RequestExit(); if (ServerThread.joinable()) { ServerThread.join(); } Server->Close(); }); // Give server a moment to start accepting Sleep(100); SUBCASE("handshake succeeds with 101") { asio::io_context IoCtx; asio::ip::tcp::socket Sock(IoCtx); Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); CHECK(Ok); Sleep(50); CHECK_EQ(TestService.m_OpenCount.load(), 1); Sock.close(); } SUBCASE("normal HTTP still works alongside WebSocket service") { asio::io_context IoCtx; asio::ip::tcp::socket Sock(IoCtx); Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); // Send a normal HTTP GET (not upgrade) std::string HttpReq = fmt::format( "GET /wstest/hello HTTP/1.1\r\n" "Host: 127.0.0.1:{}\r\n" "Connection: close\r\n" "\r\n", Port); asio::write(Sock, asio::buffer(HttpReq)); asio::streambuf ResponseBuf; asio::error_code Ec; asio::read(Sock, ResponseBuf, asio::transfer_at_least(1), Ec); std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); CHECK(Response.find("200") != std::string::npos); } SUBCASE("echo message roundtrip") { asio::io_context IoCtx; asio::ip::tcp::socket Sock(IoCtx); Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); REQUIRE(Ok); Sleep(50); // Send a text message (masked, as client) std::vector Frame = BuildMaskedTextFrame("ping test"); asio::write(Sock, asio::buffer(Frame)); // Read the echo reply WsFrameParseResult Reply = ReadOneFrame(Sock); REQUIRE(Reply.IsValid); CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); std::string_view ReplyText(reinterpret_cast(Reply.Payload.data()), Reply.Payload.size()); CHECK_EQ(ReplyText, "ping test"sv); CHECK_EQ(TestService.m_MessageCount.load(), 1); CHECK_EQ(TestService.m_LastMessage, "ping test"); Sock.close(); } SUBCASE("server push to client") { asio::io_context IoCtx; asio::ip::tcp::socket Sock(IoCtx); Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); REQUIRE(Ok); Sleep(50); // Server pushes a message TestService.SendToAll("server says hello"); WsFrameParseResult Frame = ReadOneFrame(Sock); REQUIRE(Frame.IsValid); CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); std::string_view Text(reinterpret_cast(Frame.Payload.data()), Frame.Payload.size()); CHECK_EQ(Text, "server says hello"sv); Sock.close(); } SUBCASE("client close handshake") { asio::io_context IoCtx; asio::ip::tcp::socket Sock(IoCtx); Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); REQUIRE(Ok); Sleep(50); // Send close frame std::vector CloseFrame = BuildMaskedCloseFrame(1000); asio::write(Sock, asio::buffer(CloseFrame)); // Server should echo close back WsFrameParseResult Reply = ReadOneFrame(Sock); REQUIRE(Reply.IsValid); CHECK_EQ(Reply.Opcode, WebSocketOpcode::kClose); Sleep(50); CHECK_EQ(TestService.m_CloseCount.load(), 1); CHECK_EQ(TestService.m_LastCloseCode.load(), 1000); Sock.close(); } SUBCASE("multiple concurrent connections") { constexpr int NumClients = 5; asio::io_context IoCtx; std::vector Sockets; for (int i = 0; i < NumClients; ++i) { Sockets.emplace_back(IoCtx); Sockets.back().connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); bool Ok = DoWebSocketHandshake(Sockets.back(), "/wstest/ws", Port); REQUIRE(Ok); } Sleep(100); CHECK_EQ(TestService.m_OpenCount.load(), NumClients); // Broadcast from server TestService.SendToAll("broadcast"); // Each client should receive the message for (int i = 0; i < NumClients; ++i) { WsFrameParseResult Frame = ReadOneFrame(Sockets[i]); REQUIRE(Frame.IsValid); CHECK_EQ(Frame.Opcode, WebSocketOpcode::kText); std::string_view Text(reinterpret_cast(Frame.Payload.data()), Frame.Payload.size()); CHECK_EQ(Text, "broadcast"sv); } // Close all for (auto& S : Sockets) { S.close(); } } SUBCASE("service without IWebSocketHandler rejects upgrade") { // Register a plain HTTP service (no WebSocket) struct PlainService : public HttpService { const char* BaseUri() const override { return "/plain/"; } void HandleRequest(HttpServerRequest& Request) override { Request.WriteResponse(HttpResponseCode::OK); } }; PlainService Plain; Server->RegisterService(Plain); Sleep(50); asio::io_context IoCtx; asio::ip::tcp::socket Sock(IoCtx); Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); // Attempt WebSocket upgrade on the plain service ExtendableStringBuilder<512> Request; Request << "GET /plain/ws HTTP/1.1\r\n" << "Host: 127.0.0.1:" << Port << "\r\n" << "Upgrade: websocket\r\n" << "Connection: Upgrade\r\n" << "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" << "Sec-WebSocket-Version: 13\r\n" << "\r\n"; std::string_view ReqStr = Request.ToView(); asio::write(Sock, asio::buffer(ReqStr.data(), ReqStr.size())); asio::streambuf ResponseBuf; asio::read_until(Sock, ResponseBuf, "\r\n\r\n"); std::string Response(asio::buffers_begin(ResponseBuf.data()), asio::buffers_end(ResponseBuf.data())); // Should NOT get 101 — should fall through to normal request handling CHECK(Response.find("101") == std::string::npos); Sock.close(); } SUBCASE("ping/pong auto-response") { asio::io_context IoCtx; asio::ip::tcp::socket Sock(IoCtx); Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); REQUIRE(Ok); Sleep(50); // Send a ping frame with payload "test" std::string_view PingPayload = "test"; std::span PingData(reinterpret_cast(PingPayload.data()), PingPayload.size()); std::vector PingFrame = BuildMaskedFrame(WebSocketOpcode::kPing, PingData); asio::write(Sock, asio::buffer(PingFrame)); // Should receive a pong with the same payload WsFrameParseResult Reply = ReadOneFrame(Sock); REQUIRE(Reply.IsValid); CHECK_EQ(Reply.Opcode, WebSocketOpcode::kPong); CHECK_EQ(Reply.Payload.size(), 4u); std::string_view PongText(reinterpret_cast(Reply.Payload.data()), Reply.Payload.size()); CHECK_EQ(PongText, "test"sv); Sock.close(); } SUBCASE("multiple messages in sequence") { asio::io_context IoCtx; asio::ip::tcp::socket Sock(IoCtx); Sock.connect(asio::ip::tcp::endpoint(asio::ip::address::from_string("127.0.0.1"), static_cast(Port))); bool Ok = DoWebSocketHandshake(Sock, "/wstest/ws", Port); REQUIRE(Ok); Sleep(50); for (int i = 0; i < 10; ++i) { std::string Msg = fmt::format("message {}", i); std::vector Frame = BuildMaskedTextFrame(Msg); asio::write(Sock, asio::buffer(Frame)); WsFrameParseResult Reply = ReadOneFrame(Sock); REQUIRE(Reply.IsValid); CHECK_EQ(Reply.Opcode, WebSocketOpcode::kText); std::string_view ReplyText(reinterpret_cast(Reply.Payload.data()), Reply.Payload.size()); CHECK_EQ(ReplyText, Msg); } CHECK_EQ(TestService.m_MessageCount.load(), 10); Sock.close(); } } ////////////////////////////////////////////////////////////////////////// // // Integration tests: HttpWsClient // namespace { struct TestWsClientHandler : public IWsClientHandler { void OnWsOpen() override { m_OpenCount.fetch_add(1); } void OnWsMessage(const WebSocketMessage& Msg) override { m_MessageCount.fetch_add(1); if (Msg.Opcode == WebSocketOpcode::kText) { std::string_view Text(static_cast(Msg.Payload.Data()), Msg.Payload.Size()); m_LastMessage = std::string(Text); } } void OnWsClose(uint16_t Code, [[maybe_unused]] std::string_view Reason) override { m_CloseCount.fetch_add(1); m_LastCloseCode = Code; } std::atomic m_OpenCount{0}; std::atomic m_MessageCount{0}; std::atomic m_CloseCount{0}; std::atomic m_LastCloseCode{0}; std::string m_LastMessage; }; } // anonymous namespace TEST_CASE("websocket.client") { WsTestService TestService; ScopedTemporaryDirectory TmpDir; Ref Server = CreateHttpAsioServer(AsioConfig{}); int Port = Server->Initialize(7576, TmpDir.Path()); REQUIRE(Port != 0); Server->RegisterService(TestService); std::thread ServerThread([&]() { Server->Run(false); }); auto ServerGuard = MakeGuard([&]() { Server->RequestExit(); if (ServerThread.joinable()) { ServerThread.join(); } Server->Close(); }); Sleep(100); SUBCASE("connect, echo, close") { TestWsClientHandler Handler; std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); HttpWsClient Client(Url, Handler); Client.Connect(); // Wait for OnWsOpen auto Deadline = std::chrono::steady_clock::now() + 5s; while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) { Sleep(10); } REQUIRE_EQ(Handler.m_OpenCount.load(), 1); CHECK(Client.IsOpen()); // Send text, expect echo Client.SendText("hello from client"); Deadline = std::chrono::steady_clock::now() + 5s; while (Handler.m_MessageCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) { Sleep(10); } CHECK_EQ(Handler.m_MessageCount.load(), 1); CHECK_EQ(Handler.m_LastMessage, "hello from client"); // Close Client.Close(1000, "done"); Deadline = std::chrono::steady_clock::now() + 5s; while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) { Sleep(10); } // The server echoes the close frame, which triggers OnWsClose on the client side // with the server's close code. Allow the connection to settle. Sleep(50); CHECK_FALSE(Client.IsOpen()); } SUBCASE("connect to bad port") { TestWsClientHandler Handler; std::string Url = "ws://127.0.0.1:1/wstest/ws"; HttpWsClient Client(Url, Handler, HttpWsClientSettings{.ConnectTimeout = std::chrono::milliseconds(2000)}); Client.Connect(); auto Deadline = std::chrono::steady_clock::now() + 5s; while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) { Sleep(10); } CHECK_EQ(Handler.m_CloseCount.load(), 1); CHECK_EQ(Handler.m_LastCloseCode.load(), 1006); CHECK_EQ(Handler.m_OpenCount.load(), 0); } SUBCASE("server-initiated close") { TestWsClientHandler Handler; std::string Url = fmt::format("ws://127.0.0.1:{}/wstest/ws", Port); HttpWsClient Client(Url, Handler); Client.Connect(); auto Deadline = std::chrono::steady_clock::now() + 5s; while (Handler.m_OpenCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) { Sleep(10); } REQUIRE_EQ(Handler.m_OpenCount.load(), 1); // Copy connections then close them outside the lock to avoid deadlocking // with OnWebSocketClose which acquires an exclusive lock std::vector> Conns; TestService.m_ConnectionsLock.WithSharedLock([&] { Conns = TestService.m_Connections; }); for (auto& Conn : Conns) { Conn->Close(1001, "going away"); } Deadline = std::chrono::steady_clock::now() + 5s; while (Handler.m_CloseCount.load() == 0 && std::chrono::steady_clock::now() < Deadline) { Sleep(10); } CHECK_EQ(Handler.m_CloseCount.load(), 1); CHECK_EQ(Handler.m_LastCloseCode.load(), 1001); CHECK_FALSE(Client.IsOpen()); } } void websocket_forcelink() { } } // namespace zen #endif // ZEN_WITH_TESTS